Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn, optim | |
| from torchvision.models import resnet34, ResNet34_Weights | |
| from src.processing import generate_test_images | |
| from src.config import IDX2CODE | |
| class BirdNet(nn.Module): | |
| def __init__(self, n_out=len(IDX2CODE.keys()), pretrained=True, freeze_backbone=True, dropout=.25): | |
| super().__init__() | |
| self.model = resnet34(weights=ResNet34_Weights.DEFAULT if pretrained else None) | |
| # Modify first convolution layer to accept 1-channel grayscale input | |
| # Original ResNet34 expects 3-channel RGB input | |
| # We adapt it to accept 1-channel grayscale melspectrogram | |
| original_conv1 = self.model.conv1 | |
| self.model.conv1 = nn.Conv2d( | |
| in_channels=1, # Grayscale input | |
| out_channels=original_conv1.out_channels, | |
| kernel_size=original_conv1.kernel_size, | |
| stride=original_conv1.stride, | |
| padding=original_conv1.padding, | |
| bias=original_conv1.bias | |
| ) | |
| if pretrained: | |
| with torch.no_grad(): | |
| self.model.conv1.weight.data = original_conv1.weight.data.mean(dim=1, keepdim=True) | |
| # in_features = self.model.fc.in_features | |
| # layers = list(self.model.children())[:-2] | |
| # layers.append(nn.AdaptiveMaxPool2d(1)) | |
| # self.encoder = nn.Sequential(*layers) | |
| self.model.fc = nn.Linear(self.model.fc.in_features, n_out) | |
| # self.model.fc = nn.Sequential( | |
| # nn.Linear(self.model.fc.in_features, 256), | |
| # nn.ReLU(), | |
| # nn.Dropout(dropout), | |
| # nn.Linear(256, n_out) | |
| # ) | |
| # Optional: Freeze backbone for fine-tuning (train only the final layer) | |
| if freeze_backbone: | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| # Unfreeze the final layer | |
| for param in self.model.fc.parameters(): | |
| param.requires_grad = True | |
| def forward(self, x): | |
| return self.model(x) | |
| class Model: | |
| def __init__(self, device, n_out=len(IDX2CODE.keys()), loss_fn=nn.CrossEntropyLoss(), | |
| pretrained=True, freeze_backbone=True, dropout=.1): | |
| self.n_out = n_out | |
| self.device = device | |
| self.model = BirdNet(self.n_out, pretrained=pretrained, | |
| freeze_backbone=freeze_backbone, dropout=dropout).to(self.device) | |
| self.lr = 5e-3 | |
| self.loss_fn = loss_fn | |
| self.opt = optim.Adam(self.model.parameters(), lr=self.lr) | |
| # self.lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.opt, mode='min', factor=.5, patience=3, min_lr=1e-5) | |
| self.epoch_train_losses = [] | |
| self.epoch_val_losses = [] | |
| self.epoch_train_accs = [] | |
| self.epoch_val_accs = [] | |
| self.epoch = 0 | |
| def load_from_chkpt(self, chkpt_path): | |
| chkpt = torch.load(chkpt_path, weights_only=False, map_location=torch.device(self.device)) | |
| self.epoch = chkpt['epoch'] | |
| self.model.load_state_dict(chkpt['model']) | |
| self.opt.load_state_dict(chkpt['optim']) | |
| self.epoch_train_losses = chkpt['train_losses'] | |
| self.epoch_val_losses = chkpt['valid_losses'] | |
| self.epoch_train_accs = chkpt['train_accs'] | |
| self.epoch_val_accs = chkpt['valid_accs'] | |
| def make_preds(self, fp): | |
| arrs = generate_test_images(fp) | |
| self.model.eval(); | |
| with torch.no_grad(): | |
| out = self.model(arrs.to(self.device).float()) | |
| labels = out.argmax(dim=1) | |
| vc = labels.unique(return_counts=True) | |
| return IDX2CODE[vc[0][vc[1].argmax()].item()] |