Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| from models.custom_cnn import create_custom_cnn | |
| from models.resnet18 import load_resnet18 | |
| from utils.data_loader import get_cifar10_info | |
| # CIFAR-10 preprocessing | |
| CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) | |
| CIFAR10_STD = (0.2023, 0.1994, 0.2010) | |
| def preprocess_image(image): | |
| """Preprocess uploaded image for CIFAR-10 model.""" | |
| transform = transforms.Compose([ | |
| transforms.Resize((32, 32)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD) | |
| ]) | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| return transform(image).unsqueeze(0) | |
| def predict_image(image, model_choice="CustomCNN"): | |
| """ | |
| Predict class label for uploaded image. | |
| Args: | |
| image: PIL Image uploaded by user | |
| model_choice: Which model to use for prediction | |
| Returns: | |
| str: Formatted prediction result | |
| """ | |
| try: | |
| # Load model | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| if model_choice == "CustomCNN": | |
| model = create_custom_cnn() | |
| model_path = "best_model_custom.pth" | |
| else: | |
| model = load_resnet18() | |
| model_path = "best_model_resnet18.pth" | |
| # Load trained weights | |
| try: | |
| checkpoint = torch.load(model_path, map_location=device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| except FileNotFoundError: | |
| return f"β Model weights not found: {model_path}\nTrain the model first!" | |
| model.to(device) | |
| model.eval() | |
| # Preprocess image | |
| input_tensor = preprocess_image(image).to(device) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| probabilities = F.softmax(outputs, dim=1) | |
| confidence, predicted = torch.max(probabilities, 1) | |
| # Get class info | |
| cifar10_info = get_cifar10_info() | |
| class_names = cifar10_info['class_names'] | |
| predicted_class = class_names[predicted.item()] | |
| confidence_score = confidence.item() * 100 | |
| # Format result with top-3 predictions | |
| top3_prob, top3_indices = torch.topk(probabilities, 3) | |
| result = f"π― **Prediction: {predicted_class}** ({confidence_score:.1f}% confidence)\n\n" | |
| result += f"π **Top 3 Predictions:**\n" | |
| for i in range(3): | |
| class_name = class_names[top3_indices[0][i].item()] | |
| prob = top3_prob[0][i].item() * 100 | |
| result += f"{i+1}. {class_name}: {prob:.1f}%\n" | |
| result += f"\nπ§ **Model:** {model_choice}\n" | |
| result += f"π± **Device:** {device}" | |
| return result | |
| except Exception as e: | |
| return f"β **Error:** {str(e)}\n\nPlease ensure the image is valid and model is trained." | |
| def create_gradio_interface(): | |
| """Create and launch Gradio interface.""" | |
| # Custom CSS for better styling | |
| css = """ | |
| .gradio-container { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| } | |
| .output-text { | |
| font-family: 'Courier New', monospace; | |
| font-size: 14px; | |
| } | |
| """ | |
| # Create interface | |
| interface = gr.Interface( | |
| fn=predict_image, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Image", height=300), | |
| gr.Dropdown( | |
| choices=["CustomCNN", "ResNet18"], | |
| value="CustomCNN", | |
| label="Select Model" | |
| ) | |
| ], | |
| outputs=gr.Textbox( | |
| label="Prediction Result", | |
| lines=10, | |
| elem_classes=["output-text"] | |
| ), | |
| title="π§ CIFAR-10 CNN Benchmark", | |
| description=""" | |
| Upload an image to test our trained models! | |
| **Models:** | |
| - **CustomCNN**: Lightweight 3M parameter model | |
| - **ResNet18**: Standard 11M parameter baseline | |
| **Note:** Images will be resized to 32x32 pixels (CIFAR-10 format) | |
| **CIFAR-10 Classes:** airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck | |
| """, | |
| css=css, | |
| theme=gr.themes.Soft(), | |
| allow_flagging="never" | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| # Create and launch interface | |
| demo = create_gradio_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| show_error=True | |
| ) |