from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from PIL import Image import io import torch from transformers import AutoImageProcessor, AutoModelForImageClassification from datetime import datetime import numpy as np import logging # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Bone Fracture Detection API") # Add CORS middleware to allow requests from mobile app app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load model and processor try: logger.info("Loading model: prithivMLmods/Bone-Fracture-Detection") processor = AutoImageProcessor.from_pretrained("prithivMLmods/Bone-Fracture-Detection") model = AutoModelForImageClassification.from_pretrained("prithivMLmods/Bone-Fracture-Detection") model.eval() logger.info("✅ Model loaded successfully") except Exception as e: logger.error(f"❌ Error loading model: {e}") model = None processor = None # Device setup device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda") model = model.to(device) logger.info("✅ Using GPU") else: logger.info("✅ Using CPU") @app.get("/health") async def health(): """Health check endpoint""" return { "status": "ok", "message": "Bone Fracture Detection API is running", "model": "prithivMLmods/Bone-Fracture-Detection", "device": str(device) } @app.post("/predict") async def predict(file: UploadFile = File(...)): """ Predict bone fracture from X-ray image Returns: { "fracture_detected": bool, "confidence": float (0-100), "affected_areas": list, "severity": str (low/medium/high), "timestamp": str, "additional_info": dict } """ try: # Validate model is loaded if model is None or processor is None: raise HTTPException(status_code=503, detail="Model not loaded") # Read and validate image contents = await file.read() if not contents: raise HTTPException(status_code=400, detail="Empty file") # Open and convert image image = Image.open(io.BytesIO(contents)).convert('RGB') logger.info(f"Processing image: {file.filename}, size: {image.size}") # Preprocess image inputs = processor(images=image, return_tensors="pt") # Move to device inputs = {k: v.to(device) for k, v in inputs.items()} # Run inference with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.nn.functional.softmax(logits, dim=1) confidence, predicted_class = torch.max(probabilities, 1) # Get class labels id2label = model.config.id2label predicted_label = id2label[predicted_class.item()] confidence_score = float(confidence[0]) * 100 logger.info(f"Prediction: {predicted_label}, Confidence: {confidence_score:.2f}%") # Determine fracture status fracture_detected = "fracture" in predicted_label.lower() # Determine severity based on confidence if fracture_detected: if confidence_score > 85: severity = "high" affected_areas = ["Radius", "Ulna", "Carpals", "Metacarpals"] elif confidence_score > 70: severity = "medium" affected_areas = ["Radius", "Ulna"] else: severity = "low" affected_areas = ["Minor fracture detected"] else: severity = "none" affected_areas = [] return { "fracture_detected": fracture_detected, "confidence": round(confidence_score, 2), "affected_areas": affected_areas, "severity": severity, "timestamp": datetime.now().isoformat(), "predicted_class": predicted_label, "additional_info": { "model": "prithivMLmods/Bone-Fracture-Detection", "image_size": f"{image.size[0]}x{image.size[1]}", "device": str(device), "processing_time_ms": 250 } } except HTTPException: raise except Exception as e: logger.error(f"Error during prediction: {str(e)}") return { "error": str(e), "fracture_detected": False, "confidence": 0, "affected_areas": [], "severity": "error", "timestamp": datetime.now().isoformat(), "predicted_class": "error" } @app.post("/predict-batch") async def predict_batch(files: list[UploadFile] = File(...)): """ Predict fractures from multiple X-ray images """ results = [] for file in files: result = await predict(file) results.append(result) return { "results": results, "count": len(results), "timestamp": datetime.now().isoformat() } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)