Santhosh V commited on
Commit
5008b38
Β·
1 Parent(s): 8826d8a

Add CIFAR-100 ResNet-18 Gradio app with 77.45% accuracy model

Browse files
Files changed (4) hide show
  1. .gitignore +46 -0
  2. README.md +56 -6
  3. app.py +214 -0
  4. requirements.txt +6 -0
.gitignore ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ *.so
5
+ .Python
6
+ build/
7
+ develop-eggs/
8
+ dist/
9
+ downloads/
10
+ eggs/
11
+ .eggs/
12
+ lib/
13
+ lib64/
14
+ parts/
15
+ sdist/
16
+ var/
17
+ wheels/
18
+ *.egg-info/
19
+ .installed.cfg
20
+ *.egg
21
+
22
+ # PyTorch
23
+ *.pth
24
+ *.pt
25
+
26
+ # Jupyter Notebook
27
+ .ipynb_checkpoints
28
+
29
+ # Environment
30
+ .env
31
+ .venv
32
+ env/
33
+ venv/
34
+
35
+ # IDE
36
+ .vscode/
37
+ .idea/
38
+
39
+ # OS
40
+ .DS_Store
41
+ Thumbs.db
42
+
43
+ # Temporary files
44
+ *.tmp
45
+ *.temp
46
+ *.log
README.md CHANGED
@@ -1,13 +1,63 @@
1
  ---
2
- title: ERA V4 S8 Assignment
3
- emoji: ⚑
4
- colorFrom: pink
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
- short_description: ERA V4 S8 Assignment
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: πŸ† CIFAR-100 ResNet-18 Classifier
3
+ emoji: 🎯
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
+ short_description: CIFAR-100 ResNet-18 model achieving 77.45% accuracy - Upload images for instant classification!
11
+ license: mit
12
  ---
13
 
