Raushan2709's picture
Update app.py
661ac18 verified
import gradio as gr
import pickle
import numpy as np
import tensorflow as tf
from PIL import Image
import os
# --- Loading the Model and Class Indices ---
MODEL_PATH = 'fracture_detection_model.pkl'
CLASS_INDICES_PATH = 'class_indices.pkl'
# Check if model files exist before loading
if not os.path.exists(MODEL_PATH) or not os.path.exists(CLASS_INDICES_PATH):
raise RuntimeError(f"Model or class indices files not found. Please ensure '{MODEL_PATH}' and '{CLASS_INDICES_PATH}' are uploaded to your Space.")
# Load the trained model
try:
with open(MODEL_PATH, 'rb') as f:
model = pickle.load(f)
except Exception as e:
raise RuntimeError(f"Error loading the model: {e}")
# Load the class indices
try:
with open(CLASS_INDICES_PATH, 'rb') as f:
class_indices = pickle.load(f)
# Invert the dictionary to map index to class name
class_names = {v: k for k, v in class_indices.items()}
except Exception as e:
raise RuntimeError(f"Error loading class indices: {e}")
print("--- Model and class indices loaded successfully! ---")
# --- Prediction Function (Adapted for Gradio) ---
def predict_fracture(img):
"""
Takes a PIL Image from Gradio, preprocesses it,
and returns a dictionary of class probabilities.
"""
if img is None:
return "Please upload an image."
try:
# 1. Ensure image is in RGB format
if img.mode != "RGB":
img = img.convert("RGB")
# 2. Resize the image (matching your original 150x150 target)
img = img.resize((150, 150))
# 3. Convert image to numpy array and scale pixel values
img_array = tf.keras.preprocessing.image.img_to_array(img)
img_array = img_array / 255.0
# 4. Expand dimensions to create a batch of 1
img_batch = np.expand_dims(img_array, axis=0)
# 5. Make a prediction
prediction = model.predict(img_batch)
# 6. Format output for Gradio's Label component
# This creates a nice bar graph of results: {"Fractured": 0.85, "Normal": 0.15}
result = {}
for i, class_name in class_names.items():
result[class_name] = float(prediction[0][i])
return result
except Exception as e:
return f"Prediction failed: {str(e)}"
# --- Create the Gradio Web Interface ---
demo = gr.Interface(
fn=predict_fracture,
inputs=gr.Image(type="pil", label="Upload X-ray Image"),
# gr.Label automatically highlights the highest probability class
outputs=gr.Label(num_top_classes=3, label="Analysis Result"),
title="🩻 X-ray Fracture Detection System",
description="Upload a bone X-ray image to identify potential fractures. Powered by TensorFlow and Gradio.",
examples=[img for img in os.listdir('.') if img.endswith(('.jpg', '.png', '.jpeg'))][:2] if any(img.endswith(('.jpg', '.png', '.jpeg')) for img in os.listdir('.')) else None
)
# Launch the app (Hugging Face looks for this to start the container)
if __name__ == "__main__":
demo.launch()