| import fastapi
|
| from fastapi import File, UploadFile, HTTPException
|
| from fastapi.middleware.cors import CORSMiddleware
|
| import uvicorn
|
| import pickle
|
| import numpy as np
|
| import tensorflow as tf
|
| from PIL import Image
|
| import io
|
| import os
|
|
|
|
|
| app = fastapi.FastAPI(
|
| title="X-ray Fracture Detection API",
|
| description="An API to predict bone fractures from X-ray images.",
|
| version="1.0.0"
|
| )
|
|
|
|
|
|
|
|
|
| app.add_middleware(
|
| CORSMiddleware,
|
| allow_origins=["*"],
|
| allow_credentials=True,
|
| allow_methods=["*"],
|
| allow_headers=["*"],
|
| )
|
|
|
|
|
|
|
| MODEL_PATH = 'fracture_detection_model.pkl'
|
| CLASS_INDICES_PATH = 'class_indices.pkl'
|
|
|
|
|
| if not os.path.exists(MODEL_PATH) or not os.path.exists(CLASS_INDICES_PATH):
|
| raise RuntimeError(f"Model or class indices files not found. Please ensure '{MODEL_PATH}' and '{CLASS_INDICES_PATH}' are in the correct directory.")
|
|
|
|
|
| try:
|
| with open(MODEL_PATH, 'rb') as f:
|
| model = pickle.load(f)
|
| except Exception as e:
|
| raise RuntimeError(f"Error loading the model: {e}")
|
|
|
|
|
| try:
|
| with open(CLASS_INDICES_PATH, 'rb') as f:
|
| class_indices = pickle.load(f)
|
|
|
| class_names = {v: k for k, v in class_indices.items()}
|
| except Exception as e:
|
| raise RuntimeError(f"Error loading class indices: {e}")
|
|
|
| print("--- Model and class indices loaded successfully! ---")
|
|
|
|
|
| def preprocess_image(image_bytes: bytes, target_size=(150, 150)) -> np.ndarray:
|
| """
|
| Preprocesses the uploaded image to match the model's input requirements.
|
| """
|
| try:
|
|
|
| img = Image.open(io.BytesIO(image_bytes))
|
|
|
|
|
| if img.mode != "RGB":
|
| img = img.convert("RGB")
|
|
|
|
|
| img = img.resize(target_size)
|
|
|
|
|
| img_array = tf.keras.preprocessing.image.img_to_array(img)
|
| img_array = img_array / 255.0
|
|
|
|
|
| img_batch = np.expand_dims(img_array, axis=0)
|
| return img_batch
|
| except Exception as e:
|
|
|
| raise HTTPException(status_code=400, detail=f"Image preprocessing failed: {e}")
|
|
|
|
|
|
|
| @app.post("/predict")
|
| async def predict(file: UploadFile = File(...)):
|
| """
|
| Accepts an X-ray image file and returns the predicted fracture type.
|
| """
|
|
|
| image_bytes = await file.read()
|
|
|
|
|
| processed_image = preprocess_image(image_bytes)
|
|
|
|
|
| prediction = model.predict(processed_image)
|
|
|
|
|
|
|
| predicted_index = np.argmax(prediction[0])
|
|
|
| predicted_class = class_names[predicted_index]
|
|
|
| confidence = float(prediction[0][predicted_index])
|
|
|
|
|
| return {
|
| "prediction": predicted_class,
|
| "confidence": f"{confidence:.2f}"
|
| }
|
|
|
|
|
| @app.get("/")
|
| def read_root():
|
| return {"message": "Welcome to the Fracture Detection API. Please use the /docs endpoint for more information."}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|