File size: 4,380 Bytes
f177f7b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | import fastapi
from fastapi import File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware # Added for frontend communication
import uvicorn
import pickle
import numpy as np
import tensorflow as tf
from PIL import Image
import io
import os
# --- Basic FastAPI App Setup ---
app = fastapi.FastAPI(
title="X-ray Fracture Detection API",
description="An API to predict bone fractures from X-ray images.",
version="1.0.0"
)
# --- CORS (Cross-Origin Resource Sharing) Middleware ---
# This is the new section that fixes the "Failed to fetch" error.
# It allows your HTML frontend to communicate with this Python backend.
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods (GET, POST, etc.)
allow_headers=["*"], # Allows all headers
)
# --- Loading the Model and Class Indices ---
# Note: Ensure these .pkl files are in the same directory as this script.
MODEL_PATH = 'fracture_detection_model.pkl'
CLASS_INDICES_PATH = 'class_indices.pkl'
# Check if model files exist before loading
if not os.path.exists(MODEL_PATH) or not os.path.exists(CLASS_INDICES_PATH):
raise RuntimeError(f"Model or class indices files not found. Please ensure '{MODEL_PATH}' and '{CLASS_INDICES_PATH}' are in the correct directory.")
# Load the trained model
try:
with open(MODEL_PATH, 'rb') as f:
model = pickle.load(f)
except Exception as e:
raise RuntimeError(f"Error loading the model: {e}")
# Load the class indices
try:
with open(CLASS_INDICES_PATH, 'rb') as f:
class_indices = pickle.load(f)
# Invert the dictionary to map index to class name
class_names = {v: k for k, v in class_indices.items()}
except Exception as e:
raise RuntimeError(f"Error loading class indices: {e}")
print("--- Model and class indices loaded successfully! ---")
# --- Image Preprocessing Function ---
def preprocess_image(image_bytes: bytes, target_size=(150, 150)) -> np.ndarray:
"""
Preprocesses the uploaded image to match the model's input requirements.
"""
try:
# Open the image from bytes
img = Image.open(io.BytesIO(image_bytes))
# Ensure image is in RGB format
if img.mode != "RGB":
img = img.convert("RGB")
# Resize the image
img = img.resize(target_size)
# Convert image to numpy array and scale pixel values
img_array = tf.keras.preprocessing.image.img_to_array(img)
img_array = img_array / 255.0
# Expand dimensions to create a batch of 1
img_batch = np.expand_dims(img_array, axis=0)
return img_batch
except Exception as e:
# Raise an HTTPException for bad image data
raise HTTPException(status_code=400, detail=f"Image preprocessing failed: {e}")
# --- Prediction Endpoint ---
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""
Accepts an X-ray image file and returns the predicted fracture type.
"""
# 1. Read the image file uploaded by the user
image_bytes = await file.read()
# 2. Preprocess the image to prepare it for the model
processed_image = preprocess_image(image_bytes)
# 3. Make a prediction using the loaded model
prediction = model.predict(processed_image)
# 4. Process the prediction result
# Find the index of the highest probability
predicted_index = np.argmax(prediction[0])
# Get the corresponding class name
predicted_class = class_names[predicted_index]
# Get the confidence score
confidence = float(prediction[0][predicted_index])
# 5. Return the result in a JSON format
return {
"prediction": predicted_class,
"confidence": f"{confidence:.2f}"
}
# --- Root Endpoint ---
@app.get("/")
def read_root():
return {"message": "Welcome to the Fracture Detection API. Please use the /docs endpoint for more information."}
# --- To run this application ---
# 1. Install necessary libraries: pip install fastapi "uvicorn[standard]" tensorflow numpy Pillow python-multipart
# 2. Save your .pkl files in the same directory as this script.
# 3. Run from your terminal: uvicorn main:app --reload
|