| """Обучение модели детекции дефектов окраски кузова. |
| |
| Запуск: python -m src.train |
| """ |
| from __future__ import annotations |
| import time |
| import json |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix |
| from tqdm import tqdm |
|
|
| from . import config as C |
| from .dataset import make_loaders |
| from .model import build_model |
|
|
|
|
| def set_seed(seed: int) -> None: |
| import random |
| random.seed(seed); np.random.seed(seed); torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def evaluate(model: nn.Module, loader, device) -> dict: |
| model.eval() |
| all_p, all_y = [], [] |
| with torch.no_grad(): |
| for x, y in loader: |
| x = x.to(device, non_blocking=True) |
| logits = model(x) |
| prob = torch.softmax(logits, dim=1)[:, 1] |
| all_p.append(prob.cpu().numpy()) |
| all_y.append(y.numpy()) |
| p = np.concatenate(all_p); y = np.concatenate(all_y) |
| pred = (p >= C.DEFECT_THRESHOLD).astype(int) |
| metrics = { |
| "auc": float(roc_auc_score(y, p)) if len(np.unique(y)) > 1 else float("nan"), |
| "f1": float(f1_score(y, pred, zero_division=0)), |
| "acc": float((pred == y).mean()), |
| "cm": confusion_matrix(y, pred, labels=[0, 1]).tolist(), |
| } |
| return metrics |
|
|
|
|
| def main() -> None: |
| set_seed(C.SEED) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Устройство: {device}") |
|
|
| train_loader, val_loader = make_loaders() |
| model = build_model(pretrained=True).to(device) |
|
|
| optim = AdamW(model.parameters(), lr=C.LR, weight_decay=C.WEIGHT_DECAY) |
| sched = CosineAnnealingLR(optim, T_max=C.EPOCHS) |
| criterion = nn.CrossEntropyLoss(label_smoothing=C.LABEL_SMOOTH) |
|
|
| C.CHECKPOINTS.mkdir(parents=True, exist_ok=True) |
| C.RUNS.mkdir(parents=True, exist_ok=True) |
| history = [] |
| best_score = -1.0 |
| best_path = C.CHECKPOINTS / "best.pt" |
|
|
| for epoch in range(1, C.EPOCHS + 1): |
| model.train() |
| running = 0.0 |
| n = 0 |
| t0 = time.time() |
| pbar = tqdm(train_loader, desc=f"Эпоха {epoch}/{C.EPOCHS}") |
| for x, y in pbar: |
| x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True) |
| optim.zero_grad(set_to_none=True) |
| logits = model(x) |
| loss = criterion(logits, y) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) |
| optim.step() |
| running += float(loss.item()) * x.size(0); n += x.size(0) |
| pbar.set_postfix(loss=f"{running / n:.4f}") |
| sched.step() |
|
|
| metrics = evaluate(model, val_loader, device) |
| score = metrics["auc"] if not np.isnan(metrics["auc"]) else metrics["f1"] |
| elapsed = time.time() - t0 |
| print(f" val: AUC={metrics['auc']:.3f} F1={metrics['f1']:.3f} " |
| f"acc={metrics['acc']:.3f} cm={metrics['cm']} ({elapsed:.1f}s)") |
| history.append({"epoch": epoch, "train_loss": running / max(n, 1), **metrics}) |
|
|
| if score > best_score: |
| best_score = score |
| torch.save({"model": model.state_dict(), |
| "backbone": C.BACKBONE, |
| "img_size": C.IMG_SIZE, |
| "metrics": metrics}, best_path) |
| print(f" ✓ сохранён лучший чекпоинт {best_path.name} (score={best_score:.3f})") |
|
|
| (C.RUNS / "history.json").write_text(json.dumps(history, indent=2, ensure_ascii=False)) |
| print(f"\nГотово. Лучший score: {best_score:.3f}\nЧекпоинт: {best_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|