grfdjiwsd's picture
Update app.py
fabc462 verified
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
# --- 1. SCRIPT CONFIGURATION ---
# Define the path to your ONNX model and the class labels
MODEL_PATH = 'snap-attractiveness-classifier.onnx'
# This should be a list of your class names in the order your model was trained
# Example for a simple cat/dog classifier
CLASS_NAMES = ['Attractive']
# --- 2. LOAD THE ONNX MODEL AND CREATE AN INFERENCE SESSION ---
# This is done once when the script starts
try:
session = ort.InferenceSession(MODEL_PATH)
# Get the input name from the model's metadata
input_name = session.get_inputs()[0].name
# Get the model's expected input shape
# This will be something like [1, 3, 224, 224]
input_shape = session.get_inputs()[0].shape
print(f"✅ Model loaded successfully. Input name: {input_name}, Input shape: {input_shape}")
except Exception as e:
print(f"❌ Error loading the ONNX model: {e}")
session = None
input_name = None
input_shape = None
# --- 3. DEFINE THE PREDICTION FUNCTION ---
def predict(image):
"""
This function takes a PIL image, preprocesses it, runs inference,
and post-processes the output.
"""
if session is None:
return {"error": "Model not loaded. Please check the logs."}
# --- Preprocessing ---
# 1. Resize the image to the model's expected size (e.g., 224x224)
# The input_shape is [batch, channels, height, width]
img_height, img_width = input_shape[2], input_shape[3]
image = image.resize((img_width, img_height), Image.Resampling.LANCZOS)
inputshape = [1, 3, 64, 64]
# 2. Convert the image to a NumPy array and normalize
# Standard normalization for ImageNet models
image_data = np.array(image).astype(np.float32) / 255.0
# Define mean and std also as float32 to prevent upcasting
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
normalized_data = (image_data - mean) / std
transposed_data = normalized_data.transpose(2, 0, 1)
input_tensor = np.expand_dims(transposed_data, axis=0)
# --- THE FIX ---
# Ensure the final tensor is explicitly float32 before sending it to the session.
# This is the most direct and guaranteed way to fix the error.
input_tensor = input_tensor.astype(np.float32)
# --- Inference ---
# Run the model
results = session.run(None, {input_name: input_tensor})
# --- Post-processing ---
# 1. The 'results' is a list of outputs. Get the first one (classification scores).
prediction_scores = results[0][0] # Squeeze out the batch dimension
# 2. Apply softmax to convert scores to probabilities
exp_scores = np.exp(prediction_scores - np.max(prediction_scores))
probabilities = exp_scores / exp_scores.sum()
# 3. Create a dictionary of labels and their probabilities
confidences = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(len(CLASS_NAMES))}
return confidences
# --- 4. CREATE THE GRADIO INTERFACE ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🖼️ ONNX Image Classifier")
gr.Markdown("Upload an image and the model will predict its class.")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Image")
label_output = gr.Label(num_top_classes=len(CLASS_NAMES), label="Predictions")
submit_button = gr.Button("Classify Image", variant="primary")
submit_button.click(
fn=predict,
inputs=image_input,
outputs=label_output
)
gr.Examples(
examples=[], # Add path to example images if you have them
inputs=image_input,
outputs=label_output,
fn=predict
)
# --- 5. LAUNCH THE APP ---
if __name__ == "__main__":
demo.launch(share=True, debug=True)