| """Модель детекции дефектов на основе EfficientNetV2-S из timm (transfer learning).""" |
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
| import timm |
|
|
| from . import config as C |
|
|
|
|
| class DefectClassifier(nn.Module): |
| """Бинарный классификатор патч/деталь: defect vs clean. |
| |
| Используем предобученный backbone из timm и свою классификационную голову |
| с дропаутом — это устойчиво на малых датасетах. |
| """ |
|
|
| def __init__(self, backbone: str = C.BACKBONE, num_classes: int = C.NUM_CLASSES, |
| pretrained: bool = True, drop_rate: float = 0.3): |
| super().__init__() |
| self.backbone = timm.create_model( |
| backbone, |
| pretrained=pretrained, |
| num_classes=0, |
| global_pool="avg", |
| ) |
| feat_dim = self.backbone.num_features |
| self.head = nn.Sequential( |
| nn.Dropout(drop_rate), |
| nn.Linear(feat_dim, 256), |
| nn.GELU(), |
| nn.Dropout(drop_rate), |
| nn.Linear(256, num_classes), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| feats = self.backbone(x) |
| return self.head(feats) |
|
|
| @torch.no_grad() |
| def predict_proba(self, x: torch.Tensor) -> torch.Tensor: |
| return torch.softmax(self.forward(x), dim=1) |
|
|
| def gradcam_target_layer(self) -> nn.Module: |
| """Слой для построения Grad-CAM (последний conv-блок backbone'а).""" |
| |
| if hasattr(self.backbone, "conv_head"): |
| return self.backbone.conv_head |
| |
| if hasattr(self.backbone, "blocks"): |
| return self.backbone.blocks[-1] |
| raise RuntimeError("Не нашёл подходящий слой для Grad-CAM") |
|
|
|
|
| def build_model(pretrained: bool = True) -> DefectClassifier: |
| return DefectClassifier(pretrained=pretrained) |
|
|