#!/usr/bin/env python3 """ HuggingFace Spaces App for ImageNet ResNet50 Classifier Trained from scratch to 78%+ Top-1 accuracy """ import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from PIL import Image import json # ============================================================================ # MODEL DEFINITION # ============================================================================ class Bottleneck(nn.Module): expansion = 4 def __init__(self, in_planes, planes, stride=1): super().__init__() self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(self.expansion * planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = F.relu(self.bn2(self.conv2(out))) out = self.bn3(self.conv3(out)) out += self.shortcut(x) out = F.relu(out) return out class ResNet50(nn.Module): def __init__(self, num_classes=1000): super().__init__() self.in_planes = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(Bottleneck, 64, 3, stride=1) self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2) self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2) self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * 4, num_classes) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.maxpool(out) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = self.avgpool(out) out = torch.flatten(out, 1) out = self.fc(out) return out # ============================================================================ # MODEL LOADING # ============================================================================ def load_model(): """Load the trained model (CPU-optimized for HuggingFace)""" model = ResNet50(num_classes=1000) try: # Try to load checkpoint checkpoint_path = "best_model_final.pth" # Will be uploaded separately checkpoint = torch.load(checkpoint_path, map_location='cpu') # Handle different checkpoint formats if isinstance(checkpoint, dict): if 'model' in checkpoint: state_dict = checkpoint['model'] elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint else: state_dict = checkpoint # Remove 'module.' prefix if present (from DataParallel) new_state_dict = {} for k, v in state_dict.items(): name = k.replace('module.', '') if k.startswith('module.') else k new_state_dict[name] = v model.load_state_dict(new_state_dict) print(f"✅ Model loaded successfully from {checkpoint_path}") except Exception as e: print(f"⚠️ Could not load checkpoint: {e}") print("Using randomly initialized model for demo purposes") model.eval() return model # ============================================================================ # IMAGE PREPROCESSING # ============================================================================ transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # ============================================================================ # IMAGENET CLASS LABELS # ============================================================================ # Top 20 most common ImageNet classes for demo IMAGENET_CLASSES = { 0: "tench", 1: "goldfish", 2: "great white shark", 3: "tiger shark", 4: "hammerhead", 5: "electric ray", 6: "stingray", 7: "cock", 8: "hen", 9: "ostrich", 10: "brambling", 11: "goldfinch", 12: "house finch", 13: "junco", 14: "indigo bunting", 15: "robin", 151: "Chihuahua", 207: "golden retriever", 281: "tabby cat", 282: "tiger cat", 283: "Persian cat", 285: "Egyptian cat", 291: "lion", 292: "tiger", 293: "jaguar", 294: "leopard", 404: "airliner", 407: "container ship", 468: "cab", 511: "convertible", 609: "jeep", 627: "limousine", 817: "sports car", 751: "racer", 779: "school bus", 555: "fire engine", 569: "garbage truck", 717: "pickup", # Add more as needed } # Load full class names - MUST use the corrected mapping! # This model was trained with folders named 0-999 (lexicographically sorted) # NOT with standard ImageNet WordNet IDs try: with open('imagenet_classes_corrected.json', 'r') as f: loaded_classes = json.load(f) # Ensure it's a dict with string keys if isinstance(loaded_classes, list): IMAGENET_CLASSES = {str(i): name for i, name in enumerate(loaded_classes)} else: IMAGENET_CLASSES = loaded_classes print(f"✅ Loaded corrected ImageNet class mapping with {len(IMAGENET_CLASSES)} classes") except FileNotFoundError: print("⚠️ WARNING: imagenet_classes_corrected.json not found! Using fallback mapping.") print(" Model predictions will be INCORRECT without the corrected mapping!") except Exception as e: print(f"⚠️ WARNING: Failed to load class mapping: {e}") # ============================================================================ # INFERENCE FUNCTION # ============================================================================ def predict(image): """ Predict ImageNet class for input image Args: image: PIL Image Returns: dict: Top-5 predictions with confidence scores """ if image is None: return {"Error": 0.0, "Please upload an image": 0.0} try: # Preprocess img_tensor = transform(image).unsqueeze(0) # Add batch dimension # Inference with torch.no_grad(): outputs = model(img_tensor) probabilities = torch.nn.functional.softmax(outputs[0], dim=0) # Get top 5 predictions top5_prob, top5_indices = torch.topk(probabilities, 5) # Format results - MUST be dict with string keys and float values results = {} for i in range(5): idx = top5_indices[i].item() prob = top5_prob[i].item() class_name = IMAGENET_CLASSES.get(str(idx), f"Class {idx}") results[class_name] = float(prob) return results except Exception as e: # Return valid format even for errors return {"Prediction Error": 0.0, f"Details: {str(e)[:50]}": 0.0} # ============================================================================ # GRADIO INTERFACE # ============================================================================ # Load model globally print("Loading model...") model = load_model() print("Model loaded successfully!") # Create Gradio interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🔥 ImageNet ResNet50 Classifier **Trained from scratch to 77%+ Top-1 accuracy on ImageNet!** Upload any image and get top-5 predictions with confidence scores. """) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image") predict_btn = gr.Button("Classify Image", variant="primary") gr.Markdown(""" ### 📝 Tips: - Works best with **clear, centered objects** - Supports **1000 ImageNet classes** (animals, vehicles, objects, etc.) - Try images from different categories! """) with gr.Column(): output = gr.Label(num_top_classes=5, label="Top-5 Predictions") gr.Markdown(""" ### 🎯 Model Info: - **Architecture:** ResNet50 (25.5M params) - **Training:** From scratch (no pretrained weights) - **Dataset:** ImageNet (1.2M images, 1000 classes) - **Accuracy:** 77.09% Top-1 validation ### 🔗 Links: - [GitHub Repository](https://github.com/Shwethaamrutha/TSAI-S9) """) # Example images gr.Markdown("### 🖼️ Try These Examples:") gr.Examples( examples=[ ["GermanShephard.jpg"], ["Goldfish.jpg"], ["Tiger.jpg"], ["SilkyTerrier.avif"], ], inputs=image_input, outputs=output, fn=predict, cache_examples=False, ) # Connect button predict_btn.click(fn=predict, inputs=image_input, outputs=output) # Launch if __name__ == "__main__": demo.launch()