File size: 3,635 Bytes
254b144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from torch import nn, optim
from torchvision.models import resnet34, ResNet34_Weights
from src.processing import generate_test_images
from src.config import IDX2CODE


class BirdNet(nn.Module):
    def __init__(self, n_out=len(IDX2CODE.keys()), pretrained=True, freeze_backbone=True, dropout=.25):
        super().__init__()
        self.model = resnet34(weights=ResNet34_Weights.DEFAULT if pretrained else None)

        # Modify first convolution layer to accept 1-channel grayscale input
        # Original ResNet34 expects 3-channel RGB input
        # We adapt it to accept 1-channel grayscale melspectrogram
        original_conv1 = self.model.conv1
        self.model.conv1 = nn.Conv2d(
            in_channels=1,  # Grayscale input
            out_channels=original_conv1.out_channels,
            kernel_size=original_conv1.kernel_size,
            stride=original_conv1.stride,
            padding=original_conv1.padding,
            bias=original_conv1.bias
        )

        if pretrained:
            with torch.no_grad():
                self.model.conv1.weight.data = original_conv1.weight.data.mean(dim=1, keepdim=True)

        # in_features = self.model.fc.in_features
        # layers = list(self.model.children())[:-2]
        # layers.append(nn.AdaptiveMaxPool2d(1))
        # self.encoder = nn.Sequential(*layers)

        self.model.fc = nn.Linear(self.model.fc.in_features, n_out)
        # self.model.fc = nn.Sequential(
        #     nn.Linear(self.model.fc.in_features, 256),
        #     nn.ReLU(),
        #     nn.Dropout(dropout),
        #     nn.Linear(256, n_out)
        # )
        # Optional: Freeze backbone for fine-tuning (train only the final layer)
        if freeze_backbone:
            for param in self.model.parameters():
                param.requires_grad = False
            # Unfreeze the final layer
            for param in self.model.fc.parameters():
                param.requires_grad = True
    
    def forward(self, x):
        return self.model(x)

class Model:
    def __init__(self, device, n_out=len(IDX2CODE.keys()), loss_fn=nn.CrossEntropyLoss(), 
                 pretrained=True, freeze_backbone=True, dropout=.1):
        self.n_out = n_out
        self.device = device
        self.model = BirdNet(self.n_out, pretrained=pretrained, 
                             freeze_backbone=freeze_backbone, dropout=dropout).to(self.device)
        self.lr = 5e-3
        self.loss_fn = loss_fn
        self.opt = optim.Adam(self.model.parameters(), lr=self.lr)
        # self.lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.opt, mode='min', factor=.5, patience=3, min_lr=1e-5)
        self.epoch_train_losses = []
        self.epoch_val_losses = []
        self.epoch_train_accs = []
        self.epoch_val_accs = []
        self.epoch = 0

    def load_from_chkpt(self, chkpt_path):
        chkpt = torch.load(chkpt_path, weights_only=False, map_location=torch.device(self.device))
        self.epoch = chkpt['epoch']
        self.model.load_state_dict(chkpt['model'])
        self.opt.load_state_dict(chkpt['optim'])
        self.epoch_train_losses = chkpt['train_losses']
        self.epoch_val_losses = chkpt['valid_losses']
        self.epoch_train_accs = chkpt['train_accs']
        self.epoch_val_accs = chkpt['valid_accs']

    def make_preds(self, fp):
        arrs = generate_test_images(fp)
        self.model.eval();
        with torch.no_grad():
            out = self.model(arrs.to(self.device).float())
        labels = out.argmax(dim=1)
        vc = labels.unique(return_counts=True)
        return IDX2CODE[vc[0][vc[1].argmax()].item()]