File size: 3,833 Bytes
18e5c60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Обучение модели детекции дефектов окраски кузова.

Запуск:  python -m src.train
"""
from __future__ import annotations
import time
import json
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix
from tqdm import tqdm

from . import config as C
from .dataset import make_loaders
from .model import build_model


def set_seed(seed: int) -> None:
    import random
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def evaluate(model: nn.Module, loader, device) -> dict:
    model.eval()
    all_p, all_y = [], []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device, non_blocking=True)
            logits = model(x)
            prob = torch.softmax(logits, dim=1)[:, 1]
            all_p.append(prob.cpu().numpy())
            all_y.append(y.numpy())
    p = np.concatenate(all_p); y = np.concatenate(all_y)
    pred = (p >= C.DEFECT_THRESHOLD).astype(int)
    metrics = {
        "auc": float(roc_auc_score(y, p)) if len(np.unique(y)) > 1 else float("nan"),
        "f1": float(f1_score(y, pred, zero_division=0)),
        "acc": float((pred == y).mean()),
        "cm": confusion_matrix(y, pred, labels=[0, 1]).tolist(),
    }
    return metrics


def main() -> None:
    set_seed(C.SEED)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Устройство: {device}")

    train_loader, val_loader = make_loaders()
    model = build_model(pretrained=True).to(device)

    optim = AdamW(model.parameters(), lr=C.LR, weight_decay=C.WEIGHT_DECAY)
    sched = CosineAnnealingLR(optim, T_max=C.EPOCHS)
    criterion = nn.CrossEntropyLoss(label_smoothing=C.LABEL_SMOOTH)

    C.CHECKPOINTS.mkdir(parents=True, exist_ok=True)
    C.RUNS.mkdir(parents=True, exist_ok=True)
    history = []
    best_score = -1.0
    best_path = C.CHECKPOINTS / "best.pt"

    for epoch in range(1, C.EPOCHS + 1):
        model.train()
        running = 0.0
        n = 0
        t0 = time.time()
        pbar = tqdm(train_loader, desc=f"Эпоха {epoch}/{C.EPOCHS}")
        for x, y in pbar:
            x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
            optim.zero_grad(set_to_none=True)
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optim.step()
            running += float(loss.item()) * x.size(0); n += x.size(0)
            pbar.set_postfix(loss=f"{running / n:.4f}")
        sched.step()

        metrics = evaluate(model, val_loader, device)
        score = metrics["auc"] if not np.isnan(metrics["auc"]) else metrics["f1"]
        elapsed = time.time() - t0
        print(f"  val: AUC={metrics['auc']:.3f}  F1={metrics['f1']:.3f}  "
              f"acc={metrics['acc']:.3f}  cm={metrics['cm']}  ({elapsed:.1f}s)")
        history.append({"epoch": epoch, "train_loss": running / max(n, 1), **metrics})

        if score > best_score:
            best_score = score
            torch.save({"model": model.state_dict(),
                        "backbone": C.BACKBONE,
                        "img_size": C.IMG_SIZE,
                        "metrics": metrics}, best_path)
            print(f"  ✓ сохранён лучший чекпоинт {best_path.name}  (score={best_score:.3f})")

    (C.RUNS / "history.json").write_text(json.dumps(history, indent=2, ensure_ascii=False))
    print(f"\nГотово. Лучший score: {best_score:.3f}\nЧекпоинт: {best_path}")


if __name__ == "__main__":
    main()