| |
| """ |
| HuggingFace Spaces App for ImageNet ResNet50 Classifier |
| Trained from scratch to 78%+ Top-1 accuracy |
| """ |
|
|
| import gradio as gr |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision import transforms |
| from PIL import Image |
| import json |
|
|
| |
| |
| |
|
|
| class Bottleneck(nn.Module): |
| expansion = 4 |
| |
| def __init__(self, in_planes, planes, stride=1): |
| super().__init__() |
| self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) |
| self.bn1 = nn.BatchNorm2d(planes) |
| self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) |
| self.bn2 = nn.BatchNorm2d(planes) |
| self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) |
| self.bn3 = nn.BatchNorm2d(self.expansion * planes) |
| |
| self.shortcut = nn.Sequential() |
| if stride != 1 or in_planes != self.expansion * planes: |
| self.shortcut = nn.Sequential( |
| nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), |
| nn.BatchNorm2d(self.expansion * planes) |
| ) |
| |
| def forward(self, x): |
| out = F.relu(self.bn1(self.conv1(x))) |
| out = F.relu(self.bn2(self.conv2(out))) |
| out = self.bn3(self.conv3(out)) |
| out += self.shortcut(x) |
| out = F.relu(out) |
| return out |
|
|
|
|
| class ResNet50(nn.Module): |
| def __init__(self, num_classes=1000): |
| super().__init__() |
| self.in_planes = 64 |
| |
| self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) |
| self.bn1 = nn.BatchNorm2d(64) |
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
| |
| self.layer1 = self._make_layer(Bottleneck, 64, 3, stride=1) |
| self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2) |
| self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2) |
| self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2) |
| |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
| self.fc = nn.Linear(512 * 4, num_classes) |
| |
| def _make_layer(self, block, planes, num_blocks, stride): |
| strides = [stride] + [1] * (num_blocks - 1) |
| layers = [] |
| for stride in strides: |
| layers.append(block(self.in_planes, planes, stride)) |
| self.in_planes = planes * block.expansion |
| return nn.Sequential(*layers) |
| |
| def forward(self, x): |
| out = F.relu(self.bn1(self.conv1(x))) |
| out = self.maxpool(out) |
| out = self.layer1(out) |
| out = self.layer2(out) |
| out = self.layer3(out) |
| out = self.layer4(out) |
| out = self.avgpool(out) |
| out = torch.flatten(out, 1) |
| out = self.fc(out) |
| return out |
|
|
|
|
| |
| |
| |
|
|
| def load_model(): |
| """Load the trained model (CPU-optimized for HuggingFace)""" |
| model = ResNet50(num_classes=1000) |
| |
| try: |
| |
| checkpoint_path = "best_model_final.pth" |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') |
| |
| |
| if isinstance(checkpoint, dict): |
| if 'model' in checkpoint: |
| state_dict = checkpoint['model'] |
| elif 'state_dict' in checkpoint: |
| state_dict = checkpoint['state_dict'] |
| else: |
| state_dict = checkpoint |
| else: |
| state_dict = checkpoint |
| |
| |
| new_state_dict = {} |
| for k, v in state_dict.items(): |
| name = k.replace('module.', '') if k.startswith('module.') else k |
| new_state_dict[name] = v |
| |
| model.load_state_dict(new_state_dict) |
| print(f"✅ Model loaded successfully from {checkpoint_path}") |
| |
| except Exception as e: |
| print(f"⚠️ Could not load checkpoint: {e}") |
| print("Using randomly initialized model for demo purposes") |
| |
| model.eval() |
| return model |
|
|
|
|
| |
| |
| |
|
|
| transform = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
|
|
| |
| |
| |
|
|
| |
| IMAGENET_CLASSES = { |
| 0: "tench", 1: "goldfish", 2: "great white shark", 3: "tiger shark", |
| 4: "hammerhead", 5: "electric ray", 6: "stingray", 7: "cock", |
| 8: "hen", 9: "ostrich", 10: "brambling", 11: "goldfinch", |
| 12: "house finch", 13: "junco", 14: "indigo bunting", 15: "robin", |
| 151: "Chihuahua", 207: "golden retriever", 281: "tabby cat", |
| 282: "tiger cat", 283: "Persian cat", 285: "Egyptian cat", |
| 291: "lion", 292: "tiger", 293: "jaguar", 294: "leopard", |
| 404: "airliner", 407: "container ship", 468: "cab", |
| 511: "convertible", 609: "jeep", 627: "limousine", |
| 817: "sports car", 751: "racer", 779: "school bus", |
| 555: "fire engine", 569: "garbage truck", 717: "pickup", |
| |
| } |
|
|
| |
| |
| |
| try: |
| with open('imagenet_classes_corrected.json', 'r') as f: |
| loaded_classes = json.load(f) |
| |
| if isinstance(loaded_classes, list): |
| IMAGENET_CLASSES = {str(i): name for i, name in enumerate(loaded_classes)} |
| else: |
| IMAGENET_CLASSES = loaded_classes |
| print(f"✅ Loaded corrected ImageNet class mapping with {len(IMAGENET_CLASSES)} classes") |
| except FileNotFoundError: |
| print("⚠️ WARNING: imagenet_classes_corrected.json not found! Using fallback mapping.") |
| print(" Model predictions will be INCORRECT without the corrected mapping!") |
| except Exception as e: |
| print(f"⚠️ WARNING: Failed to load class mapping: {e}") |
|
|
|
|
| |
| |
| |
|
|
| def predict(image): |
| """ |
| Predict ImageNet class for input image |
| |
| Args: |
| image: PIL Image |
| |
| Returns: |
| dict: Top-5 predictions with confidence scores |
| """ |
| if image is None: |
| return {"Error": 0.0, "Please upload an image": 0.0} |
| |
| try: |
| |
| img_tensor = transform(image).unsqueeze(0) |
| |
| |
| with torch.no_grad(): |
| outputs = model(img_tensor) |
| probabilities = torch.nn.functional.softmax(outputs[0], dim=0) |
| |
| |
| top5_prob, top5_indices = torch.topk(probabilities, 5) |
| |
| |
| results = {} |
| for i in range(5): |
| idx = top5_indices[i].item() |
| prob = top5_prob[i].item() |
| class_name = IMAGENET_CLASSES.get(str(idx), f"Class {idx}") |
| results[class_name] = float(prob) |
| |
| return results |
| |
| except Exception as e: |
| |
| return {"Prediction Error": 0.0, f"Details: {str(e)[:50]}": 0.0} |
|
|
|
|
| |
| |
| |
|
|
| |
| print("Loading model...") |
| model = load_model() |
| print("Model loaded successfully!") |
|
|
| |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| gr.Markdown(""" |
| # 🔥 ImageNet ResNet50 Classifier |
| |
| **Trained from scratch to 77%+ Top-1 accuracy on ImageNet!** |
| |
| Upload any image and get top-5 predictions with confidence scores. |
| """) |
| |
| with gr.Row(): |
| with gr.Column(): |
| image_input = gr.Image(type="pil", label="Upload Image") |
| predict_btn = gr.Button("Classify Image", variant="primary") |
| |
| gr.Markdown(""" |
| ### 📝 Tips: |
| - Works best with **clear, centered objects** |
| - Supports **1000 ImageNet classes** (animals, vehicles, objects, etc.) |
| - Try images from different categories! |
| """) |
| |
| with gr.Column(): |
| output = gr.Label(num_top_classes=5, label="Top-5 Predictions") |
| |
| gr.Markdown(""" |
| ### 🎯 Model Info: |
| - **Architecture:** ResNet50 (25.5M params) |
| - **Training:** From scratch (no pretrained weights) |
| - **Dataset:** ImageNet (1.2M images, 1000 classes) |
| - **Accuracy:** 77.09% Top-1 validation |
| |
| ### 🔗 Links: |
| - [GitHub Repository](https://github.com/Shwethaamrutha/TSAI-S9) |
| """) |
| |
| |
| gr.Markdown("### 🖼️ Try These Examples:") |
| gr.Examples( |
| examples=[ |
| ["GermanShephard.jpg"], |
| ["Goldfish.jpg"], |
| ["Tiger.jpg"], |
| ["Eagle.jpg"], |
| ], |
| inputs=image_input, |
| outputs=output, |
| fn=predict, |
| cache_examples=False, |
| ) |
| |
| |
| predict_btn.click(fn=predict, inputs=image_input, outputs=output) |
| |
|
|
|
|
| |
| if __name__ == "__main__": |
| demo.launch() |
|
|
|
|