"""Модель детекции дефектов на основе EfficientNetV2-S из timm (transfer learning).""" from __future__ import annotations import torch import torch.nn as nn import timm from . import config as C class DefectClassifier(nn.Module): """Бинарный классификатор патч/деталь: defect vs clean. Используем предобученный backbone из timm и свою классификационную голову с дропаутом — это устойчиво на малых датасетах. """ def __init__(self, backbone: str = C.BACKBONE, num_classes: int = C.NUM_CLASSES, pretrained: bool = True, drop_rate: float = 0.3): super().__init__() self.backbone = timm.create_model( backbone, pretrained=pretrained, num_classes=0, # без головы — берём фичи global_pool="avg", ) feat_dim = self.backbone.num_features self.head = nn.Sequential( nn.Dropout(drop_rate), nn.Linear(feat_dim, 256), nn.GELU(), nn.Dropout(drop_rate), nn.Linear(256, num_classes), ) def forward(self, x: torch.Tensor) -> torch.Tensor: feats = self.backbone(x) return self.head(feats) @torch.no_grad() def predict_proba(self, x: torch.Tensor) -> torch.Tensor: return torch.softmax(self.forward(x), dim=1) def gradcam_target_layer(self) -> nn.Module: """Слой для построения Grad-CAM (последний conv-блок backbone'а).""" # У EfficientNet-семейства это последний блок перед глобальным пулингом if hasattr(self.backbone, "conv_head"): return self.backbone.conv_head # Fallback: последний блок features if hasattr(self.backbone, "blocks"): return self.backbone.blocks[-1] raise RuntimeError("Не нашёл подходящий слой для Grad-CAM") def build_model(pretrained: bool = True) -> DefectClassifier: return DefectClassifier(pretrained=pretrained)