File size: 2,150 Bytes
8cc2137
 
 
 
7d6580c
7a5f7fb
 
 
 
 
 
 
 
 
 
 
 
 
8cc2137
 
 
 
 
 
 
 
 
7d6580c
8cc2137
 
 
7d6580c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cc2137
 
7d6580c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn as nn
from torchvision import models


def find_last_conv2d(module: nn.Module) -> nn.Conv2d | None:
    """
    Returns the last nn.Conv2d found in a module traversal.
    Important: we do NOT attach this as a child module on the model instance,
    otherwise it becomes part of state_dict and breaks checkpoint loading.
    """
    last = None
    for m in module.modules():
        if isinstance(m, nn.Conv2d):
            last = m
    return last


class MultiTaskResNet50(nn.Module):
    def __init__(self, num_classes=9):
        super().__init__()
        self.backbone = models.resnet50(weights=None)
        feat_dim = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.class_head = nn.Linear(feat_dim, num_classes)
        self.bio_head = nn.Linear(feat_dim, 2)

    def forward(self, x: torch.Tensor):
        feats = self.backbone(x)
        return {
            "class": self.class_head(feats),
            "bio": self.bio_head(feats),
        }


class MultiTaskConvNeXt(nn.Module):
    """
    ConvNeXt-Base backbone with two heads:
    - N-class structural/mold classifier
    - 2-class biological vs non-biological head

    Mirrors the training setup from the ConvNeXt Kaggle notebook.
    """

    def __init__(self, num_classes: int):
        super().__init__()

        # We load task-specific weights, so no ImageNet weights here.
        self.backbone = models.convnext_base(weights=None)

        # ConvNeXt classifier is [LayerNorm2d, Flatten, Linear]
        feat_dim = self.backbone.classifier[2].in_features
        self.backbone.classifier = nn.Identity()

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.class_head = nn.Linear(feat_dim, num_classes)
        self.bio_head = nn.Linear(feat_dim, 2)
        self.dropout = nn.Dropout(p=0.1)

    def forward(self, x: torch.Tensor):
        feats = self.backbone.features(x)
        feats = self.pool(feats)
        feats = torch.flatten(feats, 1)
        feats = self.dropout(feats)

        return {
            "class": self.class_head(feats),
            "bio": self.bio_head(feats),
        }