Spaces:
Sleeping
Sleeping
| """ | |
| YOLOv8 Fracture Detection β Fine-Tuning Script | |
| ================================================ | |
| Dataset: GRAZPEDWRI-DX (recommended, already in YOLO format) | |
| https://www.kaggle.com/datasets/kmader/grazpedwri-dx | |
| Alternative datasets: | |
| FracAtlas: https://huggingface.co/datasets/FracAtlas/FracAtlas | |
| MURA (Stanford): https://stanfordmlgroup.github.io/competitions/mura/ | |
| Setup: | |
| pip install ultralytics kaggle | |
| Download GRAZPEDWRI-DX from Kaggle: | |
| 1. Get your API token from https://www.kaggle.com/settings β API β Create Token | |
| 2. Place kaggle.json in ~/.kaggle/ | |
| 3. Run: kaggle datasets download -d kmader/grazpedwri-dx -p data/grazpedwri-dx --unzip | |
| Usage: | |
| python train_fracture_model.py --dataset data/grazpedwri-dx --epochs 100 | |
| python train_fracture_model.py --dataset data/grazpedwri-dx --epochs 50 --model yolov8s.pt | |
| After training, copy the best weights: | |
| cp runs/fracture/train/weights/best.pt models/fracture_yolov8.pt | |
| Then update backend/.env: | |
| YOLO_WEIGHTS_PATH=models/fracture_yolov8.pt | |
| ALLOW_GENERIC_YOLO_WEIGHTS=false | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import shutil | |
| import sys | |
| from pathlib import Path | |
| # ββ Dataset YAML generation ββββββββββββββββββββββββββββββββββββββββββ | |
| def build_dataset_yaml(dataset_path: str, output_path: str = "fracture_dataset.yaml") -> str: | |
| """Create a YOLO dataset config pointing to the local dataset.""" | |
| dataset = Path(dataset_path).resolve() | |
| # GRAZPEDWRI-DX ships with train/ and valid/ splits | |
| train_path = dataset / "images" / "train" | |
| val_path = dataset / "images" / "valid" | |
| if not train_path.exists(): | |
| # Try flat structure | |
| train_path = dataset / "train" / "images" | |
| val_path = dataset / "valid" / "images" | |
| if not train_path.exists(): | |
| print(f"[WARN] Expected train images at {train_path} β verify dataset structure.") | |
| yaml_content = f"""# GRAZPEDWRI-DX Fracture Detection Dataset | |
| # https://www.kaggle.com/datasets/kmader/grazpedwri-dx | |
| path: {str(dataset)} | |
| train: images/train | |
| val: images/valid | |
| # Classes β GRAZPEDWRI-DX classes | |
| nc: 10 | |
| names: | |
| 0: boneanomaly | |
| 1: bonelesion | |
| 2: foreignbody | |
| 3: fracture | |
| 4: metal | |
| 5: periostealreaction | |
| 6: pronationsign | |
| 7: softtissue | |
| 8: text | |
| 9: hardware | |
| # Notes: | |
| # The primary class of interest is 'fracture' (index 3). | |
| # During inference, XRayVision AI reports any detected class | |
| # that relates to bone damage. | |
| """ | |
| with open(output_path, "w") as f: | |
| f.write(yaml_content) | |
| print(f"[OK] Dataset YAML written to: {output_path}") | |
| return output_path | |
| # ββ Training βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def train( | |
| dataset_yaml: str, | |
| base_model: str = "yolov8n.pt", | |
| epochs: int = 100, | |
| batch: int = 16, | |
| imgsz: int = 640, | |
| project: str = "runs/fracture", | |
| name: str = "train", | |
| device: str = "0", # GPU 0; use "cpu" for CPU-only | |
| ) -> Path: | |
| """Fine-tune YOLOv8 on the fracture dataset.""" | |
| try: | |
| from ultralytics import YOLO | |
| except ImportError: | |
| sys.exit("ultralytics not installed. Run: pip install ultralytics") | |
| print(f"\n{'='*60}") | |
| print(f" XRayVision AI β YOLOv8 Fracture Fine-Tuning") | |
| print(f"{'='*60}") | |
| print(f" Base model : {base_model}") | |
| print(f" Dataset : {dataset_yaml}") | |
| print(f" Epochs : {epochs}") | |
| print(f" Batch size : {batch}") | |
| print(f" Image size : {imgsz}px") | |
| print(f" Device : {device}") | |
| print(f"{'='*60}\n") | |
| model = YOLO(base_model) | |
| results = model.train( | |
| data=dataset_yaml, | |
| epochs=epochs, | |
| batch=batch, | |
| imgsz=imgsz, | |
| project=project, | |
| name=name, | |
| device=device, | |
| # Medical imaging β use slower, more accurate augmentations | |
| augment=True, | |
| hsv_h=0.0, # X-rays are grayscale, skip hue shift | |
| hsv_s=0.0, # Skip saturation shift | |
| hsv_v=0.4, # Vary brightness (simulates exposure differences) | |
| degrees=10.0, # Mild rotation (patient positioning variation) | |
| translate=0.1, | |
| scale=0.5, | |
| shear=0.0, | |
| flipud=0.0, | |
| fliplr=0.5, # Horizontal flip is anatomically valid | |
| mosaic=0.5, | |
| # Hyperparameters tuned for medical images | |
| lr0=0.001, | |
| lrf=0.01, | |
| momentum=0.937, | |
| weight_decay=0.0005, | |
| warmup_epochs=3.0, | |
| patience=20, # Early stopping | |
| save_period=10, | |
| val=True, | |
| plots=True, | |
| # Class weights to prioritize fracture detection | |
| cls=0.5, | |
| box=7.5, | |
| dfl=1.5, | |
| verbose=True, | |
| exist_ok=True, | |
| ) | |
| best_weights = Path(project) / name / "weights" / "best.pt" | |
| print(f"\n[DONE] Training complete!") | |
| print(f" Best weights: {best_weights.resolve()}") | |
| print(f" mAP50: {results.results_dict.get('metrics/mAP50(B)', 'N/A')}") | |
| print(f" mAP50-95: {results.results_dict.get('metrics/mAP50-95(B)', 'N/A')}") | |
| return best_weights | |
| # ββ Validation βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def validate(weights_path: str, dataset_yaml: str, imgsz: int = 640) -> None: | |
| """Run validation on the val split and print metrics.""" | |
| from ultralytics import YOLO | |
| print(f"\n[INFO] Validating {weights_path} ...") | |
| model = YOLO(weights_path) | |
| metrics = model.val(data=dataset_yaml, imgsz=imgsz) | |
| print(f"\n[RESULTS]") | |
| print(f" mAP50 : {metrics.box.map50:.4f}") | |
| print(f" mAP50-95 : {metrics.box.map:.4f}") | |
| print(f" Precision : {metrics.box.mp:.4f}") | |
| print(f" Recall : {metrics.box.mr:.4f}") | |
| print("\n[Per-class]") | |
| for i, name in enumerate(metrics.names.values()): | |
| ap = metrics.box.ap[i] if i < len(metrics.box.ap) else 0 | |
| print(f" {name:<30} AP50={ap:.4f}") | |
| # ββ Deploy helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def deploy(best_weights: str, target: str = "models/fracture_yolov8.pt") -> None: | |
| """Copy best weights to the backend models/ directory.""" | |
| target_path = Path(target) | |
| target_path.parent.mkdir(parents=True, exist_ok=True) | |
| shutil.copy(best_weights, target_path) | |
| print(f"\n[DEPLOYED] Weights copied to: {target_path.resolve()}") | |
| print(f"\nNext steps:") | |
| print(f" 1. Update backend/.env:") | |
| print(f" YOLO_WEIGHTS_PATH=models/fracture_yolov8.pt") | |
| print(f" ALLOW_GENERIC_YOLO_WEIGHTS=false") | |
| print(f" 2. Restart the FastAPI server.") | |
| # ββ CLI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main() -> None: | |
| parser = argparse.ArgumentParser( | |
| description="Fine-tune YOLOv8 for fracture detection (XRayVision AI FYP)" | |
| ) | |
| parser.add_argument("--dataset", required=True, help="Path to GRAZPEDWRI-DX dataset root") | |
| parser.add_argument("--model", default="yolov8n.pt", help="Base YOLO model (yolov8n/s/m/l/x.pt)") | |
| parser.add_argument("--epochs", type=int, default=100) | |
| parser.add_argument("--batch", type=int, default=16) | |
| parser.add_argument("--imgsz", type=int, default=640) | |
| parser.add_argument("--device", default="0", help="GPU id (0,1,...) or 'cpu'") | |
| parser.add_argument("--validate-only", metavar="WEIGHTS", default=None, | |
| help="Skip training and only validate existing weights") | |
| parser.add_argument("--deploy", action="store_true", | |
| help="Copy best.pt to models/fracture_yolov8.pt after training") | |
| args = parser.parse_args() | |
| yaml_path = build_dataset_yaml(args.dataset) | |
| if args.validate_only: | |
| validate(args.validate_only, yaml_path, args.imgsz) | |
| return | |
| best = train( | |
| dataset_yaml=yaml_path, | |
| base_model=args.model, | |
| epochs=args.epochs, | |
| batch=args.batch, | |
| imgsz=args.imgsz, | |
| device=args.device, | |
| ) | |
| validate(str(best), yaml_path, args.imgsz) | |
| if args.deploy: | |
| deploy(str(best)) | |
| if __name__ == "__main__": | |
| main() | |