paint_defect_detector / src\train.py
therealestcoder's picture
Upload src\train.py with huggingface_hub
53ec820 verified
"""Обучение модели детекции дефектов окраски кузова.
Запуск: python -m src.train
"""
from __future__ import annotations
import time
import json
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix
from tqdm import tqdm
from . import config as C
from .dataset import make_loaders
from .model import build_model
def set_seed(seed: int) -> None:
import random
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def evaluate(model: nn.Module, loader, device) -> dict:
model.eval()
all_p, all_y = [], []
with torch.no_grad():
for x, y in loader:
x = x.to(device, non_blocking=True)
logits = model(x)
prob = torch.softmax(logits, dim=1)[:, 1]
all_p.append(prob.cpu().numpy())
all_y.append(y.numpy())
p = np.concatenate(all_p); y = np.concatenate(all_y)
pred = (p >= C.DEFECT_THRESHOLD).astype(int)
metrics = {
"auc": float(roc_auc_score(y, p)) if len(np.unique(y)) > 1 else float("nan"),
"f1": float(f1_score(y, pred, zero_division=0)),
"acc": float((pred == y).mean()),
"cm": confusion_matrix(y, pred, labels=[0, 1]).tolist(),
}
return metrics
def main() -> None:
set_seed(C.SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Устройство: {device}")
train_loader, val_loader = make_loaders()
model = build_model(pretrained=True).to(device)
optim = AdamW(model.parameters(), lr=C.LR, weight_decay=C.WEIGHT_DECAY)
sched = CosineAnnealingLR(optim, T_max=C.EPOCHS)
criterion = nn.CrossEntropyLoss(label_smoothing=C.LABEL_SMOOTH)
C.CHECKPOINTS.mkdir(parents=True, exist_ok=True)
C.RUNS.mkdir(parents=True, exist_ok=True)
history = []
best_score = -1.0
best_path = C.CHECKPOINTS / "best.pt"
for epoch in range(1, C.EPOCHS + 1):
model.train()
running = 0.0
n = 0
t0 = time.time()
pbar = tqdm(train_loader, desc=f"Эпоха {epoch}/{C.EPOCHS}")
for x, y in pbar:
x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
optim.zero_grad(set_to_none=True)
logits = model(x)
loss = criterion(logits, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
optim.step()
running += float(loss.item()) * x.size(0); n += x.size(0)
pbar.set_postfix(loss=f"{running / n:.4f}")
sched.step()
metrics = evaluate(model, val_loader, device)
score = metrics["auc"] if not np.isnan(metrics["auc"]) else metrics["f1"]
elapsed = time.time() - t0
print(f" val: AUC={metrics['auc']:.3f} F1={metrics['f1']:.3f} "
f"acc={metrics['acc']:.3f} cm={metrics['cm']} ({elapsed:.1f}s)")
history.append({"epoch": epoch, "train_loss": running / max(n, 1), **metrics})
if score > best_score:
best_score = score
torch.save({"model": model.state_dict(),
"backbone": C.BACKBONE,
"img_size": C.IMG_SIZE,
"metrics": metrics}, best_path)
print(f" ✓ сохранён лучший чекпоинт {best_path.name} (score={best_score:.3f})")
(C.RUNS / "history.json").write_text(json.dumps(history, indent=2, ensure_ascii=False))
print(f"\nГотово. Лучший score: {best_score:.3f}\nЧекпоинт: {best_path}")
if __name__ == "__main__":
main()