Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import models | |
| def find_last_conv2d(module: nn.Module) -> nn.Conv2d | None: | |
| """ | |
| Returns the last nn.Conv2d found in a module traversal. | |
| Important: we do NOT attach this as a child module on the model instance, | |
| otherwise it becomes part of state_dict and breaks checkpoint loading. | |
| """ | |
| last = None | |
| for m in module.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| last = m | |
| return last | |
| class MultiTaskResNet50(nn.Module): | |
| def __init__(self, num_classes=9): | |
| super().__init__() | |
| self.backbone = models.resnet50(weights=None) | |
| feat_dim = self.backbone.fc.in_features | |
| self.backbone.fc = nn.Identity() | |
| self.class_head = nn.Linear(feat_dim, num_classes) | |
| self.bio_head = nn.Linear(feat_dim, 2) | |
| def forward(self, x: torch.Tensor): | |
| feats = self.backbone(x) | |
| return { | |
| "class": self.class_head(feats), | |
| "bio": self.bio_head(feats), | |
| } | |
| class MultiTaskConvNeXt(nn.Module): | |
| """ | |
| ConvNeXt-Base backbone with two heads: | |
| - N-class structural/mold classifier | |
| - 2-class biological vs non-biological head | |
| Mirrors the training setup from the ConvNeXt Kaggle notebook. | |
| """ | |
| def __init__(self, num_classes: int): | |
| super().__init__() | |
| # We load task-specific weights, so no ImageNet weights here. | |
| self.backbone = models.convnext_base(weights=None) | |
| # ConvNeXt classifier is [LayerNorm2d, Flatten, Linear] | |
| feat_dim = self.backbone.classifier[2].in_features | |
| self.backbone.classifier = nn.Identity() | |
| self.pool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.class_head = nn.Linear(feat_dim, num_classes) | |
| self.bio_head = nn.Linear(feat_dim, 2) | |
| self.dropout = nn.Dropout(p=0.1) | |
| def forward(self, x: torch.Tensor): | |
| feats = self.backbone.features(x) | |
| feats = self.pool(feats) | |
| feats = torch.flatten(feats, 1) | |
| feats = self.dropout(feats) | |
| return { | |
| "class": self.class_head(feats), | |
| "bio": self.bio_head(feats), | |
| } | |