arm-model / model /train.py
pragadeeshv23's picture
Upload folder using huggingface_hub
5b86813 verified
#!/usr/bin/env python3
"""
Unified Training Script – YOLOv11 + CNN-BiGRU
Based on: Nature Scientific Reports (Nov 2025)
Usage:
# Train YOLOv11 detector only
python train.py yolo --data dataset/data.yaml --epochs 100
# Train CNN-BiGRU severity model (requires sequence data)
python train.py bigru --data severity_sequences/ --epochs 50
# Train both sequentially
python train.py all --data dataset/data.yaml --bigru-data severity_sequences/
"""
import os
import sys
import shutil
import logging
import argparse
from pathlib import Path
from datetime import datetime
import torch
import yaml
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler(), logging.FileHandler("training.log")],
)
logger = logging.getLogger("train")
# ═══════════════════════════════════════════════════════════════════════════
# YOLOv11 Training
# ═══════════════════════════════════════════════════════════════════════════
def train_yolo(args):
from yolo_detection import YOLOv11Detector
logger.info("=" * 60)
logger.info(" YOLOv11 Road Anomaly Detection – Training")
logger.info("=" * 60)
# GPU info
if torch.cuda.is_available():
name = torch.cuda.get_device_properties(0).name
vram = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
logger.info("GPU: %s (%.1f GB)", name, vram)
else:
logger.info("Training on CPU (this will be slow)")
# Resolve data.yaml
data_yaml = str(Path(args.data).resolve())
logger.info("Dataset config: %s", data_yaml)
# Determine batch size from VRAM
batch = args.batch
if batch == 0:
# Auto-select based on GPU VRAM
# RTX 2050 (4 GB) β†’ batch 4 @ 416px
# RTX 3060 (8 GB) β†’ batch 8
# RTX 3090+ (20+ GB) β†’ batch 16
if torch.cuda.is_available():
vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
if vram_gb >= 20:
batch = 16
elif vram_gb >= 8:
batch = 8
else:
batch = 4
else:
batch = 2
logger.info("Auto batch size: %d (VRAM=%.1f GB)", batch,
vram_gb if torch.cuda.is_available() else 0)
detector = YOLOv11Detector(
model_path=args.model,
img_size=args.imgsz,
)
results = detector.train(
data_yaml=data_yaml,
epochs=args.epochs,
batch=batch,
optimizer=args.optimizer,
lr0=args.lr,
weight_decay=args.weight_decay,
warmup_epochs=args.warmup,
mosaic=0.5,
cache=args.cache,
amp=not args.no_amp,
workers=args.workers,
project=args.project,
name=args.name,
resume=args.resume,
)
# Copy best.pt to project root for easy access
best_src = Path(args.project) / args.name / "weights" / "best.pt"
if best_src.exists():
best_dst = Path("runs/best.pt")
best_dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(best_src, best_dst)
logger.info("βœ… Best model β†’ %s", best_dst)
# Export
if args.export:
for fmt in args.export:
try:
detector.export(format=fmt, half=(fmt == "engine"))
logger.info("βœ… Exported β†’ %s", fmt)
except Exception as e:
logger.warning("Export %s failed: %s", fmt, e)
return results
# ═══════════════════════════════════════════════════════════════════════════
# CNN-BiGRU Training
# ═══════════════════════════════════════════════════════════════════════════
def train_bigru(args):
from cnn_bigru import CNNBiGRU, AnomalySequenceDataset, BiGRUTrainer
from torch.utils.data import DataLoader, random_split
logger.info("=" * 60)
logger.info(" CNN-BiGRU Severity Prediction – Training")
logger.info("=" * 60)
# Load dataset
dataset = AnomalySequenceDataset(
root=args.bigru_data,
seq_len=args.seq_len,
patch_size=64,
)
# Split 80/20
n_val = max(1, int(len(dataset) * 0.2))
n_train = len(dataset) - n_val
train_ds, val_ds = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(
train_ds, batch_size=args.bigru_batch, shuffle=True,
num_workers=args.workers, pin_memory=True,
)
val_loader = DataLoader(
val_ds, batch_size=args.bigru_batch, shuffle=False,
num_workers=args.workers, pin_memory=True,
)
logger.info("Train sequences: %d | Val sequences: %d", n_train, n_val)
# Create model
model = CNNBiGRU(
in_channels=3,
hidden_size=128,
num_gru_layers=2,
num_severity_classes=4,
)
trainer = BiGRUTrainer(
model=model,
lr=args.bigru_lr,
weight_decay=1e-4,
)
history = trainer.fit(
train_loader=train_loader,
val_loader=val_loader,
epochs=args.bigru_epochs,
save_dir=args.bigru_save_dir,
patience=args.patience,
)
# Copy best to project root
best_src = Path(args.bigru_save_dir) / "best_bigru.pth"
if best_src.exists():
best_dst = Path("runs/best_bigru.pth")
shutil.copy2(best_src, best_dst)
logger.info("βœ… Best BiGRU β†’ %s", best_dst)
return history
# ═══════════════════════════════════════════════════════════════════════════
# CLI
# ═══════════════════════════════════════════════════════════════════════════
def build_parser():
parser = argparse.ArgumentParser(
description="Train YOLOv11 + CNN-BiGRU Road Anomaly Detection System",
)
sub = parser.add_subparsers(dest="mode", required=True)
# ---- yolo ----
p_yolo = sub.add_parser("yolo", help="Train YOLOv11 detector")
p_yolo.add_argument("--data", required=True, help="data.yaml path")
p_yolo.add_argument("--model", default="yolo11n.pt",
help="Base model (yolo11n/s/m/l/x.pt)")
p_yolo.add_argument("--epochs", type=int, default=100)
p_yolo.add_argument("--batch", type=int, default=0,
help="Batch size (0 = auto from VRAM)")
p_yolo.add_argument("--imgsz", type=int, default=416)
p_yolo.add_argument("--optimizer", default="AdamW")
p_yolo.add_argument("--lr", type=float, default=0.001)
p_yolo.add_argument("--weight-decay", type=float, default=0.0005)
p_yolo.add_argument("--warmup", type=float, default=3.0)
p_yolo.add_argument("--cache", default="disk",
help="'ram', 'disk', or '' for none")
p_yolo.add_argument("--no-amp", action="store_true")
p_yolo.add_argument("--workers", type=int, default=4)
p_yolo.add_argument("--project", default="road_anomaly")
p_yolo.add_argument("--name", default="yolov11_road_detection")
p_yolo.add_argument("--resume", action="store_true")
p_yolo.add_argument("--export", nargs="*", default=[],
help="Export formats after training (onnx, engine, tflite)")
# ---- bigru ----
p_bigru = sub.add_parser("bigru", help="Train CNN-BiGRU severity model")
p_bigru.add_argument("--bigru-data", required=True,
help="Root dir with sequences/ + labels.csv")
p_bigru.add_argument("--seq-len", type=int, default=8)
p_bigru.add_argument("--bigru-batch", type=int, default=8)
p_bigru.add_argument("--bigru-epochs", type=int, default=50)
p_bigru.add_argument("--bigru-lr", type=float, default=1e-3)
p_bigru.add_argument("--bigru-save-dir", default="bigru_checkpoints")
p_bigru.add_argument("--patience", type=int, default=10)
p_bigru.add_argument("--workers", type=int, default=4)
# ---- all ----
p_all = sub.add_parser("all", help="Train YOLO then BiGRU")
# Inherit all args from both
p_all.add_argument("--data", required=True)
p_all.add_argument("--model", default="yolo11n.pt")
p_all.add_argument("--epochs", type=int, default=100)
p_all.add_argument("--batch", type=int, default=0)
p_all.add_argument("--imgsz", type=int, default=416)
p_all.add_argument("--optimizer", default="AdamW")
p_all.add_argument("--lr", type=float, default=0.001)
p_all.add_argument("--weight-decay", type=float, default=0.0005)
p_all.add_argument("--warmup", type=float, default=3.0)
p_all.add_argument("--cache", default="disk")
p_all.add_argument("--no-amp", action="store_true")
p_all.add_argument("--workers", type=int, default=4)
p_all.add_argument("--project", default="road_anomaly")
p_all.add_argument("--name", default="yolov11_road_detection")
p_all.add_argument("--resume", action="store_true")
p_all.add_argument("--export", nargs="*", default=[])
p_all.add_argument("--bigru-data", default=None)
p_all.add_argument("--seq-len", type=int, default=8)
p_all.add_argument("--bigru-batch", type=int, default=8)
p_all.add_argument("--bigru-epochs", type=int, default=50)
p_all.add_argument("--bigru-lr", type=float, default=1e-3)
p_all.add_argument("--bigru-save-dir", default="bigru_checkpoints")
p_all.add_argument("--patience", type=int, default=10)
return parser
def main():
parser = build_parser()
args = parser.parse_args()
print()
print("πŸš— ROAD ANOMALY DETECTION – YOLOv11 + CNN-BiGRU")
print(" Based on Nature Scientific Reports (Nov 2025)")
print(f" Started: {datetime.now():%Y-%m-%d %H:%M:%S}")
print()
if args.mode == "yolo":
train_yolo(args)
elif args.mode == "bigru":
train_bigru(args)
elif args.mode == "all":
# Phase 1: YOLO
logger.info("═══ Phase 1/2: YOLOv11 Training ═══")
train_yolo(args)
# Phase 2: BiGRU (if data provided)
if args.bigru_data:
logger.info("═══ Phase 2/2: CNN-BiGRU Training ═══")
train_bigru(args)
else:
logger.info(
"Skipping BiGRU training – provide --bigru-data to enable."
)
print()
print("🎯 Training pipeline complete!")
print(f" Finished: {datetime.now():%Y-%m-%d %H:%M:%S}")
print()
if __name__ == "__main__":
main()