14
+ # πŸ† CIFAR-100 ResNet-18 Classifier - 77.45% Accuracy
15
+
16
+ **Upload an image to classify it into one of 100 CIFAR-100 categories!**
17
+
18
+ ## 🎯 Model Performance
19
+
20
+ | Metric | Target | Achieved | Status |
21
+ |--------|--------|----------|--------|
22
+ | πŸ… **Test Accuracy** | 73% | **77.45%** | βœ… **+4.45%** |
23
+ | πŸ“¦ **Parameters** | ~11M | **11.22M** | βœ… **Optimal** |
24
+ | ⏱️ **Training Time** | 100 epochs | **49 minutes** | ⚑ **Fast** |
25
+ | 🎯 **Target Achievement** | Epoch 100 | **Epoch 58** | βœ… **58% through** |
26
+
27
+ ## πŸ—οΈ Model Architecture
28
+
29
+ - **ResNet-18** with BasicBlocks optimized for CIFAR-100
30
+ - **11.22M parameters** with 133-pixel receptive field
31
+ - **Advanced augmentation** pipeline (Albumentations + Mixup + CutMix)
32
+ - **OneCycle scheduler** for optimal learning rate progression
33
+
34
+ ## πŸ† Top Performing Classes
35
+
36
+ | Rank | Class | Accuracy | Performance |
37
+ |------|-------|----------|-------------|
38
+ | 1 | **wardrobe** | 97.00% | πŸ† Exceptional |
39
+ | 2 | **motorcycle** | 93.00% | πŸ₯ˆ Excellent |
40
+ | 3 | **bicycle** | 93.00% | πŸ₯‰ Excellent |
41
+ | 4 | **aquarium_fish** | 92.00% | ⭐ Strong |
42
+
43
+ ## πŸ“š CIFAR-100 Categories
44
+
45
+ The model classifies images into **100 fine-grained categories** across **20 superclasses**:
46
+
47
+ - **Animals:** mammals, fish, insects, reptiles
48
+ - **Vehicles:** cars, trucks, motorcycles, bicycles
49
+ - **Household:** furniture, electrical devices, containers
50
+ - **Nature:** trees, flowers, natural landscapes
51
+ - **People:** different age groups and genders
52
+
53
+ ## πŸš€ Usage
54
+
55
+ Simply upload an image and get instant predictions with confidence scores for the top 5 most likely classes.
56
+
57
+ ## πŸ“– Documentation
58
+
59
+ For complete technical details, training logs, and model analysis, visit the [GitHub Repository](https://github.com/santhoshv6/era_v4_s8_assignment).
60
+
61
+ ---
62
+
63
+ **Model trained as part of ERA V4 Course Session 8 - Deep Learning Specialization**
app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms as transforms
6
+ from PIL import Image
7
+ import numpy as np
8
+ import requests
9
+ from io import BytesIO
10
+
11
+ # CIFAR-100 class names
12
+ CIFAR100_CLASSES = [
13
+ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
14
+ 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
15
+ 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
16
+ 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
17
+ 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
18
+ 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
19
+ 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
20
+ 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
21
+ 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
22
+ 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea',
23
+ 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
24
+ 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
25
+ 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
26
+ 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman',
27
+ 'worm'
28
+ ]
29
+
30
+ class BasicBlock(nn.Module):
31
+ expansion = 1
32
+
33
+ def __init__(self, in_planes, planes, stride=1):
34
+ super(BasicBlock, self).__init__()
35
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
36
+ self.bn1 = nn.BatchNorm2d(planes)
37
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
38
+ self.bn2 = nn.BatchNorm2d(planes)
39
+
40
+ self.shortcut = nn.Sequential()
41
+ if stride != 1 or in_planes != self.expansion*planes:
42
+ self.shortcut = nn.Sequential(
43
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
44
+ nn.BatchNorm2d(self.expansion*planes)
45
+ )
46
+
47
+ def forward(self, x):
48
+ out = F.relu(self.bn1(self.conv1(x)))
49
+ out = self.bn2(self.conv2(out))
50
+ out += self.shortcut(x)
51
+ out = F.relu(out)
52
+ return out
53
+
54
+ class ResNet18(nn.Module):
55
+ def __init__(self, num_classes=100):
56
+ super(ResNet18, self).__init__()
57
+ self.in_planes = 64
58
+
59
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
60
+ self.bn1 = nn.BatchNorm2d(64)
61
+ self.layer1 = self._make_layer(BasicBlock, 64, 2, stride=1)
62
+ self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2)
63
+ self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2)
64
+ self.layer4 = self._make_layer(BasicBlock, 512, 2, stride=2)
65
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
66
+ self.linear = nn.Linear(512*BasicBlock.expansion, num_classes)
67
+
68
+ def _make_layer(self, block, planes, num_blocks, stride):
69
+ strides = [stride] + [1]*(num_blocks-1)
70
+ layers = []
71
+ for stride in strides:
72
+ layers.append(block(self.in_planes, planes, stride))
73
+ self.in_planes = planes * block.expansion
74
+ return nn.Sequential(*layers)
75
+
76
+ def forward(self, x):
77
+ out = F.relu(self.bn1(self.conv1(x)))
78
+ out = self.layer1(out)
79
+ out = self.layer2(out)
80
+ out = self.layer3(out)
81
+ out = self.layer4(out)
82
+ out = self.avgpool(out)
83
+ out = out.view(out.size(0), -1)
84
+ out = self.linear(out)
85
+ return out
86
+
87
+ # Initialize model
88
+ model = ResNet18(num_classes=100)
89
+
90
+ # Load the pre-trained model
91
+ @torch.no_grad()
92
+ def load_model():
93
+ try:
94
+ # Try to download the model from your GitHub releases
95
+ model_url = "https://github.com/santhoshv6/era_v4_s8_assignment/releases/download/v1.0/model_best.pth"
96
+ response = requests.get(model_url)
97
+ response.raise_for_status()
98
+
99
+ # Load the model state dict
100
+ checkpoint = torch.load(BytesIO(response.content), map_location='cpu')
101
+ model.load_state_dict(checkpoint['state_dict'])
102
+ model.eval()
103
+ return True
104
+ except Exception as e:
105
+ print(f"Error loading model: {e}")
106
+ return False
107
+
108
+ # Define image preprocessing
109
+ transform = transforms.Compose([
110
+ transforms.Resize((32, 32)),
111
+ transforms.ToTensor(),
112
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
113
+ ])
114
+
115
+ def predict(image):
116
+ """
117
+ Predict the class of an input image using the trained ResNet-18 model.
118
+
119
+ Args:
120
+ image: PIL Image or numpy array
121
+
122
+ Returns:
123
+ Dictionary with predictions and confidence scores
124
+ """
125
+ try:
126
+ # Convert to PIL Image if needed
127
+ if isinstance(image, np.ndarray):
128
+ image = Image.fromarray(image)
129
+
130
+ # Convert to RGB if needed
131
+ if image.mode != 'RGB':
132
+ image = image.convert('RGB')
133
+
134
+ # Preprocess the image
135
+ input_tensor = transform(image).unsqueeze(0)
136
+
137
+ # Make prediction
138
+ with torch.no_grad():
139
+ outputs = model(input_tensor)
140
+ probabilities = F.softmax(outputs, dim=1)
141
+
142
+ # Get top 5 predictions
143
+ top5_prob, top5_idx = torch.topk(probabilities, 5, dim=1)
144
+
145
+ # Create results dictionary
146
+ results = {}
147
+ for i in range(5):
148
+ class_idx = top5_idx[0][i].item()
149
+ class_name = CIFAR100_CLASSES[class_idx]
150
+ confidence = top5_prob[0][i].item()
151
+ results[f"{class_name}"] = confidence
152
+
153
+ return results
154
+
155
+ except Exception as e:
156
+ return {"Error": f"Prediction failed: {str(e)}"}
157
+
158
+ # Load model on startup
159
+ model_loaded = load_model()
160
+
161
+ # Create Gradio interface
162
+ def create_interface():
163
+ if not model_loaded:
164
+ return gr.Interface(
165
+ fn=lambda x: {"Error": "Model failed to load. Please try again later."},
166
+ inputs=gr.Image(type="pil"),
167
+ outputs=gr.Label(num_top_classes=5),
168
+ title="❌ Model Loading Error",
169
+ description="The CIFAR-100 ResNet model could not be loaded."
170
+ )
171
+
172
+ return gr.Interface(
173
+ fn=predict,
174
+ inputs=gr.Image(type="pil", label="Upload an Image"),
175
+ outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"),
176
+ title="πŸ† CIFAR-100 ResNet-18 Classifier - 77.45% Accuracy",
177
+ description="""
178
+ **Upload an image to classify it into one of 100 CIFAR-100 categories!**
179
+
180
+ 🎯 **Model Performance:** 77.45% test accuracy (4.45% above target)
181
+ πŸ—οΈ **Architecture:** ResNet-18 with 11.22M parameters
182
+ πŸ“Š **Training:** 100 epochs on Tesla P100, reached target at epoch 58
183
+
184
+ **Best performing classes:** wardrobe (97%), motorcycle (93%), bicycle (93%), aquarium_fish (92%)
185
+
186
+ *This model excels at furniture, vehicles, and distinctive objects. For best results, upload clear images similar to CIFAR-100 style.*
187
+ """,
188
+ examples=[
189
+ # You can add example images here if available
190
+ ],
191
+ article="""
192
+ ### πŸ“š About This Model
193
+
194
+ This ResNet-18 model was trained on CIFAR-100 dataset achieving **77.45% accuracy**, exceeding the 73% target by 4.45%.
195
+
196
+ **Key Features:**
197
+ - πŸ—οΈ **Optimized Architecture:** ResNet-18 with BasicBlocks
198
+ - 🎨 **Advanced Augmentation:** Albumentations + Mixup + CutMix
199
+ - ⚑ **Fast Training:** OneCycle learning rate scheduler
200
+ - πŸ” **Interpretable:** GradCAM visualizations available
201
+
202
+ **CIFAR-100 Categories:** 100 fine-grained classes across 20 superclasses including animals, vehicles, household items, and natural objects.
203
+
204
+ πŸ“– **Full Documentation:** [GitHub Repository](https://github.com/santhoshv6/era_v4_s8_assignment)
205
+ """,
206
+ theme=gr.themes.Soft(),
207
+ allow_flagging="never"
208
+ )
209
+
210
+ # Create and launch the interface
211
+ demo = create_interface()
212
+
213
+ if __name__ == "__main__":
214
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ pillow>=9.0.0
4
+ numpy>=1.21.0
5
+ requests>=2.25.0
6
+ gradio>=4.0.0