#!/usr/bin/env python3 """ Unified Training Script – YOLOv11 + CNN-BiGRU Based on: Nature Scientific Reports (Nov 2025) Usage: # Train YOLOv11 detector only python train.py yolo --data dataset/data.yaml --epochs 100 # Train CNN-BiGRU severity model (requires sequence data) python train.py bigru --data severity_sequences/ --epochs 50 # Train both sequentially python train.py all --data dataset/data.yaml --bigru-data severity_sequences/ """ import os import sys import shutil import logging import argparse from pathlib import Path from datetime import datetime import torch import yaml logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(), logging.FileHandler("training.log")], ) logger = logging.getLogger("train") # ═══════════════════════════════════════════════════════════════════════════ # YOLOv11 Training # ═══════════════════════════════════════════════════════════════════════════ def train_yolo(args): from yolo_detection import YOLOv11Detector logger.info("=" * 60) logger.info(" YOLOv11 Road Anomaly Detection – Training") logger.info("=" * 60) # GPU info if torch.cuda.is_available(): name = torch.cuda.get_device_properties(0).name vram = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) logger.info("GPU: %s (%.1f GB)", name, vram) else: logger.info("Training on CPU (this will be slow)") # Resolve data.yaml data_yaml = str(Path(args.data).resolve()) logger.info("Dataset config: %s", data_yaml) # Determine batch size from VRAM batch = args.batch if batch == 0: # Auto-select based on GPU VRAM # RTX 2050 (4 GB) → batch 4 @ 416px # RTX 3060 (8 GB) → batch 8 # RTX 3090+ (20+ GB) → batch 16 if torch.cuda.is_available(): vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) if vram_gb >= 20: batch = 16 elif vram_gb >= 8: batch = 8 else: batch = 4 else: batch = 2 logger.info("Auto batch size: %d (VRAM=%.1f GB)", batch, vram_gb if torch.cuda.is_available() else 0) detector = YOLOv11Detector( model_path=args.model, img_size=args.imgsz, ) results = detector.train( data_yaml=data_yaml, epochs=args.epochs, batch=batch, optimizer=args.optimizer, lr0=args.lr, weight_decay=args.weight_decay, warmup_epochs=args.warmup, mosaic=0.5, cache=args.cache, amp=not args.no_amp, workers=args.workers, project=args.project, name=args.name, resume=args.resume, ) # Copy best.pt to project root for easy access best_src = Path(args.project) / args.name / "weights" / "best.pt" if best_src.exists(): best_dst = Path("runs/best.pt") best_dst.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(best_src, best_dst) logger.info("✅ Best model → %s", best_dst) # Export if args.export: for fmt in args.export: try: detector.export(format=fmt, half=(fmt == "engine")) logger.info("✅ Exported → %s", fmt) except Exception as e: logger.warning("Export %s failed: %s", fmt, e) return results # ═══════════════════════════════════════════════════════════════════════════ # CNN-BiGRU Training # ═══════════════════════════════════════════════════════════════════════════ def train_bigru(args): from cnn_bigru import CNNBiGRU, AnomalySequenceDataset, BiGRUTrainer from torch.utils.data import DataLoader, random_split logger.info("=" * 60) logger.info(" CNN-BiGRU Severity Prediction – Training") logger.info("=" * 60) # Load dataset dataset = AnomalySequenceDataset( root=args.bigru_data, seq_len=args.seq_len, patch_size=64, ) # Split 80/20 n_val = max(1, int(len(dataset) * 0.2)) n_train = len(dataset) - n_val train_ds, val_ds = random_split(dataset, [n_train, n_val]) train_loader = DataLoader( train_ds, batch_size=args.bigru_batch, shuffle=True, num_workers=args.workers, pin_memory=True, ) val_loader = DataLoader( val_ds, batch_size=args.bigru_batch, shuffle=False, num_workers=args.workers, pin_memory=True, ) logger.info("Train sequences: %d | Val sequences: %d", n_train, n_val) # Create model model = CNNBiGRU( in_channels=3, hidden_size=128, num_gru_layers=2, num_severity_classes=4, ) trainer = BiGRUTrainer( model=model, lr=args.bigru_lr, weight_decay=1e-4, ) history = trainer.fit( train_loader=train_loader, val_loader=val_loader, epochs=args.bigru_epochs, save_dir=args.bigru_save_dir, patience=args.patience, ) # Copy best to project root best_src = Path(args.bigru_save_dir) / "best_bigru.pth" if best_src.exists(): best_dst = Path("runs/best_bigru.pth") shutil.copy2(best_src, best_dst) logger.info("✅ Best BiGRU → %s", best_dst) return history # ═══════════════════════════════════════════════════════════════════════════ # CLI # ═══════════════════════════════════════════════════════════════════════════ def build_parser(): parser = argparse.ArgumentParser( description="Train YOLOv11 + CNN-BiGRU Road Anomaly Detection System", ) sub = parser.add_subparsers(dest="mode", required=True) # ---- yolo ---- p_yolo = sub.add_parser("yolo", help="Train YOLOv11 detector") p_yolo.add_argument("--data", required=True, help="data.yaml path") p_yolo.add_argument("--model", default="yolo11n.pt", help="Base model (yolo11n/s/m/l/x.pt)") p_yolo.add_argument("--epochs", type=int, default=100) p_yolo.add_argument("--batch", type=int, default=0, help="Batch size (0 = auto from VRAM)") p_yolo.add_argument("--imgsz", type=int, default=416) p_yolo.add_argument("--optimizer", default="AdamW") p_yolo.add_argument("--lr", type=float, default=0.001) p_yolo.add_argument("--weight-decay", type=float, default=0.0005) p_yolo.add_argument("--warmup", type=float, default=3.0) p_yolo.add_argument("--cache", default="disk", help="'ram', 'disk', or '' for none") p_yolo.add_argument("--no-amp", action="store_true") p_yolo.add_argument("--workers", type=int, default=4) p_yolo.add_argument("--project", default="road_anomaly") p_yolo.add_argument("--name", default="yolov11_road_detection") p_yolo.add_argument("--resume", action="store_true") p_yolo.add_argument("--export", nargs="*", default=[], help="Export formats after training (onnx, engine, tflite)") # ---- bigru ---- p_bigru = sub.add_parser("bigru", help="Train CNN-BiGRU severity model") p_bigru.add_argument("--bigru-data", required=True, help="Root dir with sequences/ + labels.csv") p_bigru.add_argument("--seq-len", type=int, default=8) p_bigru.add_argument("--bigru-batch", type=int, default=8) p_bigru.add_argument("--bigru-epochs", type=int, default=50) p_bigru.add_argument("--bigru-lr", type=float, default=1e-3) p_bigru.add_argument("--bigru-save-dir", default="bigru_checkpoints") p_bigru.add_argument("--patience", type=int, default=10) p_bigru.add_argument("--workers", type=int, default=4) # ---- all ---- p_all = sub.add_parser("all", help="Train YOLO then BiGRU") # Inherit all args from both p_all.add_argument("--data", required=True) p_all.add_argument("--model", default="yolo11n.pt") p_all.add_argument("--epochs", type=int, default=100) p_all.add_argument("--batch", type=int, default=0) p_all.add_argument("--imgsz", type=int, default=416) p_all.add_argument("--optimizer", default="AdamW") p_all.add_argument("--lr", type=float, default=0.001) p_all.add_argument("--weight-decay", type=float, default=0.0005) p_all.add_argument("--warmup", type=float, default=3.0) p_all.add_argument("--cache", default="disk") p_all.add_argument("--no-amp", action="store_true") p_all.add_argument("--workers", type=int, default=4) p_all.add_argument("--project", default="road_anomaly") p_all.add_argument("--name", default="yolov11_road_detection") p_all.add_argument("--resume", action="store_true") p_all.add_argument("--export", nargs="*", default=[]) p_all.add_argument("--bigru-data", default=None) p_all.add_argument("--seq-len", type=int, default=8) p_all.add_argument("--bigru-batch", type=int, default=8) p_all.add_argument("--bigru-epochs", type=int, default=50) p_all.add_argument("--bigru-lr", type=float, default=1e-3) p_all.add_argument("--bigru-save-dir", default="bigru_checkpoints") p_all.add_argument("--patience", type=int, default=10) return parser def main(): parser = build_parser() args = parser.parse_args() print() print("🚗 ROAD ANOMALY DETECTION – YOLOv11 + CNN-BiGRU") print(" Based on Nature Scientific Reports (Nov 2025)") print(f" Started: {datetime.now():%Y-%m-%d %H:%M:%S}") print() if args.mode == "yolo": train_yolo(args) elif args.mode == "bigru": train_bigru(args) elif args.mode == "all": # Phase 1: YOLO logger.info("═══ Phase 1/2: YOLOv11 Training ═══") train_yolo(args) # Phase 2: BiGRU (if data provided) if args.bigru_data: logger.info("═══ Phase 2/2: CNN-BiGRU Training ═══") train_bigru(args) else: logger.info( "Skipping BiGRU training – provide --bigru-data to enable." ) print() print("🎯 Training pipeline complete!") print(f" Finished: {datetime.now():%Y-%m-%d %H:%M:%S}") print() if __name__ == "__main__": main()