File size: 4,380 Bytes
f177f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import fastapi
from fastapi import File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware # Added for frontend communication
import uvicorn
import pickle
import numpy as np
import tensorflow as tf
from PIL import Image
import io
import os

# --- Basic FastAPI App Setup ---
app = fastapi.FastAPI(
    title="X-ray Fracture Detection API",
    description="An API to predict bone fractures from X-ray images.",
    version="1.0.0"
)

# --- CORS (Cross-Origin Resource Sharing) Middleware ---
# This is the new section that fixes the "Failed to fetch" error.
# It allows your HTML frontend to communicate with this Python backend.
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allows all origins
    allow_credentials=True,
    allow_methods=["*"],  # Allows all methods (GET, POST, etc.)
    allow_headers=["*"],  # Allows all headers
)

# --- Loading the Model and Class Indices ---
# Note: Ensure these .pkl files are in the same directory as this script.
MODEL_PATH = 'fracture_detection_model.pkl'
CLASS_INDICES_PATH = 'class_indices.pkl'

# Check if model files exist before loading
if not os.path.exists(MODEL_PATH) or not os.path.exists(CLASS_INDICES_PATH):
    raise RuntimeError(f"Model or class indices files not found. Please ensure '{MODEL_PATH}' and '{CLASS_INDICES_PATH}' are in the correct directory.")

# Load the trained model
try:
    with open(MODEL_PATH, 'rb') as f:
        model = pickle.load(f)
except Exception as e:
    raise RuntimeError(f"Error loading the model: {e}")

# Load the class indices
try:
    with open(CLASS_INDICES_PATH, 'rb') as f:
        class_indices = pickle.load(f)
    # Invert the dictionary to map index to class name
    class_names = {v: k for k, v in class_indices.items()}
except Exception as e:
    raise RuntimeError(f"Error loading class indices: {e}")

print("--- Model and class indices loaded successfully! ---")

# --- Image Preprocessing Function ---
def preprocess_image(image_bytes: bytes, target_size=(150, 150)) -> np.ndarray:
    """

    Preprocesses the uploaded image to match the model's input requirements.

    """
    try:
        # Open the image from bytes
        img = Image.open(io.BytesIO(image_bytes))

        # Ensure image is in RGB format
        if img.mode != "RGB":
            img = img.convert("RGB")

        # Resize the image
        img = img.resize(target_size)

        # Convert image to numpy array and scale pixel values
        img_array = tf.keras.preprocessing.image.img_to_array(img)
        img_array = img_array / 255.0

        # Expand dimensions to create a batch of 1
        img_batch = np.expand_dims(img_array, axis=0)
        return img_batch
    except Exception as e:
        # Raise an HTTPException for bad image data
        raise HTTPException(status_code=400, detail=f"Image preprocessing failed: {e}")


# --- Prediction Endpoint ---
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    """

    Accepts an X-ray image file and returns the predicted fracture type.

    """
    # 1. Read the image file uploaded by the user
    image_bytes = await file.read()

    # 2. Preprocess the image to prepare it for the model
    processed_image = preprocess_image(image_bytes)

    # 3. Make a prediction using the loaded model
    prediction = model.predict(processed_image)

    # 4. Process the prediction result
    # Find the index of the highest probability
    predicted_index = np.argmax(prediction[0])
    # Get the corresponding class name
    predicted_class = class_names[predicted_index]
    # Get the confidence score
    confidence = float(prediction[0][predicted_index])

    # 5. Return the result in a JSON format
    return {
        "prediction": predicted_class,
        "confidence": f"{confidence:.2f}"
    }

# --- Root Endpoint ---
@app.get("/")
def read_root():
    return {"message": "Welcome to the Fracture Detection API. Please use the /docs endpoint for more information."}

# --- To run this application ---
# 1. Install necessary libraries: pip install fastapi "uvicorn[standard]" tensorflow numpy Pillow python-multipart
# 2. Save your .pkl files in the same directory as this script.
# 3. Run from your terminal: uvicorn main:app --reload