File size: 2,246 Bytes
545e859
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""Модель детекции дефектов на основе EfficientNetV2-S из timm (transfer learning)."""
from __future__ import annotations

import torch
import torch.nn as nn
import timm

from . import config as C


class DefectClassifier(nn.Module):
    """Бинарный классификатор патч/деталь: defect vs clean.

    Используем предобученный backbone из timm и свою классификационную голову
    с дропаутом — это устойчиво на малых датасетах.
    """

    def __init__(self, backbone: str = C.BACKBONE, num_classes: int = C.NUM_CLASSES,
                 pretrained: bool = True, drop_rate: float = 0.3):
        super().__init__()
        self.backbone = timm.create_model(
            backbone,
            pretrained=pretrained,
            num_classes=0,           # без головы — берём фичи
            global_pool="avg",
        )
        feat_dim = self.backbone.num_features
        self.head = nn.Sequential(
            nn.Dropout(drop_rate),
            nn.Linear(feat_dim, 256),
            nn.GELU(),
            nn.Dropout(drop_rate),
            nn.Linear(256, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        feats = self.backbone(x)
        return self.head(feats)

    @torch.no_grad()
    def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
        return torch.softmax(self.forward(x), dim=1)

    def gradcam_target_layer(self) -> nn.Module:
        """Слой для построения Grad-CAM (последний conv-блок backbone'а)."""
        # У EfficientNet-семейства это последний блок перед глобальным пулингом
        if hasattr(self.backbone, "conv_head"):
            return self.backbone.conv_head
        # Fallback: последний блок features
        if hasattr(self.backbone, "blocks"):
            return self.backbone.blocks[-1]
        raise RuntimeError("Не нашёл подходящий слой для Grad-CAM")


def build_model(pretrained: bool = True) -> DefectClassifier:
    return DefectClassifier(pretrained=pretrained)