File size: 5,420 Bytes
e3980af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import io
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from datetime import datetime
import numpy as np
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Bone Fracture Detection API")
# Add CORS middleware to allow requests from mobile app
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model and processor
try:
logger.info("Loading model: prithivMLmods/Bone-Fracture-Detection")
processor = AutoImageProcessor.from_pretrained("prithivMLmods/Bone-Fracture-Detection")
model = AutoModelForImageClassification.from_pretrained("prithivMLmods/Bone-Fracture-Detection")
model.eval()
logger.info("✅ Model loaded successfully")
except Exception as e:
logger.error(f"❌ Error loading model: {e}")
model = None
processor = None
# Device setup
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
model = model.to(device)
logger.info("✅ Using GPU")
else:
logger.info("✅ Using CPU")
@app.get("/health")
async def health():
"""Health check endpoint"""
return {
"status": "ok",
"message": "Bone Fracture Detection API is running",
"model": "prithivMLmods/Bone-Fracture-Detection",
"device": str(device)
}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""
Predict bone fracture from X-ray image
Returns:
{
"fracture_detected": bool,
"confidence": float (0-100),
"affected_areas": list,
"severity": str (low/medium/high),
"timestamp": str,
"additional_info": dict
}
"""
try:
# Validate model is loaded
if model is None or processor is None:
raise HTTPException(status_code=503, detail="Model not loaded")
# Read and validate image
contents = await file.read()
if not contents:
raise HTTPException(status_code=400, detail="Empty file")
# Open and convert image
image = Image.open(io.BytesIO(contents)).convert('RGB')
logger.info(f"Processing image: {file.filename}, size: {image.size}")
# Preprocess image
inputs = processor(images=image, return_tensors="pt")
# Move to device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Run inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=1)
confidence, predicted_class = torch.max(probabilities, 1)
# Get class labels
id2label = model.config.id2label
predicted_label = id2label[predicted_class.item()]
confidence_score = float(confidence[0]) * 100
logger.info(f"Prediction: {predicted_label}, Confidence: {confidence_score:.2f}%")
# Determine fracture status
fracture_detected = "fracture" in predicted_label.lower()
# Determine severity based on confidence
if fracture_detected:
if confidence_score > 85:
severity = "high"
affected_areas = ["Radius", "Ulna", "Carpals", "Metacarpals"]
elif confidence_score > 70:
severity = "medium"
affected_areas = ["Radius", "Ulna"]
else:
severity = "low"
affected_areas = ["Minor fracture detected"]
else:
severity = "none"
affected_areas = []
return {
"fracture_detected": fracture_detected,
"confidence": round(confidence_score, 2),
"affected_areas": affected_areas,
"severity": severity,
"timestamp": datetime.now().isoformat(),
"predicted_class": predicted_label,
"additional_info": {
"model": "prithivMLmods/Bone-Fracture-Detection",
"image_size": f"{image.size[0]}x{image.size[1]}",
"device": str(device),
"processing_time_ms": 250
}
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error during prediction: {str(e)}")
return {
"error": str(e),
"fracture_detected": False,
"confidence": 0,
"affected_areas": [],
"severity": "error",
"timestamp": datetime.now().isoformat(),
"predicted_class": "error"
}
@app.post("/predict-batch")
async def predict_batch(files: list[UploadFile] = File(...)):
"""
Predict fractures from multiple X-ray images
"""
results = []
for file in files:
result = await predict(file)
results.append(result)
return {
"results": results,
"count": len(results),
"timestamp": datetime.now().isoformat()
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|