Vivek Vaddina
initial working commit
254b144 unverified
import torch
from torch import nn, optim
from torchvision.models import resnet34, ResNet34_Weights
from src.processing import generate_test_images
from src.config import IDX2CODE
class BirdNet(nn.Module):
def __init__(self, n_out=len(IDX2CODE.keys()), pretrained=True, freeze_backbone=True, dropout=.25):
super().__init__()
self.model = resnet34(weights=ResNet34_Weights.DEFAULT if pretrained else None)
# Modify first convolution layer to accept 1-channel grayscale input
# Original ResNet34 expects 3-channel RGB input
# We adapt it to accept 1-channel grayscale melspectrogram
original_conv1 = self.model.conv1
self.model.conv1 = nn.Conv2d(
in_channels=1, # Grayscale input
out_channels=original_conv1.out_channels,
kernel_size=original_conv1.kernel_size,
stride=original_conv1.stride,
padding=original_conv1.padding,
bias=original_conv1.bias
)
if pretrained:
with torch.no_grad():
self.model.conv1.weight.data = original_conv1.weight.data.mean(dim=1, keepdim=True)
# in_features = self.model.fc.in_features
# layers = list(self.model.children())[:-2]
# layers.append(nn.AdaptiveMaxPool2d(1))
# self.encoder = nn.Sequential(*layers)
self.model.fc = nn.Linear(self.model.fc.in_features, n_out)
# self.model.fc = nn.Sequential(
# nn.Linear(self.model.fc.in_features, 256),
# nn.ReLU(),
# nn.Dropout(dropout),
# nn.Linear(256, n_out)
# )
# Optional: Freeze backbone for fine-tuning (train only the final layer)
if freeze_backbone:
for param in self.model.parameters():
param.requires_grad = False
# Unfreeze the final layer
for param in self.model.fc.parameters():
param.requires_grad = True
def forward(self, x):
return self.model(x)
class Model:
def __init__(self, device, n_out=len(IDX2CODE.keys()), loss_fn=nn.CrossEntropyLoss(),
pretrained=True, freeze_backbone=True, dropout=.1):
self.n_out = n_out
self.device = device
self.model = BirdNet(self.n_out, pretrained=pretrained,
freeze_backbone=freeze_backbone, dropout=dropout).to(self.device)
self.lr = 5e-3
self.loss_fn = loss_fn
self.opt = optim.Adam(self.model.parameters(), lr=self.lr)
# self.lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.opt, mode='min', factor=.5, patience=3, min_lr=1e-5)
self.epoch_train_losses = []
self.epoch_val_losses = []
self.epoch_train_accs = []
self.epoch_val_accs = []
self.epoch = 0
def load_from_chkpt(self, chkpt_path):
chkpt = torch.load(chkpt_path, weights_only=False, map_location=torch.device(self.device))
self.epoch = chkpt['epoch']
self.model.load_state_dict(chkpt['model'])
self.opt.load_state_dict(chkpt['optim'])
self.epoch_train_losses = chkpt['train_losses']
self.epoch_val_losses = chkpt['valid_losses']
self.epoch_train_accs = chkpt['train_accs']
self.epoch_val_accs = chkpt['valid_accs']
def make_preds(self, fp):
arrs = generate_test_images(fp)
self.model.eval();
with torch.no_grad():
out = self.model(arrs.to(self.device).float())
labels = out.argmax(dim=1)
vc = labels.unique(return_counts=True)
return IDX2CODE[vc[0][vc[1].argmax()].item()]