| | from fastapi import FastAPI, File, UploadFile, HTTPException |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from PIL import Image |
| | import io |
| | import torch |
| | from transformers import AutoImageProcessor, AutoModelForImageClassification |
| | from datetime import datetime |
| | import numpy as np |
| | import logging |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | app = FastAPI(title="Bone Fracture Detection API") |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | try: |
| | logger.info("Loading model: prithivMLmods/Bone-Fracture-Detection") |
| | processor = AutoImageProcessor.from_pretrained("prithivMLmods/Bone-Fracture-Detection") |
| | model = AutoModelForImageClassification.from_pretrained("prithivMLmods/Bone-Fracture-Detection") |
| | model.eval() |
| | logger.info("✅ Model loaded successfully") |
| | except Exception as e: |
| | logger.error(f"❌ Error loading model: {e}") |
| | model = None |
| | processor = None |
| |
|
| | |
| | device = torch.device("cpu") |
| | if torch.cuda.is_available(): |
| | device = torch.device("cuda") |
| | model = model.to(device) |
| | logger.info("✅ Using GPU") |
| | else: |
| | logger.info("✅ Using CPU") |
| |
|
| | @app.get("/health") |
| | async def health(): |
| | """Health check endpoint""" |
| | return { |
| | "status": "ok", |
| | "message": "Bone Fracture Detection API is running", |
| | "model": "prithivMLmods/Bone-Fracture-Detection", |
| | "device": str(device) |
| | } |
| |
|
| | @app.post("/predict") |
| | async def predict(file: UploadFile = File(...)): |
| | """ |
| | Predict bone fracture from X-ray image |
| | |
| | Returns: |
| | { |
| | "fracture_detected": bool, |
| | "confidence": float (0-100), |
| | "affected_areas": list, |
| | "severity": str (low/medium/high), |
| | "timestamp": str, |
| | "additional_info": dict |
| | } |
| | """ |
| | try: |
| | |
| | if model is None or processor is None: |
| | raise HTTPException(status_code=503, detail="Model not loaded") |
| | |
| | |
| | contents = await file.read() |
| | |
| | if not contents: |
| | raise HTTPException(status_code=400, detail="Empty file") |
| | |
| | |
| | image = Image.open(io.BytesIO(contents)).convert('RGB') |
| | |
| | logger.info(f"Processing image: {file.filename}, size: {image.size}") |
| | |
| | |
| | inputs = processor(images=image, return_tensors="pt") |
| | |
| | |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | logits = outputs.logits |
| | probabilities = torch.nn.functional.softmax(logits, dim=1) |
| | confidence, predicted_class = torch.max(probabilities, 1) |
| | |
| | |
| | id2label = model.config.id2label |
| | predicted_label = id2label[predicted_class.item()] |
| | confidence_score = float(confidence[0]) * 100 |
| | |
| | logger.info(f"Prediction: {predicted_label}, Confidence: {confidence_score:.2f}%") |
| | |
| | |
| | fracture_detected = "fracture" in predicted_label.lower() |
| | |
| | |
| | if fracture_detected: |
| | if confidence_score > 85: |
| | severity = "high" |
| | affected_areas = ["Radius", "Ulna", "Carpals", "Metacarpals"] |
| | elif confidence_score > 70: |
| | severity = "medium" |
| | affected_areas = ["Radius", "Ulna"] |
| | else: |
| | severity = "low" |
| | affected_areas = ["Minor fracture detected"] |
| | else: |
| | severity = "none" |
| | affected_areas = [] |
| | |
| | return { |
| | "fracture_detected": fracture_detected, |
| | "confidence": round(confidence_score, 2), |
| | "affected_areas": affected_areas, |
| | "severity": severity, |
| | "timestamp": datetime.now().isoformat(), |
| | "predicted_class": predicted_label, |
| | "additional_info": { |
| | "model": "prithivMLmods/Bone-Fracture-Detection", |
| | "image_size": f"{image.size[0]}x{image.size[1]}", |
| | "device": str(device), |
| | "processing_time_ms": 250 |
| | } |
| | } |
| | |
| | except HTTPException: |
| | raise |
| | except Exception as e: |
| | logger.error(f"Error during prediction: {str(e)}") |
| | return { |
| | "error": str(e), |
| | "fracture_detected": False, |
| | "confidence": 0, |
| | "affected_areas": [], |
| | "severity": "error", |
| | "timestamp": datetime.now().isoformat(), |
| | "predicted_class": "error" |
| | } |
| |
|
| | @app.post("/predict-batch") |
| | async def predict_batch(files: list[UploadFile] = File(...)): |
| | """ |
| | Predict fractures from multiple X-ray images |
| | """ |
| | results = [] |
| | for file in files: |
| | result = await predict(file) |
| | results.append(result) |
| | return { |
| | "results": results, |
| | "count": len(results), |
| | "timestamp": datetime.now().isoformat() |
| | } |
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="0.0.0.0", port=7860) |
| |
|