File size: 10,198 Bytes
2b5084a f61162c 2b5084a f61162c 2b5084a f61162c 2b5084a 106a88d 2b5084a f61162c 106a88d f61162c 2b5084a f61162c 2b5084a f61162c 2b5084a f61162c 106a88d 2b5084a f61162c 2b5084a f61162c 2b5084a f61162c 2b5084a f61162c 2b5084a f61162c 7406b0f f61162c 7406b0f f61162c 2b5084a f61162c 2b5084a f61162c 2b5084a f61162c 2b5084a f61162c 2b5084a f61162c 2b5084a f61162c 33e3bf9 2b5084a f61162c 2b5084a f61162c 2b5084a f61162c 2b5084a f61162c 2b5084a f61162c 2b5084a db245b9 33e3bf9 f61162c 2b5084a f61162c 2b5084a f61162c db245b9 f61162c db245b9 765672e f61162c 2b5084a f61162c 2b5084a db245b9 2b5084a f61162c 2b5084a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 |
#!/usr/bin/env python3
"""
HuggingFace Spaces App for ImageNet ResNet50 Classifier
Trained from scratch to 78%+ Top-1 accuracy
"""
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import json
# ============================================================================
# MODEL DEFINITION
# ============================================================================
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet50(nn.Module):
def __init__(self, num_classes=1000):
super().__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(Bottleneck, 64, 3, stride=1)
self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2)
self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2)
self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * 4, num_classes)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.maxpool(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.avgpool(out)
out = torch.flatten(out, 1)
out = self.fc(out)
return out
# ============================================================================
# MODEL LOADING
# ============================================================================
def load_model():
"""Load the trained model (CPU-optimized for HuggingFace)"""
model = ResNet50(num_classes=1000)
try:
# Try to load checkpoint
checkpoint_path = "best_model_final.pth" # Will be uploaded separately
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# Handle different checkpoint formats
if isinstance(checkpoint, dict):
if 'model' in checkpoint:
state_dict = checkpoint['model']
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
else:
state_dict = checkpoint
# Remove 'module.' prefix if present (from DataParallel)
new_state_dict = {}
for k, v in state_dict.items():
name = k.replace('module.', '') if k.startswith('module.') else k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
print(f"β
Model loaded successfully from {checkpoint_path}")
except Exception as e:
print(f"β οΈ Could not load checkpoint: {e}")
print("Using randomly initialized model for demo purposes")
model.eval()
return model
# ============================================================================
# IMAGE PREPROCESSING
# ============================================================================
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# ============================================================================
# IMAGENET CLASS LABELS
# ============================================================================
# Top 20 most common ImageNet classes for demo
IMAGENET_CLASSES = {
0: "tench", 1: "goldfish", 2: "great white shark", 3: "tiger shark",
4: "hammerhead", 5: "electric ray", 6: "stingray", 7: "cock",
8: "hen", 9: "ostrich", 10: "brambling", 11: "goldfinch",
12: "house finch", 13: "junco", 14: "indigo bunting", 15: "robin",
151: "Chihuahua", 207: "golden retriever", 281: "tabby cat",
282: "tiger cat", 283: "Persian cat", 285: "Egyptian cat",
291: "lion", 292: "tiger", 293: "jaguar", 294: "leopard",
404: "airliner", 407: "container ship", 468: "cab",
511: "convertible", 609: "jeep", 627: "limousine",
817: "sports car", 751: "racer", 779: "school bus",
555: "fire engine", 569: "garbage truck", 717: "pickup",
# Add more as needed
}
# Load full class names - MUST use the corrected mapping!
# This model was trained with folders named 0-999 (lexicographically sorted)
# NOT with standard ImageNet WordNet IDs
try:
with open('imagenet_classes_corrected.json', 'r') as f:
loaded_classes = json.load(f)
# Ensure it's a dict with string keys
if isinstance(loaded_classes, list):
IMAGENET_CLASSES = {str(i): name for i, name in enumerate(loaded_classes)}
else:
IMAGENET_CLASSES = loaded_classes
print(f"β
Loaded corrected ImageNet class mapping with {len(IMAGENET_CLASSES)} classes")
except FileNotFoundError:
print("β οΈ WARNING: imagenet_classes_corrected.json not found! Using fallback mapping.")
print(" Model predictions will be INCORRECT without the corrected mapping!")
except Exception as e:
print(f"β οΈ WARNING: Failed to load class mapping: {e}")
# ============================================================================
# INFERENCE FUNCTION
# ============================================================================
def predict(image):
"""
Predict ImageNet class for input image
Args:
image: PIL Image
Returns:
dict: Top-5 predictions with confidence scores
"""
if image is None:
return {"Error": 0.0, "Please upload an image": 0.0}
try:
# Preprocess
img_tensor = transform(image).unsqueeze(0) # Add batch dimension
# Inference
with torch.no_grad():
outputs = model(img_tensor)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# Get top 5 predictions
top5_prob, top5_indices = torch.topk(probabilities, 5)
# Format results - MUST be dict with string keys and float values
results = {}
for i in range(5):
idx = top5_indices[i].item()
prob = top5_prob[i].item()
class_name = IMAGENET_CLASSES.get(str(idx), f"Class {idx}")
results[class_name] = float(prob)
return results
except Exception as e:
# Return valid format even for errors
return {"Prediction Error": 0.0, f"Details: {str(e)[:50]}": 0.0}
# ============================================================================
# GRADIO INTERFACE
# ============================================================================
# Load model globally
print("Loading model...")
model = load_model()
print("Model loaded successfully!")
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# π₯ ImageNet ResNet50 Classifier
**Trained from scratch to 77%+ Top-1 accuracy on ImageNet!**
Upload any image and get top-5 predictions with confidence scores.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
predict_btn = gr.Button("Classify Image", variant="primary")
gr.Markdown("""
### π Tips:
- Works best with **clear, centered objects**
- Supports **1000 ImageNet classes** (animals, vehicles, objects, etc.)
- Try images from different categories!
""")
with gr.Column():
output = gr.Label(num_top_classes=5, label="Top-5 Predictions")
gr.Markdown("""
### π― Model Info:
- **Architecture:** ResNet50 (25.5M params)
- **Training:** From scratch (no pretrained weights)
- **Dataset:** ImageNet (1.2M images, 1000 classes)
- **Accuracy:** 77.09% Top-1 validation
### π Links:
- [GitHub Repository](https://github.com/Shwethaamrutha/TSAI-S9)
""")
# Example images
gr.Markdown("### πΌοΈ Try These Examples:")
gr.Examples(
examples=[
["GermanShephard.jpg"],
["Goldfish.jpg"],
["Tiger.jpg"],
["SilkyTerrier.avif"],
],
inputs=image_input,
outputs=output,
fn=predict,
cache_examples=False,
)
# Connect button
predict_btn.click(fn=predict, inputs=image_input, outputs=output)
# Launch
if __name__ == "__main__":
demo.launch()
|