"""Обучение модели детекции дефектов окраски кузова. Запуск: python -m src.train """ from __future__ import annotations import time import json from pathlib import Path import numpy as np import torch import torch.nn as nn from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix from tqdm import tqdm from . import config as C from .dataset import make_loaders from .model import build_model def set_seed(seed: int) -> None: import random random.seed(seed); np.random.seed(seed); torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def evaluate(model: nn.Module, loader, device) -> dict: model.eval() all_p, all_y = [], [] with torch.no_grad(): for x, y in loader: x = x.to(device, non_blocking=True) logits = model(x) prob = torch.softmax(logits, dim=1)[:, 1] all_p.append(prob.cpu().numpy()) all_y.append(y.numpy()) p = np.concatenate(all_p); y = np.concatenate(all_y) pred = (p >= C.DEFECT_THRESHOLD).astype(int) metrics = { "auc": float(roc_auc_score(y, p)) if len(np.unique(y)) > 1 else float("nan"), "f1": float(f1_score(y, pred, zero_division=0)), "acc": float((pred == y).mean()), "cm": confusion_matrix(y, pred, labels=[0, 1]).tolist(), } return metrics def main() -> None: set_seed(C.SEED) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Устройство: {device}") train_loader, val_loader = make_loaders() model = build_model(pretrained=True).to(device) optim = AdamW(model.parameters(), lr=C.LR, weight_decay=C.WEIGHT_DECAY) sched = CosineAnnealingLR(optim, T_max=C.EPOCHS) criterion = nn.CrossEntropyLoss(label_smoothing=C.LABEL_SMOOTH) C.CHECKPOINTS.mkdir(parents=True, exist_ok=True) C.RUNS.mkdir(parents=True, exist_ok=True) history = [] best_score = -1.0 best_path = C.CHECKPOINTS / "best.pt" for epoch in range(1, C.EPOCHS + 1): model.train() running = 0.0 n = 0 t0 = time.time() pbar = tqdm(train_loader, desc=f"Эпоха {epoch}/{C.EPOCHS}") for x, y in pbar: x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True) optim.zero_grad(set_to_none=True) logits = model(x) loss = criterion(logits, y) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) optim.step() running += float(loss.item()) * x.size(0); n += x.size(0) pbar.set_postfix(loss=f"{running / n:.4f}") sched.step() metrics = evaluate(model, val_loader, device) score = metrics["auc"] if not np.isnan(metrics["auc"]) else metrics["f1"] elapsed = time.time() - t0 print(f" val: AUC={metrics['auc']:.3f} F1={metrics['f1']:.3f} " f"acc={metrics['acc']:.3f} cm={metrics['cm']} ({elapsed:.1f}s)") history.append({"epoch": epoch, "train_loss": running / max(n, 1), **metrics}) if score > best_score: best_score = score torch.save({"model": model.state_dict(), "backbone": C.BACKBONE, "img_size": C.IMG_SIZE, "metrics": metrics}, best_path) print(f" ✓ сохранён лучший чекпоинт {best_path.name} (score={best_score:.3f})") (C.RUNS / "history.json").write_text(json.dumps(history, indent=2, ensure_ascii=False)) print(f"\nГотово. Лучший score: {best_score:.3f}\nЧекпоинт: {best_path}") if __name__ == "__main__": main()