Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import onnxruntime as ort | |
| import numpy as np | |
| from PIL import Image | |
| # --- 1. SCRIPT CONFIGURATION --- | |
| # Define the path to your ONNX model and the class labels | |
| MODEL_PATH = 'snap-attractiveness-classifier.onnx' | |
| # This should be a list of your class names in the order your model was trained | |
| # Example for a simple cat/dog classifier | |
| CLASS_NAMES = ['Attractive'] | |
| # --- 2. LOAD THE ONNX MODEL AND CREATE AN INFERENCE SESSION --- | |
| # This is done once when the script starts | |
| try: | |
| session = ort.InferenceSession(MODEL_PATH) | |
| # Get the input name from the model's metadata | |
| input_name = session.get_inputs()[0].name | |
| # Get the model's expected input shape | |
| # This will be something like [1, 3, 224, 224] | |
| input_shape = session.get_inputs()[0].shape | |
| print(f"✅ Model loaded successfully. Input name: {input_name}, Input shape: {input_shape}") | |
| except Exception as e: | |
| print(f"❌ Error loading the ONNX model: {e}") | |
| session = None | |
| input_name = None | |
| input_shape = None | |
| # --- 3. DEFINE THE PREDICTION FUNCTION --- | |
| def predict(image): | |
| """ | |
| This function takes a PIL image, preprocesses it, runs inference, | |
| and post-processes the output. | |
| """ | |
| if session is None: | |
| return {"error": "Model not loaded. Please check the logs."} | |
| # --- Preprocessing --- | |
| # 1. Resize the image to the model's expected size (e.g., 224x224) | |
| # The input_shape is [batch, channels, height, width] | |
| img_height, img_width = input_shape[2], input_shape[3] | |
| image = image.resize((img_width, img_height), Image.Resampling.LANCZOS) | |
| inputshape = [1, 3, 64, 64] | |
| # 2. Convert the image to a NumPy array and normalize | |
| # Standard normalization for ImageNet models | |
| image_data = np.array(image).astype(np.float32) / 255.0 | |
| # Define mean and std also as float32 to prevent upcasting | |
| mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) | |
| std = np.array([0.229, 0.224, 0.225], dtype=np.float32) | |
| normalized_data = (image_data - mean) / std | |
| transposed_data = normalized_data.transpose(2, 0, 1) | |
| input_tensor = np.expand_dims(transposed_data, axis=0) | |
| # --- THE FIX --- | |
| # Ensure the final tensor is explicitly float32 before sending it to the session. | |
| # This is the most direct and guaranteed way to fix the error. | |
| input_tensor = input_tensor.astype(np.float32) | |
| # --- Inference --- | |
| # Run the model | |
| results = session.run(None, {input_name: input_tensor}) | |
| # --- Post-processing --- | |
| # 1. The 'results' is a list of outputs. Get the first one (classification scores). | |
| prediction_scores = results[0][0] # Squeeze out the batch dimension | |
| # 2. Apply softmax to convert scores to probabilities | |
| exp_scores = np.exp(prediction_scores - np.max(prediction_scores)) | |
| probabilities = exp_scores / exp_scores.sum() | |
| # 3. Create a dictionary of labels and their probabilities | |
| confidences = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(len(CLASS_NAMES))} | |
| return confidences | |
| # --- 4. CREATE THE GRADIO INTERFACE --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🖼️ ONNX Image Classifier") | |
| gr.Markdown("Upload an image and the model will predict its class.") | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| label_output = gr.Label(num_top_classes=len(CLASS_NAMES), label="Predictions") | |
| submit_button = gr.Button("Classify Image", variant="primary") | |
| submit_button.click( | |
| fn=predict, | |
| inputs=image_input, | |
| outputs=label_output | |
| ) | |
| gr.Examples( | |
| examples=[], # Add path to example images if you have them | |
| inputs=image_input, | |
| outputs=label_output, | |
| fn=predict | |
| ) | |
| # --- 5. LAUNCH THE APP --- | |
| if __name__ == "__main__": | |
| demo.launch(share=True, debug=True) |