import torch from torch import nn, optim from torchvision.models import resnet34, ResNet34_Weights from src.processing import generate_test_images from src.config import IDX2CODE class BirdNet(nn.Module): def __init__(self, n_out=len(IDX2CODE.keys()), pretrained=True, freeze_backbone=True, dropout=.25): super().__init__() self.model = resnet34(weights=ResNet34_Weights.DEFAULT if pretrained else None) # Modify first convolution layer to accept 1-channel grayscale input # Original ResNet34 expects 3-channel RGB input # We adapt it to accept 1-channel grayscale melspectrogram original_conv1 = self.model.conv1 self.model.conv1 = nn.Conv2d( in_channels=1, # Grayscale input out_channels=original_conv1.out_channels, kernel_size=original_conv1.kernel_size, stride=original_conv1.stride, padding=original_conv1.padding, bias=original_conv1.bias ) if pretrained: with torch.no_grad(): self.model.conv1.weight.data = original_conv1.weight.data.mean(dim=1, keepdim=True) # in_features = self.model.fc.in_features # layers = list(self.model.children())[:-2] # layers.append(nn.AdaptiveMaxPool2d(1)) # self.encoder = nn.Sequential(*layers) self.model.fc = nn.Linear(self.model.fc.in_features, n_out) # self.model.fc = nn.Sequential( # nn.Linear(self.model.fc.in_features, 256), # nn.ReLU(), # nn.Dropout(dropout), # nn.Linear(256, n_out) # ) # Optional: Freeze backbone for fine-tuning (train only the final layer) if freeze_backbone: for param in self.model.parameters(): param.requires_grad = False # Unfreeze the final layer for param in self.model.fc.parameters(): param.requires_grad = True def forward(self, x): return self.model(x) class Model: def __init__(self, device, n_out=len(IDX2CODE.keys()), loss_fn=nn.CrossEntropyLoss(), pretrained=True, freeze_backbone=True, dropout=.1): self.n_out = n_out self.device = device self.model = BirdNet(self.n_out, pretrained=pretrained, freeze_backbone=freeze_backbone, dropout=dropout).to(self.device) self.lr = 5e-3 self.loss_fn = loss_fn self.opt = optim.Adam(self.model.parameters(), lr=self.lr) # self.lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.opt, mode='min', factor=.5, patience=3, min_lr=1e-5) self.epoch_train_losses = [] self.epoch_val_losses = [] self.epoch_train_accs = [] self.epoch_val_accs = [] self.epoch = 0 def load_from_chkpt(self, chkpt_path): chkpt = torch.load(chkpt_path, weights_only=False, map_location=torch.device(self.device)) self.epoch = chkpt['epoch'] self.model.load_state_dict(chkpt['model']) self.opt.load_state_dict(chkpt['optim']) self.epoch_train_losses = chkpt['train_losses'] self.epoch_val_losses = chkpt['valid_losses'] self.epoch_train_accs = chkpt['train_accs'] self.epoch_val_accs = chkpt['valid_accs'] def make_preds(self, fp): arrs = generate_test_images(fp) self.model.eval(); with torch.no_grad(): out = self.model(arrs.to(self.device).float()) labels = out.argmax(dim=1) vc = labels.unique(return_counts=True) return IDX2CODE[vc[0][vc[1].argmax()].item()]