| """ResNet-50 backbone variants for multi-task classification.""" |
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision.models import ResNet50_Weights, resnet50 |
|
|
|
|
| def _topk_from_logits( |
| logits: torch.Tensor, id2label: dict[int, str], k: int |
| ) -> list[tuple[str, float]]: |
| """Map top-k softmax probabilities to label names (batch size 1).""" |
| n = logits.size(-1) |
| k = min(k, n) |
| probs = F.softmax(logits, dim=-1)[0] |
| top = probs.topk(k) |
| return [(id2label[int(i)], float(p)) for i, p in zip(top.indices.tolist(), top.values.tolist())] |
|
|
|
|
| class _ThreeHeadPredictMixin: |
| """Top-k decoding for genre / style / artist heads (shared by CNN and CNN–RNN).""" |
|
|
| def predict_topk( |
| self, |
| x: torch.Tensor, |
| *, |
| genre_id2label: dict[int, str], |
| style_id2label: dict[int, str], |
| artist_id2label: dict[int, str], |
| k: int = 3, |
| ) -> tuple[list[tuple[str, float]], list[tuple[str, float]], list[tuple[str, float]]]: |
| self.eval() |
| with torch.no_grad(): |
| lg, ls, la = self(x) |
| return ( |
| _topk_from_logits(lg, genre_id2label, k), |
| _topk_from_logits(ls, style_id2label, k), |
| _topk_from_logits(la, artist_id2label, k), |
| ) |
|
|
| def predict_topk_from_path( |
| self, |
| path: Path | str, |
| transform: torch.nn.Module, |
| device: torch.device, |
| *, |
| genre_id2label: dict[int, str], |
| style_id2label: dict[int, str], |
| artist_id2label: dict[int, str], |
| k: int = 3, |
| ) -> tuple[list[tuple[str, float]], list[tuple[str, float]], list[tuple[str, float]]]: |
| from PIL import Image |
|
|
| p = Path(path) |
| img = Image.open(p).convert("RGB") |
| x = transform(img).unsqueeze(0).to(device) |
| return self.predict_topk( |
| x, |
| genre_id2label=genre_id2label, |
| style_id2label=style_id2label, |
| artist_id2label=artist_id2label, |
| k=k, |
| ) |
|
|
|
|
| class ResNet50ThreeHeads(_ThreeHeadPredictMixin, nn.Module): |
| """ResNet-50 (ImageNet pretrained), GAP, then three linear heads: genre, style, artist.""" |
|
|
| def __init__( |
| self, |
| n_genre: int, |
| n_style: int, |
| n_artist: int, |
| weights: ResNet50_Weights | None = ResNet50_Weights.IMAGENET1K_V2, |
| dropout: float = 0.4, |
| ) -> None: |
| super().__init__() |
| self.n_genre = n_genre |
| self.n_style = n_style |
| self.n_artist = n_artist |
| backbone = resnet50(weights=weights) |
| self.backbone = nn.Sequential( |
| backbone.conv1, |
| backbone.bn1, |
| backbone.relu, |
| backbone.maxpool, |
| backbone.layer1, |
| backbone.layer2, |
| backbone.layer3, |
| backbone.layer4, |
| ) |
| self.pool = nn.AdaptiveAvgPool2d(1) |
| feat_dim = 2048 |
| self.genre_head = nn.Sequential( |
| nn.Dropout(p=dropout), |
| nn.Linear(feat_dim, n_genre), |
| ) |
| self.style_head = nn.Sequential( |
| nn.Dropout(p=dropout), |
| nn.Linear(feat_dim, n_style), |
| ) |
| self.artist_head = nn.Sequential( |
| nn.Dropout(p=dropout), |
| nn.Linear(feat_dim, n_artist), |
| ) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| features = self.backbone(x) |
| pooled = self.pool(features).flatten(1) |
| return ( |
| self.genre_head(pooled), |
| self.style_head(pooled), |
| self.artist_head(pooled), |
| ) |
|
|
|
|
| class ResNet50BiLSTMThreeHeads(_ThreeHeadPredictMixin, nn.Module): |
| """ |
| ResNet-50 (ImageNet pretrained) feature map -> column pooling -> BiLSTM -> mean pool -> three heads. |
| |
| Minimal-change upgrade over `ResNet50ThreeHeads` to keep the training loop comparable: |
| same backbone, same head style (dropout + linear), only the pooling/aggregation is replaced. |
| """ |
|
|
| def __init__( |
| self, |
| n_genre: int, |
| n_style: int, |
| n_artist: int, |
| weights: ResNet50_Weights | None = ResNet50_Weights.IMAGENET1K_V2, |
| lstm_hidden: int = 256, |
| lstm_layers: int = 1, |
| dropout: float = 0.4, |
| ) -> None: |
| super().__init__() |
| self.n_genre = n_genre |
| self.n_style = n_style |
| self.n_artist = n_artist |
|
|
| backbone = resnet50(weights=weights) |
| self.backbone = nn.Sequential( |
| backbone.conv1, |
| backbone.bn1, |
| backbone.relu, |
| backbone.maxpool, |
| backbone.layer1, |
| backbone.layer2, |
| backbone.layer3, |
| backbone.layer4, |
| ) |
|
|
| self.lstm = nn.LSTM( |
| input_size=2048, |
| hidden_size=lstm_hidden, |
| num_layers=lstm_layers, |
| bidirectional=True, |
| batch_first=True, |
| ) |
| feat_dim = 2 * lstm_hidden |
|
|
| self.genre_head = nn.Sequential( |
| nn.Dropout(p=dropout), |
| nn.Linear(feat_dim, n_genre), |
| ) |
| self.style_head = nn.Sequential( |
| nn.Dropout(p=dropout), |
| nn.Linear(feat_dim, n_style), |
| ) |
| self.artist_head = nn.Sequential( |
| nn.Dropout(p=dropout), |
| nn.Linear(feat_dim, n_artist), |
| ) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| |
| features = self.backbone(x) |
| |
| seq = features.mean(dim=2) |
| |
| seq = seq.permute(0, 2, 1).contiguous() |
| |
| out, _ = self.lstm(seq) |
| |
| pooled = out.mean(dim=1) |
| return ( |
| self.genre_head(pooled), |
| self.style_head(pooled), |
| self.artist_head(pooled), |
| ) |
|
|