|
|
|
|
|
""" |
|
|
HuggingFace Spaces App for ImageNet ResNet50 Classifier |
|
|
Trained from scratch to 78%+ Top-1 accuracy |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import json |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Bottleneck(nn.Module): |
|
|
expansion = 4 |
|
|
|
|
|
def __init__(self, in_planes, planes, stride=1): |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) |
|
|
self.bn1 = nn.BatchNorm2d(planes) |
|
|
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) |
|
|
self.bn2 = nn.BatchNorm2d(planes) |
|
|
self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) |
|
|
self.bn3 = nn.BatchNorm2d(self.expansion * planes) |
|
|
|
|
|
self.shortcut = nn.Sequential() |
|
|
if stride != 1 or in_planes != self.expansion * planes: |
|
|
self.shortcut = nn.Sequential( |
|
|
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), |
|
|
nn.BatchNorm2d(self.expansion * planes) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
out = F.relu(self.bn1(self.conv1(x))) |
|
|
out = F.relu(self.bn2(self.conv2(out))) |
|
|
out = self.bn3(self.conv3(out)) |
|
|
out += self.shortcut(x) |
|
|
out = F.relu(out) |
|
|
return out |
|
|
|
|
|
|
|
|
class ResNet50(nn.Module): |
|
|
def __init__(self, num_classes=1000): |
|
|
super().__init__() |
|
|
self.in_planes = 64 |
|
|
|
|
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) |
|
|
self.bn1 = nn.BatchNorm2d(64) |
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
|
|
|
|
|
self.layer1 = self._make_layer(Bottleneck, 64, 3, stride=1) |
|
|
self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2) |
|
|
self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2) |
|
|
self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2) |
|
|
|
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
|
|
self.fc = nn.Linear(512 * 4, num_classes) |
|
|
|
|
|
def _make_layer(self, block, planes, num_blocks, stride): |
|
|
strides = [stride] + [1] * (num_blocks - 1) |
|
|
layers = [] |
|
|
for stride in strides: |
|
|
layers.append(block(self.in_planes, planes, stride)) |
|
|
self.in_planes = planes * block.expansion |
|
|
return nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
out = F.relu(self.bn1(self.conv1(x))) |
|
|
out = self.maxpool(out) |
|
|
out = self.layer1(out) |
|
|
out = self.layer2(out) |
|
|
out = self.layer3(out) |
|
|
out = self.layer4(out) |
|
|
out = self.avgpool(out) |
|
|
out = torch.flatten(out, 1) |
|
|
out = self.fc(out) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(): |
|
|
"""Load the trained model (CPU-optimized for HuggingFace)""" |
|
|
model = ResNet50(num_classes=1000) |
|
|
|
|
|
try: |
|
|
|
|
|
checkpoint_path = "best_model_final.pth" |
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
|
|
|
|
|
|
if isinstance(checkpoint, dict): |
|
|
if 'model' in checkpoint: |
|
|
state_dict = checkpoint['model'] |
|
|
elif 'state_dict' in checkpoint: |
|
|
state_dict = checkpoint['state_dict'] |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
|
|
|
|
|
|
new_state_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
name = k.replace('module.', '') if k.startswith('module.') else k |
|
|
new_state_dict[name] = v |
|
|
|
|
|
model.load_state_dict(new_state_dict) |
|
|
print(f"β
Model loaded successfully from {checkpoint_path}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β οΈ Could not load checkpoint: {e}") |
|
|
print("Using randomly initialized model for demo purposes") |
|
|
|
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize(256), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
IMAGENET_CLASSES = { |
|
|
0: "tench", 1: "goldfish", 2: "great white shark", 3: "tiger shark", |
|
|
4: "hammerhead", 5: "electric ray", 6: "stingray", 7: "cock", |
|
|
8: "hen", 9: "ostrich", 10: "brambling", 11: "goldfinch", |
|
|
12: "house finch", 13: "junco", 14: "indigo bunting", 15: "robin", |
|
|
151: "Chihuahua", 207: "golden retriever", 281: "tabby cat", |
|
|
282: "tiger cat", 283: "Persian cat", 285: "Egyptian cat", |
|
|
291: "lion", 292: "tiger", 293: "jaguar", 294: "leopard", |
|
|
404: "airliner", 407: "container ship", 468: "cab", |
|
|
511: "convertible", 609: "jeep", 627: "limousine", |
|
|
817: "sports car", 751: "racer", 779: "school bus", |
|
|
555: "fire engine", 569: "garbage truck", 717: "pickup", |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
with open('imagenet_classes_corrected.json', 'r') as f: |
|
|
loaded_classes = json.load(f) |
|
|
|
|
|
if isinstance(loaded_classes, list): |
|
|
IMAGENET_CLASSES = {str(i): name for i, name in enumerate(loaded_classes)} |
|
|
else: |
|
|
IMAGENET_CLASSES = loaded_classes |
|
|
print(f"β
Loaded corrected ImageNet class mapping with {len(IMAGENET_CLASSES)} classes") |
|
|
except FileNotFoundError: |
|
|
print("β οΈ WARNING: imagenet_classes_corrected.json not found! Using fallback mapping.") |
|
|
print(" Model predictions will be INCORRECT without the corrected mapping!") |
|
|
except Exception as e: |
|
|
print(f"β οΈ WARNING: Failed to load class mapping: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(image): |
|
|
""" |
|
|
Predict ImageNet class for input image |
|
|
|
|
|
Args: |
|
|
image: PIL Image |
|
|
|
|
|
Returns: |
|
|
dict: Top-5 predictions with confidence scores |
|
|
""" |
|
|
if image is None: |
|
|
return {"Error": 0.0, "Please upload an image": 0.0} |
|
|
|
|
|
try: |
|
|
|
|
|
img_tensor = transform(image).unsqueeze(0) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(img_tensor) |
|
|
probabilities = torch.nn.functional.softmax(outputs[0], dim=0) |
|
|
|
|
|
|
|
|
top5_prob, top5_indices = torch.topk(probabilities, 5) |
|
|
|
|
|
|
|
|
results = {} |
|
|
for i in range(5): |
|
|
idx = top5_indices[i].item() |
|
|
prob = top5_prob[i].item() |
|
|
class_name = IMAGENET_CLASSES.get(str(idx), f"Class {idx}") |
|
|
results[class_name] = float(prob) |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
return {"Prediction Error": 0.0, f"Details: {str(e)[:50]}": 0.0} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading model...") |
|
|
model = load_model() |
|
|
print("Model loaded successfully!") |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# π₯ ImageNet ResNet50 Classifier |
|
|
|
|
|
**Trained from scratch to 77%+ Top-1 accuracy on ImageNet!** |
|
|
|
|
|
Upload any image and get top-5 predictions with confidence scores. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
image_input = gr.Image(type="pil", label="Upload Image") |
|
|
predict_btn = gr.Button("Classify Image", variant="primary") |
|
|
|
|
|
gr.Markdown(""" |
|
|
### π Tips: |
|
|
- Works best with **clear, centered objects** |
|
|
- Supports **1000 ImageNet classes** (animals, vehicles, objects, etc.) |
|
|
- Try images from different categories! |
|
|
""") |
|
|
|
|
|
with gr.Column(): |
|
|
output = gr.Label(num_top_classes=5, label="Top-5 Predictions") |
|
|
|
|
|
gr.Markdown(""" |
|
|
### π― Model Info: |
|
|
- **Architecture:** ResNet50 (25.5M params) |
|
|
- **Training:** From scratch (no pretrained weights) |
|
|
- **Dataset:** ImageNet (1.2M images, 1000 classes) |
|
|
- **Accuracy:** 77.09% Top-1 validation |
|
|
|
|
|
### π Links: |
|
|
- [GitHub Repository](https://github.com/Shwethaamrutha/TSAI-S9) |
|
|
""") |
|
|
|
|
|
|
|
|
gr.Markdown("### πΌοΈ Try These Examples:") |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["GermanShephard.jpg"], |
|
|
["Goldfish.jpg"], |
|
|
["Tiger.jpg"], |
|
|
["SilkyTerrier.avif"], |
|
|
], |
|
|
inputs=image_input, |
|
|
outputs=output, |
|
|
fn=predict, |
|
|
cache_examples=False, |
|
|
) |
|
|
|
|
|
|
|
|
predict_btn.click(fn=predict, inputs=image_input, outputs=output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|
|
|
|