import torch import torch.nn as nn from torchvision import models def find_last_conv2d(module: nn.Module) -> nn.Conv2d | None: """ Returns the last nn.Conv2d found in a module traversal. Important: we do NOT attach this as a child module on the model instance, otherwise it becomes part of state_dict and breaks checkpoint loading. """ last = None for m in module.modules(): if isinstance(m, nn.Conv2d): last = m return last class MultiTaskResNet50(nn.Module): def __init__(self, num_classes=9): super().__init__() self.backbone = models.resnet50(weights=None) feat_dim = self.backbone.fc.in_features self.backbone.fc = nn.Identity() self.class_head = nn.Linear(feat_dim, num_classes) self.bio_head = nn.Linear(feat_dim, 2) def forward(self, x: torch.Tensor): feats = self.backbone(x) return { "class": self.class_head(feats), "bio": self.bio_head(feats), } class MultiTaskConvNeXt(nn.Module): """ ConvNeXt-Base backbone with two heads: - N-class structural/mold classifier - 2-class biological vs non-biological head Mirrors the training setup from the ConvNeXt Kaggle notebook. """ def __init__(self, num_classes: int): super().__init__() # We load task-specific weights, so no ImageNet weights here. self.backbone = models.convnext_base(weights=None) # ConvNeXt classifier is [LayerNorm2d, Flatten, Linear] feat_dim = self.backbone.classifier[2].in_features self.backbone.classifier = nn.Identity() self.pool = nn.AdaptiveAvgPool2d((1, 1)) self.class_head = nn.Linear(feat_dim, num_classes) self.bio_head = nn.Linear(feat_dim, 2) self.dropout = nn.Dropout(p=0.1) def forward(self, x: torch.Tensor): feats = self.backbone.features(x) feats = self.pool(feats) feats = torch.flatten(feats, 1) feats = self.dropout(feats) return { "class": self.class_head(feats), "bio": self.bio_head(feats), }