xrayvision-backend / train_fracture_model.py
zohaibcodez's picture
initial deploy
ce4fddb
Raw
History Blame Contribute Delete
8.59 kB
"""
YOLOv8 Fracture Detection β€” Fine-Tuning Script
================================================
Dataset: GRAZPEDWRI-DX (recommended, already in YOLO format)
https://www.kaggle.com/datasets/kmader/grazpedwri-dx
Alternative datasets:
FracAtlas: https://huggingface.co/datasets/FracAtlas/FracAtlas
MURA (Stanford): https://stanfordmlgroup.github.io/competitions/mura/
Setup:
pip install ultralytics kaggle
Download GRAZPEDWRI-DX from Kaggle:
1. Get your API token from https://www.kaggle.com/settings β†’ API β†’ Create Token
2. Place kaggle.json in ~/.kaggle/
3. Run: kaggle datasets download -d kmader/grazpedwri-dx -p data/grazpedwri-dx --unzip
Usage:
python train_fracture_model.py --dataset data/grazpedwri-dx --epochs 100
python train_fracture_model.py --dataset data/grazpedwri-dx --epochs 50 --model yolov8s.pt
After training, copy the best weights:
cp runs/fracture/train/weights/best.pt models/fracture_yolov8.pt
Then update backend/.env:
YOLO_WEIGHTS_PATH=models/fracture_yolov8.pt
ALLOW_GENERIC_YOLO_WEIGHTS=false
"""
from __future__ import annotations
import argparse
import os
import shutil
import sys
from pathlib import Path
# ── Dataset YAML generation ──────────────────────────────────────────
def build_dataset_yaml(dataset_path: str, output_path: str = "fracture_dataset.yaml") -> str:
"""Create a YOLO dataset config pointing to the local dataset."""
dataset = Path(dataset_path).resolve()
# GRAZPEDWRI-DX ships with train/ and valid/ splits
train_path = dataset / "images" / "train"
val_path = dataset / "images" / "valid"
if not train_path.exists():
# Try flat structure
train_path = dataset / "train" / "images"
val_path = dataset / "valid" / "images"
if not train_path.exists():
print(f"[WARN] Expected train images at {train_path} β€” verify dataset structure.")
yaml_content = f"""# GRAZPEDWRI-DX Fracture Detection Dataset
# https://www.kaggle.com/datasets/kmader/grazpedwri-dx
path: {str(dataset)}
train: images/train
val: images/valid
# Classes β€” GRAZPEDWRI-DX classes
nc: 10
names:
0: boneanomaly
1: bonelesion
2: foreignbody
3: fracture
4: metal
5: periostealreaction
6: pronationsign
7: softtissue
8: text
9: hardware
# Notes:
# The primary class of interest is 'fracture' (index 3).
# During inference, XRayVision AI reports any detected class
# that relates to bone damage.
"""
with open(output_path, "w") as f:
f.write(yaml_content)
print(f"[OK] Dataset YAML written to: {output_path}")
return output_path
# ── Training ─────────────────────────────────────────────────────────
def train(
dataset_yaml: str,
base_model: str = "yolov8n.pt",
epochs: int = 100,
batch: int = 16,
imgsz: int = 640,
project: str = "runs/fracture",
name: str = "train",
device: str = "0", # GPU 0; use "cpu" for CPU-only
) -> Path:
"""Fine-tune YOLOv8 on the fracture dataset."""
try:
from ultralytics import YOLO
except ImportError:
sys.exit("ultralytics not installed. Run: pip install ultralytics")
print(f"\n{'='*60}")
print(f" XRayVision AI β€” YOLOv8 Fracture Fine-Tuning")
print(f"{'='*60}")
print(f" Base model : {base_model}")
print(f" Dataset : {dataset_yaml}")
print(f" Epochs : {epochs}")
print(f" Batch size : {batch}")
print(f" Image size : {imgsz}px")
print(f" Device : {device}")
print(f"{'='*60}\n")
model = YOLO(base_model)
results = model.train(
data=dataset_yaml,
epochs=epochs,
batch=batch,
imgsz=imgsz,
project=project,
name=name,
device=device,
# Medical imaging β€” use slower, more accurate augmentations
augment=True,
hsv_h=0.0, # X-rays are grayscale, skip hue shift
hsv_s=0.0, # Skip saturation shift
hsv_v=0.4, # Vary brightness (simulates exposure differences)
degrees=10.0, # Mild rotation (patient positioning variation)
translate=0.1,
scale=0.5,
shear=0.0,
flipud=0.0,
fliplr=0.5, # Horizontal flip is anatomically valid
mosaic=0.5,
# Hyperparameters tuned for medical images
lr0=0.001,
lrf=0.01,
momentum=0.937,
weight_decay=0.0005,
warmup_epochs=3.0,
patience=20, # Early stopping
save_period=10,
val=True,
plots=True,
# Class weights to prioritize fracture detection
cls=0.5,
box=7.5,
dfl=1.5,
verbose=True,
exist_ok=True,
)
best_weights = Path(project) / name / "weights" / "best.pt"
print(f"\n[DONE] Training complete!")
print(f" Best weights: {best_weights.resolve()}")
print(f" mAP50: {results.results_dict.get('metrics/mAP50(B)', 'N/A')}")
print(f" mAP50-95: {results.results_dict.get('metrics/mAP50-95(B)', 'N/A')}")
return best_weights
# ── Validation ───────────────────────────────────────────────────────
def validate(weights_path: str, dataset_yaml: str, imgsz: int = 640) -> None:
"""Run validation on the val split and print metrics."""
from ultralytics import YOLO
print(f"\n[INFO] Validating {weights_path} ...")
model = YOLO(weights_path)
metrics = model.val(data=dataset_yaml, imgsz=imgsz)
print(f"\n[RESULTS]")
print(f" mAP50 : {metrics.box.map50:.4f}")
print(f" mAP50-95 : {metrics.box.map:.4f}")
print(f" Precision : {metrics.box.mp:.4f}")
print(f" Recall : {metrics.box.mr:.4f}")
print("\n[Per-class]")
for i, name in enumerate(metrics.names.values()):
ap = metrics.box.ap[i] if i < len(metrics.box.ap) else 0
print(f" {name:<30} AP50={ap:.4f}")
# ── Deploy helpers ────────────────────────────────────────────────────
def deploy(best_weights: str, target: str = "models/fracture_yolov8.pt") -> None:
"""Copy best weights to the backend models/ directory."""
target_path = Path(target)
target_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(best_weights, target_path)
print(f"\n[DEPLOYED] Weights copied to: {target_path.resolve()}")
print(f"\nNext steps:")
print(f" 1. Update backend/.env:")
print(f" YOLO_WEIGHTS_PATH=models/fracture_yolov8.pt")
print(f" ALLOW_GENERIC_YOLO_WEIGHTS=false")
print(f" 2. Restart the FastAPI server.")
# ── CLI ───────────────────────────────────────────────────────────────
def main() -> None:
parser = argparse.ArgumentParser(
description="Fine-tune YOLOv8 for fracture detection (XRayVision AI FYP)"
)
parser.add_argument("--dataset", required=True, help="Path to GRAZPEDWRI-DX dataset root")
parser.add_argument("--model", default="yolov8n.pt", help="Base YOLO model (yolov8n/s/m/l/x.pt)")
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--batch", type=int, default=16)
parser.add_argument("--imgsz", type=int, default=640)
parser.add_argument("--device", default="0", help="GPU id (0,1,...) or 'cpu'")
parser.add_argument("--validate-only", metavar="WEIGHTS", default=None,
help="Skip training and only validate existing weights")
parser.add_argument("--deploy", action="store_true",
help="Copy best.pt to models/fracture_yolov8.pt after training")
args = parser.parse_args()
yaml_path = build_dataset_yaml(args.dataset)
if args.validate_only:
validate(args.validate_only, yaml_path, args.imgsz)
return
best = train(
dataset_yaml=yaml_path,
base_model=args.model,
epochs=args.epochs,
batch=args.batch,
imgsz=args.imgsz,
device=args.device,
)
validate(str(best), yaml_path, args.imgsz)
if args.deploy:
deploy(str(best))
if __name__ == "__main__":
main()