| |
| """ |
| Unified Training Script β YOLOv11 + CNN-BiGRU |
| Based on: Nature Scientific Reports (Nov 2025) |
| |
| Usage: |
| # Train YOLOv11 detector only |
| python train.py yolo --data dataset/data.yaml --epochs 100 |
| |
| # Train CNN-BiGRU severity model (requires sequence data) |
| python train.py bigru --data severity_sequences/ --epochs 50 |
| |
| # Train both sequentially |
| python train.py all --data dataset/data.yaml --bigru-data severity_sequences/ |
| """ |
|
|
| import os |
| import sys |
| import shutil |
| import logging |
| import argparse |
| from pathlib import Path |
| from datetime import datetime |
|
|
| import torch |
| import yaml |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(message)s", |
| handlers=[logging.StreamHandler(), logging.FileHandler("training.log")], |
| ) |
| logger = logging.getLogger("train") |
|
|
|
|
| |
| |
| |
| def train_yolo(args): |
| from yolo_detection import YOLOv11Detector |
|
|
| logger.info("=" * 60) |
| logger.info(" YOLOv11 Road Anomaly Detection β Training") |
| logger.info("=" * 60) |
|
|
| |
| if torch.cuda.is_available(): |
| name = torch.cuda.get_device_properties(0).name |
| vram = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) |
| logger.info("GPU: %s (%.1f GB)", name, vram) |
| else: |
| logger.info("Training on CPU (this will be slow)") |
|
|
| |
| data_yaml = str(Path(args.data).resolve()) |
| logger.info("Dataset config: %s", data_yaml) |
|
|
| |
| batch = args.batch |
| if batch == 0: |
| |
| |
| |
| |
| if torch.cuda.is_available(): |
| vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) |
| if vram_gb >= 20: |
| batch = 16 |
| elif vram_gb >= 8: |
| batch = 8 |
| else: |
| batch = 4 |
| else: |
| batch = 2 |
| logger.info("Auto batch size: %d (VRAM=%.1f GB)", batch, |
| vram_gb if torch.cuda.is_available() else 0) |
|
|
| detector = YOLOv11Detector( |
| model_path=args.model, |
| img_size=args.imgsz, |
| ) |
|
|
| results = detector.train( |
| data_yaml=data_yaml, |
| epochs=args.epochs, |
| batch=batch, |
| optimizer=args.optimizer, |
| lr0=args.lr, |
| weight_decay=args.weight_decay, |
| warmup_epochs=args.warmup, |
| mosaic=0.5, |
| cache=args.cache, |
| amp=not args.no_amp, |
| workers=args.workers, |
| project=args.project, |
| name=args.name, |
| resume=args.resume, |
| ) |
|
|
| |
| best_src = Path(args.project) / args.name / "weights" / "best.pt" |
| if best_src.exists(): |
| best_dst = Path("runs/best.pt") |
| best_dst.parent.mkdir(parents=True, exist_ok=True) |
| shutil.copy2(best_src, best_dst) |
| logger.info("β
Best model β %s", best_dst) |
|
|
| |
| if args.export: |
| for fmt in args.export: |
| try: |
| detector.export(format=fmt, half=(fmt == "engine")) |
| logger.info("β
Exported β %s", fmt) |
| except Exception as e: |
| logger.warning("Export %s failed: %s", fmt, e) |
|
|
| return results |
|
|
|
|
| |
| |
| |
| def train_bigru(args): |
| from cnn_bigru import CNNBiGRU, AnomalySequenceDataset, BiGRUTrainer |
| from torch.utils.data import DataLoader, random_split |
|
|
| logger.info("=" * 60) |
| logger.info(" CNN-BiGRU Severity Prediction β Training") |
| logger.info("=" * 60) |
|
|
| |
| dataset = AnomalySequenceDataset( |
| root=args.bigru_data, |
| seq_len=args.seq_len, |
| patch_size=64, |
| ) |
|
|
| |
| n_val = max(1, int(len(dataset) * 0.2)) |
| n_train = len(dataset) - n_val |
| train_ds, val_ds = random_split(dataset, [n_train, n_val]) |
|
|
| train_loader = DataLoader( |
| train_ds, batch_size=args.bigru_batch, shuffle=True, |
| num_workers=args.workers, pin_memory=True, |
| ) |
| val_loader = DataLoader( |
| val_ds, batch_size=args.bigru_batch, shuffle=False, |
| num_workers=args.workers, pin_memory=True, |
| ) |
|
|
| logger.info("Train sequences: %d | Val sequences: %d", n_train, n_val) |
|
|
| |
| model = CNNBiGRU( |
| in_channels=3, |
| hidden_size=128, |
| num_gru_layers=2, |
| num_severity_classes=4, |
| ) |
|
|
| trainer = BiGRUTrainer( |
| model=model, |
| lr=args.bigru_lr, |
| weight_decay=1e-4, |
| ) |
|
|
| history = trainer.fit( |
| train_loader=train_loader, |
| val_loader=val_loader, |
| epochs=args.bigru_epochs, |
| save_dir=args.bigru_save_dir, |
| patience=args.patience, |
| ) |
|
|
| |
| best_src = Path(args.bigru_save_dir) / "best_bigru.pth" |
| if best_src.exists(): |
| best_dst = Path("runs/best_bigru.pth") |
| shutil.copy2(best_src, best_dst) |
| logger.info("β
Best BiGRU β %s", best_dst) |
|
|
| return history |
|
|
|
|
| |
| |
| |
| def build_parser(): |
| parser = argparse.ArgumentParser( |
| description="Train YOLOv11 + CNN-BiGRU Road Anomaly Detection System", |
| ) |
| sub = parser.add_subparsers(dest="mode", required=True) |
|
|
| |
| p_yolo = sub.add_parser("yolo", help="Train YOLOv11 detector") |
| p_yolo.add_argument("--data", required=True, help="data.yaml path") |
| p_yolo.add_argument("--model", default="yolo11n.pt", |
| help="Base model (yolo11n/s/m/l/x.pt)") |
| p_yolo.add_argument("--epochs", type=int, default=100) |
| p_yolo.add_argument("--batch", type=int, default=0, |
| help="Batch size (0 = auto from VRAM)") |
| p_yolo.add_argument("--imgsz", type=int, default=416) |
| p_yolo.add_argument("--optimizer", default="AdamW") |
| p_yolo.add_argument("--lr", type=float, default=0.001) |
| p_yolo.add_argument("--weight-decay", type=float, default=0.0005) |
| p_yolo.add_argument("--warmup", type=float, default=3.0) |
| p_yolo.add_argument("--cache", default="disk", |
| help="'ram', 'disk', or '' for none") |
| p_yolo.add_argument("--no-amp", action="store_true") |
| p_yolo.add_argument("--workers", type=int, default=4) |
| p_yolo.add_argument("--project", default="road_anomaly") |
| p_yolo.add_argument("--name", default="yolov11_road_detection") |
| p_yolo.add_argument("--resume", action="store_true") |
| p_yolo.add_argument("--export", nargs="*", default=[], |
| help="Export formats after training (onnx, engine, tflite)") |
|
|
| |
| p_bigru = sub.add_parser("bigru", help="Train CNN-BiGRU severity model") |
| p_bigru.add_argument("--bigru-data", required=True, |
| help="Root dir with sequences/ + labels.csv") |
| p_bigru.add_argument("--seq-len", type=int, default=8) |
| p_bigru.add_argument("--bigru-batch", type=int, default=8) |
| p_bigru.add_argument("--bigru-epochs", type=int, default=50) |
| p_bigru.add_argument("--bigru-lr", type=float, default=1e-3) |
| p_bigru.add_argument("--bigru-save-dir", default="bigru_checkpoints") |
| p_bigru.add_argument("--patience", type=int, default=10) |
| p_bigru.add_argument("--workers", type=int, default=4) |
|
|
| |
| p_all = sub.add_parser("all", help="Train YOLO then BiGRU") |
| |
| p_all.add_argument("--data", required=True) |
| p_all.add_argument("--model", default="yolo11n.pt") |
| p_all.add_argument("--epochs", type=int, default=100) |
| p_all.add_argument("--batch", type=int, default=0) |
| p_all.add_argument("--imgsz", type=int, default=416) |
| p_all.add_argument("--optimizer", default="AdamW") |
| p_all.add_argument("--lr", type=float, default=0.001) |
| p_all.add_argument("--weight-decay", type=float, default=0.0005) |
| p_all.add_argument("--warmup", type=float, default=3.0) |
| p_all.add_argument("--cache", default="disk") |
| p_all.add_argument("--no-amp", action="store_true") |
| p_all.add_argument("--workers", type=int, default=4) |
| p_all.add_argument("--project", default="road_anomaly") |
| p_all.add_argument("--name", default="yolov11_road_detection") |
| p_all.add_argument("--resume", action="store_true") |
| p_all.add_argument("--export", nargs="*", default=[]) |
| p_all.add_argument("--bigru-data", default=None) |
| p_all.add_argument("--seq-len", type=int, default=8) |
| p_all.add_argument("--bigru-batch", type=int, default=8) |
| p_all.add_argument("--bigru-epochs", type=int, default=50) |
| p_all.add_argument("--bigru-lr", type=float, default=1e-3) |
| p_all.add_argument("--bigru-save-dir", default="bigru_checkpoints") |
| p_all.add_argument("--patience", type=int, default=10) |
|
|
| return parser |
|
|
|
|
| def main(): |
| parser = build_parser() |
| args = parser.parse_args() |
|
|
| print() |
| print("π ROAD ANOMALY DETECTION β YOLOv11 + CNN-BiGRU") |
| print(" Based on Nature Scientific Reports (Nov 2025)") |
| print(f" Started: {datetime.now():%Y-%m-%d %H:%M:%S}") |
| print() |
|
|
| if args.mode == "yolo": |
| train_yolo(args) |
|
|
| elif args.mode == "bigru": |
| train_bigru(args) |
|
|
| elif args.mode == "all": |
| |
| logger.info("βββ Phase 1/2: YOLOv11 Training βββ") |
| train_yolo(args) |
|
|
| |
| if args.bigru_data: |
| logger.info("βββ Phase 2/2: CNN-BiGRU Training βββ") |
| train_bigru(args) |
| else: |
| logger.info( |
| "Skipping BiGRU training β provide --bigru-data to enable." |
| ) |
|
|
| print() |
| print("π― Training pipeline complete!") |
| print(f" Finished: {datetime.now():%Y-%m-%d %H:%M:%S}") |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|