Chanakya Hosamani commited on
Commit
a89cc74
·
1 Parent(s): f83e2e3

Update HF Space with best checkpoint

Browse files
README.md CHANGED
@@ -1,12 +1,151 @@
1
  ---
2
- title: Cifar 100 Resnet
3
- emoji: 🏆
4
- colorFrom: green
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: ResNet-18 CIFAR-100 Classifier
3
+ emoji: 🖼️
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # ResNet-18 CIFAR-100 Image Classifier 🎯
13
+
14
+ A high-performance image classifier trained on CIFAR-100 dataset, achieving **77.18% test accuracy**.
15
+
16
+ ## Model Details
17
+
18
+ - **Architecture:** ResNet-18 (Custom CIFAR-100 variant)
19
+ - **Parameters:** ~11 million
20
+ - **Test Accuracy:** 77.18%
21
+ - **Train Accuracy:** 98.25%
22
+ - **Training Time:** ~70 minutes on RTX 4070 Laptop GPU
23
+ - **Dataset:** CIFAR-100 (60,000 32×32 color images in 100 classes)
24
+
25
+ ## Training Configuration
26
+
27
+ ### Advanced Techniques Used
28
+ 1. **OneCycle Learning Rate Policy** - Gradual warmup + extended annealing
29
+ 2. **Cutout Augmentation** - Randomly masks 8×8 patches
30
+ 3. **Label Smoothing** (0.1) - Prevents overconfident predictions
31
+ 4. **Gradient Clipping** - Stabilizes training during high-LR phase
32
+ 5. **Data Augmentation** - Random crops, horizontal flips, normalization
33
+
34
+ ### Hyperparameters
35
+ - Epochs: 100
36
+ - Batch Size: 128
37
+ - Max Learning Rate: 0.1
38
+ - Weight Decay: 5e-4
39
+ - Optimizer: SGD with momentum (0.9)
40
+
41
+ ## 100 Classes
42
+
43
+ The model can classify images into these categories:
44
+
45
+ **Animals (42 classes):**
46
+ - Mammals: bear, beaver, camel, cattle, chimpanzee, dolphin, elephant, fox, hamster, kangaroo, leopard, lion, mouse, otter, porcupine, possum, rabbit, raccoon, seal, shrew, skunk, squirrel, tiger, whale, wolf
47
+ - Aquatic: aquarium_fish, crab, crocodile, flatfish, lobster, ray, shark, trout
48
+ - Insects/Small creatures: bee, beetle, butterfly, caterpillar, cockroach, snail, snake, spider, turtle, worm
49
+ - Reptiles: dinosaur, lizard
50
+
51
+ **Vehicles (5 classes):**
52
+ bicycle, bus, motorcycle, pickup_truck, streetcar, tank, tractor, train, rocket
53
+
54
+ **Household Items (11 classes):**
55
+ bed, chair, clock, couch, cup, keyboard, lamp, plate, table, telephone, television, wardrobe
56
+
57
+ **Food (5 classes):**
58
+ apple, mushroom, orange, pear, sweet_pepper
59
+
60
+ **Nature (13 classes):**
61
+ - Trees: maple_tree, oak_tree, palm_tree, pine_tree, willow_tree
62
+ - Flowers: orchid, poppy, rose, sunflower, tulip
63
+ - Landscapes: cloud, forest, mountain, plain, road, sea
64
+
65
+ **People (3 classes):**
66
+ baby, boy, girl, man, woman
67
+
68
+ **Structures (5 classes):**
69
+ bridge, castle, house, road, skyscraper
70
+
71
+ **Other (16 classes):**
72
+ aquarium_fish, bottle, bowl, bridge, can, castle, house, lawn_mower, rocket, sea, tank
73
+
74
+ ## Usage
75
+
76
+ ### On Hugging Face Spaces
77
+ Simply upload an image and get instant predictions with confidence scores!
78
+
79
+ ### Local Usage
80
+
81
+ ```python
82
+ import torch
83
+ from PIL import Image
84
+ from torchvision import transforms
85
+
86
+ # Load model
87
+ checkpoint = torch.load('checkpoints/resnet18_best.pth')
88
+ model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=100)
89
+ model.load_state_dict(checkpoint['model_state_dict'])
90
+ model.eval()
91
+
92
+ # Preprocess image
93
+ transform = transforms.Compose([
94
+ transforms.Resize((32, 32)),
95
+ transforms.ToTensor(),
96
+ transforms.Normalize((0.5071, 0.4867, 0.4408),
97
+ (0.2675, 0.2565, 0.2761))
98
+ ])
99
+
100
+ image = Image.open('your_image.jpg')
101
+ img_tensor = transform(image).unsqueeze(0)
102
+
103
+ # Predict
104
+ with torch.no_grad():
105
+ output = model(img_tensor)
106
+ pred = output.argmax(dim=1)
107
+ ```
108
+
109
+ ## Performance Notes
110
+
111
+ - **Best for:** Small objects, centered subjects, simple backgrounds
112
+ - **Optimized for:** 32×32 images (will be automatically resized)
113
+ - **Categories:** Works best with the 100 CIFAR-100 classes listed above
114
+
115
+ ## Training Curves
116
+
117
+ The model showed steady improvement throughout training:
118
+ - Epochs 1-30: Warmup phase (13.89% → 59.38%)
119
+ - Epochs 31-60: Peak learning (59.38% → 63.88%)
120
+ - Epochs 61-100: Fine-tuning (63.88% → 77.18%)
121
+
122
+ ## Key Achievements
123
+
124
+ ✅ Exceeded 73% target accuracy by **4.18%**
125
+ ✅ Stable training with no divergence
126
+ ✅ Effective use of OneCycleLR scheduler
127
+ ✅ Combined regularization techniques
128
+ ✅ Fast training (~70 minutes for 100 epochs)
129
+
130
+ ## Repository
131
+
132
+ Full training code, logs, and checkpoints available at: [GitHub Repository](https://github.com/yourusername/resnet-cifar100)
133
+
134
+ ## Citation
135
+
136
+ If you use this model, please cite:
137
+
138
+ ```bibtex
139
+ @misc{resnet18-cifar100,
140
+ author = {Your Name},
141
+ title = {ResNet-18 CIFAR-100 Image Classifier},
142
+ year = {2024},
143
+ publisher = {Hugging Face},
144
+ howpublished = {\url{https://huggingface.co/spaces/yourusername/resnet18-cifar100}}
145
+ }
146
+ ```
147
+
148
+ ## License
149
+
150
+ MIT License - Feel free to use for research and educational purposes!
151
+
app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import numpy as np
8
+
9
+ # CIFAR-100 class names
10
+ CIFAR100_CLASSES = [
11
+ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
12
+ 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
13
+ 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
14
+ 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
15
+ 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
16
+ 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
17
+ 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
18
+ 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
19
+ 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
20
+ 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
21
+ 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
22
+ 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
23
+ 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
24
+ 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
25
+ ]
26
+
27
+ # ResNet-18 Architecture
28
+ class BasicBlock(nn.Module):
29
+ expansion = 1
30
+
31
+ def __init__(self, in_planes, planes, stride=1):
32
+ super(BasicBlock, self).__init__()
33
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
34
+ self.bn1 = nn.BatchNorm2d(planes)
35
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
36
+ self.bn2 = nn.BatchNorm2d(planes)
37
+
38
+ self.shortcut = nn.Sequential()
39
+ if stride != 1 or in_planes != self.expansion * planes:
40
+ self.shortcut = nn.Sequential(
41
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
42
+ nn.BatchNorm2d(self.expansion * planes)
43
+ )
44
+
45
+ def forward(self, x):
46
+ out = F.relu(self.bn1(self.conv1(x)))
47
+ out = self.bn2(self.conv2(out))
48
+ out += self.shortcut(x)
49
+ out = F.relu(out)
50
+ return out
51
+
52
+
53
+ class ResNet(nn.Module):
54
+ def __init__(self, block, num_blocks, num_classes=100):
55
+ super(ResNet, self).__init__()
56
+ self.in_planes = 64 # Changed from 32 to 64
57
+
58
+ # For CIFAR-100, use kernel_size=3 and stride=1 (not 7 and 2 like ImageNet)
59
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) # Changed from 32 to 64
60
+ self.bn1 = nn.BatchNorm2d(64) # Changed from 32 to 64
61
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) # Changed from 32 to 64
62
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) # Changed from 64 to 128
63
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) # Changed from 128 to 256
64
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) # Changed from 256 to 512
65
+ self.linear = nn.Linear(512 * block.expansion, num_classes) # Changed from 256 to 512
66
+
67
+ def _make_layer(self, block, planes, num_blocks, stride):
68
+ strides = [stride] + [1] * (num_blocks - 1)
69
+ layers = []
70
+ for stride in strides:
71
+ layers.append(block(self.in_planes, planes, stride))
72
+ self.in_planes = planes * block.expansion
73
+ return nn.Sequential(*layers)
74
+
75
+ def forward(self, x):
76
+ out = F.relu(self.bn1(self.conv1(x)))
77
+ out = self.layer1(out)
78
+ out = self.layer2(out)
79
+ out = self.layer3(out)
80
+ out = self.layer4(out)
81
+ out = F.avg_pool2d(out, 4)
82
+ out = out.view(out.size(0), -1)
83
+ out = self.linear(out)
84
+ return out
85
+
86
+
87
+ # Load model
88
+ def load_model():
89
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
90
+ model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=100)
91
+
92
+ # Load checkpoint
93
+ try:
94
+ checkpoint = torch.load('checkpoints/resnet18_best.pth', map_location=device)
95
+ model.load_state_dict(checkpoint['model_state_dict'])
96
+ print(f"Model loaded successfully! Best accuracy: {checkpoint.get('best_acc', 'N/A')}%")
97
+ except Exception as e:
98
+ print(f"Error loading model: {e}")
99
+ print("Using randomly initialized model (for demo purposes)")
100
+
101
+ model = model.to(device)
102
+ model.eval()
103
+ return model, device
104
+
105
+
106
+ # Image preprocessing
107
+ def preprocess_image(image):
108
+ transform = transforms.Compose([
109
+ transforms.Resize((32, 32)),
110
+ transforms.ToTensor(),
111
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
112
+ ])
113
+
114
+ if image.mode != 'RGB':
115
+ image = image.convert('RGB')
116
+
117
+ img_tensor = transform(image).unsqueeze(0)
118
+ return img_tensor
119
+
120
+
121
+ # Prediction function
122
+ def predict(image):
123
+ if image is None:
124
+ return None
125
+
126
+ # Preprocess
127
+ img_tensor = preprocess_image(image)
128
+ img_tensor = img_tensor.to(device)
129
+
130
+ # Predict
131
+ with torch.no_grad():
132
+ outputs = model(img_tensor)
133
+ probabilities = F.softmax(outputs, dim=1)[0]
134
+
135
+ # Get top 5 predictions
136
+ top5_prob, top5_idx = torch.topk(probabilities, 5)
137
+
138
+ # Format results
139
+ results = {}
140
+ for i in range(5):
141
+ class_name = CIFAR100_CLASSES[top5_idx[i]]
142
+ confidence = top5_prob[i].item()
143
+ results[class_name] = float(confidence)
144
+
145
+ return results
146
+
147
+
148
+ # Initialize model
149
+ print("Loading model...")
150
+ model, device = load_model()
151
+ print("Model loaded!")
152
+
153
+ # Create Gradio interface
154
+ title = "ResNet-18 CIFAR-100 Image Classifier"
155
+ description = """
156
+ ## 🎯 ResNet-18 trained on CIFAR-100 Dataset
157
+ This model achieves **77.18% test accuracy** on CIFAR-100!
158
+
159
+ **How to use:**
160
+ 1. Upload an image or use one of the examples
161
+ 2. The model will classify it into one of 100 categories
162
+ 3. See the top 5 predictions with confidence scores
163
+
164
+ **Note:** This model was trained on 32×32 images from CIFAR-100, so it works best with:
165
+ - Small objects
166
+ - Centered subjects
167
+ - Simple backgrounds
168
+ - Animals, vehicles, household items, plants, etc.
169
+
170
+ **Training Details:**
171
+ - Architecture: ResNet-18 (11M parameters)
172
+ - Dataset: CIFAR-100 (100 classes)
173
+ - Techniques: OneCycleLR, Cutout, Label Smoothing
174
+ - Training Time: ~70 minutes on RTX 4070
175
+ """
176
+
177
+ article = """
178
+ ### Model Performance
179
+ - **Test Accuracy:** 77.18%
180
+ - **Train Accuracy:** 98.25%
181
+ - **Total Epochs:** 100
182
+ - **Training Time:** ~70 minutes
183
+
184
+ ### Classes
185
+ The model can recognize 100 different classes including:
186
+ - **Animals (42 classes):** bear, beaver, bee, beetle, butterfly, camel, cattle, chimpanzee, caterpillar, crab, crocodile, dinosaur, dolphin, elephant, flatfish, fox, hamster, kangaroo, leopard, lion, lizard, lobster, mouse, otter, porcupine, possum, rabbit, raccoon, ray, seal, shark, shrew, skunk, snail, snake, spider, squirrel, tiger, trout, turtle, whale, wolf, worm
187
+ - **Vehicles (10 classes):** bicycle, bus, motorcycle, pickup_truck, lawn_mower, rocket, streetcar, tank, tractor, train
188
+ - **Household Items (15 classes):** bed, bottle, bowl, can, chair, clock, couch, cup, keyboard, lamp, plate, table, telephone, television, wardrobe
189
+ - **People (5 classes):** baby, boy, girl, man, woman
190
+ - **Plants:**
191
+ - **Flowers (5 classes):** orchid, poppy, rose, sunflower, tulip
192
+ - **Trees (5 classes):** maple_tree, oak_tree, palm_tree, pine_tree, willow_tree
193
+ - **Food (5 classes):** apple, mushroom, orange, pear, sweet_pepper
194
+ - **Nature & Structures (13 classes):** aquarium_fish, bridge, castle, cloud, forest, house, mountain, plain, road, sea, skyscraper
195
+
196
+ ---
197
+ **Repository:** [GitHub](https://github.com/godsofheaven/Resnet-Model-Implementation-for-CIFAR-100-Dataset)
198
+ """
199
+
200
+ # Create interface
201
+ demo = gr.Interface(
202
+ fn=predict,
203
+ inputs=gr.Image(type="pil", label="Upload Image"),
204
+ outputs=gr.Label(num_top_classes=5, label="Predictions"),
205
+ title=title,
206
+ description=description,
207
+ article=article,
208
+ examples=[
209
+ # Users can add their own example images
210
+ ],
211
+ theme=gr.themes.Soft(),
212
+ analytics_enabled=False,
213
+ )
214
+
215
+ # Launch
216
+ if __name__ == "__main__":
217
+ demo.launch()
218
+
checkpoints/resnet18_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd77ace4a7ea7f77284fe12b1c5944ed14bff3072dbc099325b4d9883ecb8bac
3
+ size 89865411
checkpoints/resnet18_last.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:752718b32f4b5eb6f8b6c2a2b3f76753f71c4241d249a85bdfab41a9ecaa9a2e
3
+ size 89865475
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ gradio>=4.0.0
4
+ pillow>=9.0.0
5
+ numpy>=1.20.0
6
+