detection / app.py
kris524's picture
Create app.py
e3980af verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import io
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from datetime import datetime
import numpy as np
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Bone Fracture Detection API")
# Add CORS middleware to allow requests from mobile app
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model and processor
try:
logger.info("Loading model: prithivMLmods/Bone-Fracture-Detection")
processor = AutoImageProcessor.from_pretrained("prithivMLmods/Bone-Fracture-Detection")
model = AutoModelForImageClassification.from_pretrained("prithivMLmods/Bone-Fracture-Detection")
model.eval()
logger.info("✅ Model loaded successfully")
except Exception as e:
logger.error(f"❌ Error loading model: {e}")
model = None
processor = None
# Device setup
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
model = model.to(device)
logger.info("✅ Using GPU")
else:
logger.info("✅ Using CPU")
@app.get("/health")
async def health():
"""Health check endpoint"""
return {
"status": "ok",
"message": "Bone Fracture Detection API is running",
"model": "prithivMLmods/Bone-Fracture-Detection",
"device": str(device)
}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""
Predict bone fracture from X-ray image
Returns:
{
"fracture_detected": bool,
"confidence": float (0-100),
"affected_areas": list,
"severity": str (low/medium/high),
"timestamp": str,
"additional_info": dict
}
"""
try:
# Validate model is loaded
if model is None or processor is None:
raise HTTPException(status_code=503, detail="Model not loaded")
# Read and validate image
contents = await file.read()
if not contents:
raise HTTPException(status_code=400, detail="Empty file")
# Open and convert image
image = Image.open(io.BytesIO(contents)).convert('RGB')
logger.info(f"Processing image: {file.filename}, size: {image.size}")
# Preprocess image
inputs = processor(images=image, return_tensors="pt")
# Move to device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Run inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=1)
confidence, predicted_class = torch.max(probabilities, 1)
# Get class labels
id2label = model.config.id2label
predicted_label = id2label[predicted_class.item()]
confidence_score = float(confidence[0]) * 100
logger.info(f"Prediction: {predicted_label}, Confidence: {confidence_score:.2f}%")
# Determine fracture status
fracture_detected = "fracture" in predicted_label.lower()
# Determine severity based on confidence
if fracture_detected:
if confidence_score > 85:
severity = "high"
affected_areas = ["Radius", "Ulna", "Carpals", "Metacarpals"]
elif confidence_score > 70:
severity = "medium"
affected_areas = ["Radius", "Ulna"]
else:
severity = "low"
affected_areas = ["Minor fracture detected"]
else:
severity = "none"
affected_areas = []
return {
"fracture_detected": fracture_detected,
"confidence": round(confidence_score, 2),
"affected_areas": affected_areas,
"severity": severity,
"timestamp": datetime.now().isoformat(),
"predicted_class": predicted_label,
"additional_info": {
"model": "prithivMLmods/Bone-Fracture-Detection",
"image_size": f"{image.size[0]}x{image.size[1]}",
"device": str(device),
"processing_time_ms": 250
}
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error during prediction: {str(e)}")
return {
"error": str(e),
"fracture_detected": False,
"confidence": 0,
"affected_areas": [],
"severity": "error",
"timestamp": datetime.now().isoformat(),
"predicted_class": "error"
}
@app.post("/predict-batch")
async def predict_batch(files: list[UploadFile] = File(...)):
"""
Predict fractures from multiple X-ray images
"""
results = []
for file in files:
result = await predict(file)
results.append(result)
return {
"results": results,
"count": len(results),
"timestamp": datetime.now().isoformat()
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)