Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from transformers import AutoFeatureExtractor, AutoModelForImageClassification | |
| # Use a smaller, more efficient model | |
| model_name = "microsoft/resnet-18" # Smaller model that should work with Hugging Face constraints | |
| # Load model and feature extractor | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
| model = AutoModelForImageClassification.from_pretrained(model_name) | |
| # Function to classify image | |
| def classify_image(image): | |
| if image is None: | |
| return "No image provided", None | |
| try: | |
| # Process image | |
| inputs = feature_extractor(images=image, return_tensors="pt") | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Get predicted class | |
| predicted_class_idx = logits.argmax(-1).item() | |
| predicted_class = model.config.id2label[predicted_class_idx] | |
| # Get top 5 predictions | |
| probs = torch.nn.functional.softmax(logits, dim=-1)[0] | |
| top5_prob, top5_indices = torch.topk(probs, 5) | |
| # Create plot for visualization | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| # Get class names and probabilities | |
| classes = [model.config.id2label[idx.item()] for idx in top5_indices] | |
| probabilities = [prob.item() * 100 for prob in top5_prob] | |
| # Create horizontal bar chart | |
| bars = ax.barh(classes, probabilities, color='#4C72B0') | |
| ax.set_xlabel('Probability (%)') | |
| ax.set_title('Top 5 Predictions') | |
| # Add percentage labels | |
| for i, bar in enumerate(bars): | |
| width = bar.get_width() | |
| ax.text(width + 1, bar.get_y() + bar.get_height()/2, | |
| f'{probabilities[i]:.1f}%', | |
| va='center', fontsize=10) | |
| # Improve layout | |
| plt.tight_layout() | |
| return predicted_class, fig | |
| except Exception as e: | |
| return f"Error: {str(e)}", None | |
| # Create Gradio interface with simpler structure | |
| demo = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[ | |
| gr.Textbox(label="Prediction"), | |
| gr.Plot(label="Confidence Levels") | |
| ], | |
| title="🖼️ Image Classification Tool", | |
| description="Upload an image to see what the AI recognizes in it!", | |
| allow_flagging="never", | |
| examples=[], # No examples to avoid dependencies | |
| theme=gr.themes.Soft() | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |