| |
| """ |
| High-Accuracy Training Script for Road Anomaly Detection |
| ========================================================= |
| Optimised for: RTX 2050 (4 GB), i5-12450H, 15 GB RAM |
| |
| Model: YOLO11s β 9.4M params, 21.5 GFLOPs (3.6Γ more than 11n) |
| |
| Usage: |
| python train_high_accuracy.py # Full training (300 epochs) |
| python train_high_accuracy.py --dry-run # Quick 2-epoch test run |
| """ |
|
|
| import os |
| import sys |
| import shutil |
| import logging |
| import argparse |
| from pathlib import Path |
| from datetime import datetime |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(message)s", |
| handlers=[logging.StreamHandler(), logging.FileHandler("training_optimised.log")], |
| ) |
| logger = logging.getLogger("train_optimised") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Train YOLO11s for road anomaly detection") |
| parser.add_argument("--dry-run", action="store_true", |
| help="Quick 2-epoch test to verify everything works") |
| args = parser.parse_args() |
|
|
| try: |
| import torch |
| from ultralytics import YOLO |
| except ImportError as e: |
| print(f"Missing dependency: {e}") |
| print("Run: pip install ultralytics torch") |
| sys.exit(1) |
|
|
| is_dry_run = args.dry_run |
| epochs = 2 if is_dry_run else 300 |
| run_name = "dry_run" if is_dry_run else "high_accuracy_s" |
|
|
| print() |
| print("=" * 60) |
| if is_dry_run: |
| print(" DRY RUN β 2 epochs to verify setup") |
| else: |
| print(" HIGH-ACCURACY ROAD ANOMALY DETECTION TRAINING") |
| print(" YOLO11s β’ RTX 2050 (4 GB) optimised") |
| print("=" * 60) |
| print(f" Started: {datetime.now():%Y-%m-%d %H:%M:%S}") |
| print() |
|
|
| |
| if torch.cuda.is_available(): |
| gpu = torch.cuda.get_device_properties(0) |
| vram_gb = gpu.total_memory / (1024 ** 3) |
| print(f" GPU: {gpu.name} ({vram_gb:.1f} GB)") |
| else: |
| vram_gb = 0 |
| print(" WARNING: No GPU β training will be very slow") |
|
|
| |
| |
| if vram_gb >= 6: |
| batch = 8 |
| elif vram_gb >= 4: |
| batch = 4 |
| else: |
| batch = 2 |
|
|
| |
| script_dir = Path(__file__).resolve().parent |
| data_yaml = script_dir / "dataset" / "data.yaml" |
| if not data_yaml.exists(): |
| print(f" ERROR: Dataset not found: {data_yaml}") |
| sys.exit(1) |
|
|
| |
| train_imgs = list((script_dir / "dataset" / "train" / "images").glob("*.jpg")) |
| valid_imgs = list((script_dir / "dataset" / "valid" / "images").glob("*.jpg")) |
| print(f" Dataset: {len(train_imgs)} train / {len(valid_imgs)} val images") |
| print(f" Batch size: {batch}") |
| print(f" Image size: 640 (native 600x600 β no downscaling)") |
| print(f" Epochs: {epochs}") |
| print() |
|
|
| |
| |
| |
| model_name = "yolo11s.pt" |
| print(f" Base model: {model_name}") |
| model = YOLO(model_name) |
|
|
| |
| |
| |
| try: |
| results = model.train( |
| |
| data=str(data_yaml), |
| imgsz=640, |
|
|
| |
| epochs=epochs, |
| patience=0 if is_dry_run else 50, |
| batch=batch, |
|
|
| |
| optimizer="AdamW", |
| lr0=0.002, |
| lrf=0.01, |
| momentum=0.937, |
| weight_decay=0.0005, |
| warmup_epochs=10, |
| warmup_momentum=0.5, |
| warmup_bias_lr=0.01, |
|
|
| |
| hsv_h=0.02, |
| hsv_s=0.75, |
| hsv_v=0.5, |
| degrees=15.0, |
| translate=0.2, |
| scale=0.5, |
| shear=5.0, |
| perspective=0.0001, |
| flipud=0.1, |
| fliplr=0.5, |
| mosaic=1.0, |
| mixup=0.15, |
| copy_paste=0.1, |
| erasing=0.2, |
| close_mosaic=20, |
|
|
| |
| device=0, |
| workers=4, |
| cache="disk", |
| amp=True, |
|
|
| |
| project="road_anomaly", |
| name=run_name, |
| exist_ok=True, |
| save=True, |
| save_period=25, |
| val=True, |
| plots=True, |
|
|
| |
| cos_lr=True, |
| nbs=64, |
| ) |
| except Exception as e: |
| logger.error("Training failed: %s", e) |
| import traceback |
| traceback.print_exc() |
| sys.exit(1) |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| save_dir = Path(model.trainer.save_dir) |
| logger.info("Training save dir: %s", save_dir) |
|
|
| best_src = save_dir / "weights" / "best.pt" |
| last_src = save_dir / "weights" / "last.pt" |
|
|
| |
| if not best_src.exists(): |
| logger.warning("best.pt not at expected path: %s", best_src) |
| logger.info("Searching for best.pt...") |
| for search_root in [Path("road_anomaly"), Path("runs"), script_dir]: |
| if not search_root.exists(): |
| continue |
| candidates = sorted(search_root.rglob("best.pt"), |
| key=lambda p: p.stat().st_mtime, reverse=True) |
| if candidates: |
| best_src = candidates[0] |
| last_src = best_src.parent / "last.pt" |
| logger.info("Found best.pt at: %s", best_src) |
| break |
|
|
| if not best_src.exists(): |
| logger.error("FATAL: best.pt not found anywhere after training!") |
| logger.error("Check these directories manually:") |
| logger.error(" %s", save_dir) |
| for p in Path(".").rglob("best.pt"): |
| logger.error(" Found: %s", p) |
| sys.exit(1) |
|
|
| |
| dest_dir = script_dir / "runs" |
| dest_dir.mkdir(parents=True, exist_ok=True) |
|
|
| dest_best = dest_dir / "best.pt" |
| shutil.copy2(best_src, dest_best) |
| logger.info("Best model copied to: %s", dest_best) |
|
|
| |
| shutil.copy2(best_src, script_dir / "best.pt") |
| logger.info("Best model copied to: %s", script_dir / "best.pt") |
|
|
| if last_src.exists(): |
| shutil.copy2(last_src, dest_dir / "last.pt") |
| logger.info("Last model copied to: %s", dest_dir / "last.pt") |
|
|
| |
| print() |
| print("=" * 60) |
| print(" FINAL VALIDATION") |
| print("=" * 60) |
|
|
| try: |
| best_model = YOLO(str(dest_best)) |
| metrics = best_model.val(data=str(data_yaml), imgsz=640, device=0) |
|
|
| p = metrics.box.mp |
| r = metrics.box.mr |
| f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0 |
|
|
| print(f" mAP@0.5: {metrics.box.map50*100:.1f}%") |
| print(f" mAP@0.5:0.95: {metrics.box.map*100:.1f}%") |
| print(f" Precision: {p*100:.1f}%") |
| print(f" Recall: {r*100:.1f}%") |
| print(f" F1-score: {f1*100:.1f}%") |
| print(f" Inference: {metrics.speed['inference']:.1f} ms/image") |
| print() |
|
|
| |
| print(" Per-class mAP@0.5:") |
| for i, ap in enumerate(metrics.box.ap50): |
| print(f" {best_model.names[i]:>20s}: {ap*100:.1f}%") |
| print("=" * 60) |
| except Exception as e: |
| logger.error("Validation failed: %s", e) |
| print(" Validation failed but model was saved successfully.") |
| print(f" Model at: {dest_best}") |
|
|
| print() |
| print(f" Finished: {datetime.now():%Y-%m-%d %H:%M:%S}") |
| print(f" Model saved to: {dest_best}") |
| print(" Run 'python evaluate.py' to re-check anytime.") |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|