chandu1617 commited on
Commit
7a59d7b
Β·
verified Β·
1 Parent(s): 7ac2efb

Upload 10 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ accuracy_detail.png filter=lfs diff=lfs merge=lfs -text
37
+ training_analysis.png filter=lfs diff=lfs merge=lfs -text
accuracy_detail.png ADDED

Git LFS Details

  • SHA256: 0c96beffdfec86c8e7868c9107840c24b48b54b5728f321a527b7a7e35a9d857
  • Pointer size: 131 Bytes
  • Size of remote file: 210 kB
app (1).py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ from PIL import Image
4
+ import gradio as gr
5
+ import os
6
+
7
+ # Assuming ResNet50 class is defined in main.py or you copy it here
8
+ # For simplicity, I'll put a placeholder. In a real scenario, you'd import ResNet50
9
+ # from a separate models.py or main.py. For this example, let's assume it's available.
10
+
11
+ # --- ResNet50 Model Definition (copy-pasted from main.py for self-containment) ---
12
+ class Bottleneck(torch.nn.Module):
13
+ expansion = 4
14
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
15
+ super().__init__()
16
+ self.conv1 = torch.nn.Conv2d(inplanes, planes, 1, bias=False)
17
+ self.bn1 = torch.nn.BatchNorm2d(planes)
18
+ self.conv2 = torch.nn.Conv2d(planes, planes, 3, stride, 1, bias=False)
19
+ self.bn2 = torch.nn.BatchNorm2d(planes)
20
+ self.conv3 = torch.nn.Conv2d(planes, planes*self.expansion, 1, bias=False)
21
+ self.bn3 = torch.nn.BatchNorm2d(planes*self.expansion)
22
+ self.relu = torch.nn.ReLU(inplace=True)
23
+ self.downsample = downsample
24
+
25
+ def forward(self, x):
26
+ identity = x
27
+ out = self.conv1(x)
28
+ out = self.bn1(out)
29
+ out = self.relu(out)
30
+ out = self.conv2(out)
31
+ out = self.bn2(out)
32
+ out = self.relu(out)
33
+ out = self.conv3(out)
34
+ out = self.bn3(out)
35
+ if self.downsample: identity = self.downsample(x)
36
+ out += identity
37
+ out = self.relu(out)
38
+ return out
39
+
40
+ class ResNet50(torch.nn.Module):
41
+ def __init__(self, num_classes=101):
42
+ super().__init__()
43
+ self.inplanes = 64
44
+
45
+ self.conv1 = torch.nn.Conv2d(3, 64, 7, 2, 3, bias=False)
46
+ self.bn1 = torch.nn.BatchNorm2d(64)
47
+ self.relu = torch.nn.ReLU(inplace=True)
48
+ self.maxpool = torch.nn.MaxPool2d(3, 2, 1)
49
+
50
+ self.layer1 = self._make_layer(Bottleneck, 64, 3)
51
+ self.layer2 = self._make_layer(Bottleneck, 128, 4, 2)
52
+ self.layer3 = self._make_layer(Bottleneck, 256, 6, 2)
53
+ self.layer4 = self._make_layer(Bottleneck, 512, 3, 2)
54
+
55
+ self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
56
+ self.fc = torch.nn.Linear(512*Bottleneck.expansion, num_classes)
57
+
58
+ self._initialize_weights()
59
+
60
+ def _make_layer(self, block, planes, blocks, stride=1):
61
+ downsample = None
62
+ if stride != 1 or self.inplanes != planes*block.expansion:
63
+ downsample = torch.nn.Sequential(
64
+ torch.nn.Conv2d(self.inplanes, planes*block.expansion, 1, stride, bias=False),
65
+ torch.nn.BatchNorm2d(planes*block.expansion)
66
+ )
67
+
68
+ layers = [block(self.inplanes, planes, stride, downsample)]
69
+ self.inplanes = planes * block.expansion
70
+ for _ in range(1, blocks):
71
+ layers.append(block(self.inplanes, planes))
72
+ return torch.nn.Sequential(*layers)
73
+
74
+ def _initialize_weights(self):
75
+ for m in self.modules():
76
+ if isinstance(m, torch.nn.Conv2d):
77
+ torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
78
+ elif isinstance(m, torch.nn.BatchNorm2d):
79
+ torch.nn.init.constant_(m.weight, 1)
80
+ torch.nn.init.constant_(m.bias, 0)
81
+
82
+ def forward(self, x):
83
+ x = self.conv1(x)
84
+ x = self.bn1(x)
85
+ x = self.relu(x)
86
+ x = self.maxpool(x)
87
+
88
+ x = self.layer1(x)
89
+ x = self.layer2(x)
90
+ x = self.layer3(x)
91
+ x = self.layer4(x)
92
+
93
+ x = self.avgpool(x)
94
+ x = torch.flatten(x, 1)
95
+ x = self.fc(x)
96
+ return x
97
+ # --- End ResNet50 Model Definition ---
98
+
99
+
100
+ # Load class names
101
+ with open('./outputs/food101_classes_simple.txt', 'r') as f:
102
+ class_names = [line.strip() for line in f]
103
+
104
+ num_classes = len(class_names)
105
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
106
+
107
+ # Load the model
108
+ model = ResNet50(num_classes=num_classes).to(device)
109
+ model_path = './outputs/food101_resnet50_final_weights.pth'
110
+ if not os.path.exists(model_path):
111
+ raise FileNotFoundError(f"Model weights not found at {model_path}. Please train the model first.")
112
+
113
+ model.load_state_dict(torch.load(model_path, map_location=device))
114
+ model.eval()
115
+
116
+ # Define the image transformations
117
+ transform = transforms.Compose([
118
+ transforms.Resize((224, 224)),
119
+ transforms.ToTensor(),
120
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
121
+ ])
122
+
123
+ def predict_image(image: Image.Image):
124
+ # Apply transformations
125
+ image = transform(image).unsqueeze(0).to(device)
126
+
127
+ with torch.no_grad():
128
+ outputs = model(image)
129
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
130
+
131
+ # Get top 5 predictions
132
+ top5_prob, top5_indices = torch.topk(probabilities, 5)
133
+
134
+ predictions = {class_names[idx]: round(prob.item() * 100, 2) for idx, prob in zip(top5_indices, top5_prob)}
135
+
136
+ return predictions
137
+
138
+ # Create Gradio interface
139
+ iface = gr.Interface(
140
+ fn=predict_image,
141
+ inputs=gr.Image(type="pil", label="Upload Food Image"),
142
+ outputs=gr.Label(num_top_classes=5),
143
+ title="Food101 ResNet50 Classifier",
144
+ description="Upload an image of food and get predictions for 101 food categories. Model trained on Food101 dataset.",
145
+ examples=[
146
+ # Add some example images here if you have them, e.g.,
147
+ # ["path/to/example_image1.jpg"],
148
+ # ["path/to/example_image2.jpg"],
149
+ ]
150
+ )
151
+
152
+ # Launch the Gradio app
153
+ if __name__ == "__main__":
154
+ iface.launch(server_name="0.0.0.0", server_port=8000) # Use port 8000 for Lightning AI deployments
food101_classes.txt ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Food101 Classes (101 total)
2
+ ==============================
3
+
4
+ 1. Apple Pie
5
+ 2. Baby Back Ribs
6
+ 3. Baklava
7
+ 4. Beef Carpaccio
8
+ 5. Beef Tartare
9
+ 6. Beet Salad
10
+ 7. Beignets
11
+ 8. Bibimbap
12
+ 9. Bread Pudding
13
+ 10. Breakfast Burrito
14
+ 11. Bruschetta
15
+ 12. Caesar Salad
16
+ 13. Cannoli
17
+ 14. Caprese Salad
18
+ 15. Carrot Cake
19
+ 16. Ceviche
20
+ 17. Cheese Plate
21
+ 18. Cheesecake
22
+ 19. Chicken Curry
23
+ 20. Chicken Quesadilla
24
+ 21. Chicken Wings
25
+ 22. Chocolate Cake
26
+ 23. Chocolate Mousse
27
+ 24. Churros
28
+ 25. Clam Chowder
29
+ 26. Club Sandwich
30
+ 27. Crab Cakes
31
+ 28. Creme Brulee
32
+ 29. Croque Madame
33
+ 30. Cup Cakes
34
+ 31. Deviled Eggs
35
+ 32. Donuts
36
+ 33. Dumplings
37
+ 34. Edamame
38
+ 35. Eggs Benedict
39
+ 36. Escargots
40
+ 37. Falafel
41
+ 38. Filet Mignon
42
+ 39. Fish And Chips
43
+ 40. Foie Gras
44
+ 41. French Fries
45
+ 42. French Onion Soup
46
+ 43. French Toast
47
+ 44. Fried Calamari
48
+ 45. Fried Rice
49
+ 46. Frozen Yogurt
50
+ 47. Garlic Bread
51
+ 48. Gnocchi
52
+ 49. Greek Salad
53
+ 50. Grilled Cheese Sandwich
54
+ 51. Grilled Salmon
55
+ 52. Guacamole
56
+ 53. Gyoza
57
+ 54. Hamburger
58
+ 55. Hot And Sour Soup
59
+ 56. Hot Dog
60
+ 57. Huevos Rancheros
61
+ 58. Hummus
62
+ 59. Ice Cream
63
+ 60. Lasagna
64
+ 61. Lobster Bisque
65
+ 62. Lobster Roll Sandwich
66
+ 63. Macaroni And Cheese
67
+ 64. Macarons
68
+ 65. Miso Soup
69
+ 66. Mussels
70
+ 67. Nachos
71
+ 68. Omelette
72
+ 69. Onion Rings
73
+ 70. Oysters
74
+ 71. Pad Thai
75
+ 72. Paella
76
+ 73. Pancakes
77
+ 74. Panna Cotta
78
+ 75. Peking Duck
79
+ 76. Pho
80
+ 77. Pizza
81
+ 78. Pork Chop
82
+ 79. Poutine
83
+ 80. Prime Rib
84
+ 81. Pulled Pork Sandwich
85
+ 82. Ramen
86
+ 83. Ravioli
87
+ 84. Red Velvet Cake
88
+ 85. Risotto
89
+ 86. Samosa
90
+ 87. Sashimi
91
+ 88. Scallops
92
+ 89. Seaweed Salad
93
+ 90. Shrimp And Grits
94
+ 91. Spaghetti Bolognese
95
+ 92. Spaghetti Carbonara
96
+ 93. Spring Rolls
97
+ 94. Steak
98
+ 95. Strawberry Shortcake
99
+ 96. Sushi
100
+ 97. Tacos
101
+ 98. Takoyaki
102
+ 99. Tiramisu
103
+ 100. Tuna Tartare
104
+ 101. Waffles
food101_classes_simple.txt ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ apple_pie
2
+ baby_back_ribs
3
+ baklava
4
+ beef_carpaccio
5
+ beef_tartare
6
+ beet_salad
7
+ beignets
8
+ bibimbap
9
+ bread_pudding
10
+ breakfast_burrito
11
+ bruschetta
12
+ caesar_salad
13
+ cannoli
14
+ caprese_salad
15
+ carrot_cake
16
+ ceviche
17
+ cheese_plate
18
+ cheesecake
19
+ chicken_curry
20
+ chicken_quesadilla
21
+ chicken_wings
22
+ chocolate_cake
23
+ chocolate_mousse
24
+ churros
25
+ clam_chowder
26
+ club_sandwich
27
+ crab_cakes
28
+ creme_brulee
29
+ croque_madame
30
+ cup_cakes
31
+ deviled_eggs
32
+ donuts
33
+ dumplings
34
+ edamame
35
+ eggs_benedict
36
+ escargots
37
+ falafel
38
+ filet_mignon
39
+ fish_and_chips
40
+ foie_gras
41
+ french_fries
42
+ french_onion_soup
43
+ french_toast
44
+ fried_calamari
45
+ fried_rice
46
+ frozen_yogurt
47
+ garlic_bread
48
+ gnocchi
49
+ greek_salad
50
+ grilled_cheese_sandwich
51
+ grilled_salmon
52
+ guacamole
53
+ gyoza
54
+ hamburger
55
+ hot_and_sour_soup
56
+ hot_dog
57
+ huevos_rancheros
58
+ hummus
59
+ ice_cream
60
+ lasagna
61
+ lobster_bisque
62
+ lobster_roll_sandwich
63
+ macaroni_and_cheese
64
+ macarons
65
+ miso_soup
66
+ mussels
67
+ nachos
68
+ omelette
69
+ onion_rings
70
+ oysters
71
+ pad_thai
72
+ paella
73
+ pancakes
74
+ panna_cotta
75
+ peking_duck
76
+ pho
77
+ pizza
78
+ pork_chop
79
+ poutine
80
+ prime_rib
81
+ pulled_pork_sandwich
82
+ ramen
83
+ ravioli
84
+ red_velvet_cake
85
+ risotto
86
+ samosa
87
+ sashimi
88
+ scallops
89
+ seaweed_salad
90
+ shrimp_and_grits
91
+ spaghetti_bolognese
92
+ spaghetti_carbonara
93
+ spring_rolls
94
+ steak
95
+ strawberry_shortcake
96
+ sushi
97
+ tacos
98
+ takoyaki
99
+ tiramisu
100
+ tuna_tartare
101
+ waffles
main (3).py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ βœ… OPTIMIZED Food101 + ResNet50 with major speed improvements
4
+ βœ… Mixed precision training (2x faster)
5
+ βœ… Better data loading (persistent workers)
6
+ βœ… Progress bars and better logging
7
+ βœ… Robust error handling and checkpointing
8
+ """
9
+
10
+ import os
11
+ import time
12
+ import copy
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ from tqdm import tqdm
16
+ import logging
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.optim as optim
21
+ import torchvision
22
+ import torchvision.transforms as transforms
23
+ from torch.utils.data import DataLoader
24
+ from torch.cuda.amp import autocast, GradScaler
25
+
26
+ # Setup logging
27
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
28
+ logger = logging.getLogger(__name__)
29
+
30
+ # -------------------------
31
+ # OPTIMIZED Data Loaders
32
+ # -------------------------
33
+ def get_food101_loaders(batch_size=64, num_workers=8): # Increased batch size and workers
34
+ """Returns optimized train/val/test loaders + class names"""
35
+
36
+ # More aggressive data augmentation
37
+ transform_train = transforms.Compose([
38
+ transforms.Resize((256, 256)), # Resize larger first
39
+ transforms.RandomCrop((224, 224)), # Then crop to avoid distortion
40
+ transforms.RandomHorizontalFlip(p=0.5),
41
+ transforms.RandomRotation(15),
42
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
43
+ transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
46
+ ])
47
+
48
+ transform_test = transforms.Compose([
49
+ transforms.Resize((224, 224)),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
52
+ ])
53
+
54
+ try:
55
+ # Full train split (75k images)
56
+ full_train = torchvision.datasets.Food101(
57
+ root='./data', split='train', download=True, transform=transform_train
58
+ )
59
+
60
+ # 90/10 train/val split with fixed seed for reproducibility
61
+ torch.manual_seed(42)
62
+ train_size = int(0.9 * len(full_train))
63
+ val_size = len(full_train) - train_size
64
+ train_dataset, val_dataset = torch.utils.data.random_split(
65
+ full_train, [train_size, val_size]
66
+ )
67
+
68
+ # Test split (25k images)
69
+ test_dataset = torchvision.datasets.Food101(
70
+ root='./data', split='test', download=True, transform=transform_test
71
+ )
72
+
73
+ logger.info(f"Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
74
+
75
+ # Optimized DataLoaders with persistent workers
76
+ train_loader = DataLoader(
77
+ train_dataset, batch_size, shuffle=True, num_workers=num_workers,
78
+ pin_memory=True, persistent_workers=True, drop_last=True
79
+ )
80
+ val_loader = DataLoader(
81
+ val_dataset, batch_size, shuffle=False, num_workers=num_workers,
82
+ pin_memory=True, persistent_workers=True
83
+ )
84
+ test_loader = DataLoader(
85
+ test_dataset, batch_size, shuffle=False, num_workers=num_workers,
86
+ pin_memory=True, persistent_workers=True
87
+ )
88
+
89
+ return train_loader, val_loader, test_loader, full_train.classes
90
+
91
+ except Exception as e:
92
+ logger.error(f"Error loading data: {e}")
93
+ raise
94
+
95
+
96
+ # -------------------------
97
+ # ResNet Building Blocks (same as original but with better initialization)
98
+ # -------------------------
99
+ class BasicBlock(nn.Module):
100
+ expansion = 1
101
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
102
+ super().__init__()
103
+ self.conv1 = nn.Conv2d(inplanes, planes, 3, stride, 1, bias=False)
104
+ self.bn1 = nn.BatchNorm2d(planes)
105
+ self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False)
106
+ self.bn2 = nn.BatchNorm2d(planes)
107
+ self.relu = nn.ReLU(inplace=True)
108
+ self.downsample = downsample
109
+
110
+ def forward(self, x):
111
+ identity = x
112
+ out = self.conv1(x)
113
+ out = self.bn1(out)
114
+ out = self.relu(out)
115
+ out = self.conv2(out)
116
+ out = self.bn2(out)
117
+ if self.downsample: identity = self.downsample(x)
118
+ out += identity
119
+ out = self.relu(out)
120
+ return out
121
+
122
+
123
+ class Bottleneck(nn.Module):
124
+ expansion = 4
125
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
126
+ super().__init__()
127
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
128
+ self.bn1 = nn.BatchNorm2d(planes)
129
+ self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, bias=False)
130
+ self.bn2 = nn.BatchNorm2d(planes)
131
+ self.conv3 = nn.Conv2d(planes, planes*self.expansion, 1, bias=False)
132
+ self.bn3 = nn.BatchNorm2d(planes*self.expansion)
133
+ self.relu = nn.ReLU(inplace=True)
134
+ self.downsample = downsample
135
+
136
+ def forward(self, x):
137
+ identity = x
138
+ out = self.conv1(x)
139
+ out = self.bn1(out)
140
+ out = self.relu(out)
141
+ out = self.conv2(out)
142
+ out = self.bn2(out)
143
+ out = self.relu(out)
144
+ out = self.conv3(out)
145
+ out = self.bn3(out)
146
+ if self.downsample: identity = self.downsample(x)
147
+ out += identity
148
+ out = self.relu(out)
149
+ return out
150
+
151
+
152
+ class ResNet50(nn.Module):
153
+ def __init__(self, num_classes=101):
154
+ super().__init__()
155
+ self.inplanes = 64
156
+
157
+ self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
158
+ self.bn1 = nn.BatchNorm2d(64)
159
+ self.relu = nn.ReLU(inplace=True)
160
+ self.maxpool = nn.MaxPool2d(3, 2, 1)
161
+
162
+ self.layer1 = self._make_layer(Bottleneck, 64, 3)
163
+ self.layer2 = self._make_layer(Bottleneck, 128, 4, 2)
164
+ self.layer3 = self._make_layer(Bottleneck, 256, 6, 2)
165
+ self.layer4 = self._make_layer(Bottleneck, 512, 3, 2)
166
+
167
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
168
+ self.fc = nn.Linear(512*Bottleneck.expansion, num_classes)
169
+
170
+ # Better initialization
171
+ self._initialize_weights()
172
+
173
+ def _make_layer(self, block, planes, blocks, stride=1):
174
+ downsample = None
175
+ if stride != 1 or self.inplanes != planes*block.expansion:
176
+ downsample = nn.Sequential(
177
+ nn.Conv2d(self.inplanes, planes*block.expansion, 1, stride, bias=False),
178
+ nn.BatchNorm2d(planes*block.expansion)
179
+ )
180
+
181
+ layers = [block(self.inplanes, planes, stride, downsample)]
182
+ self.inplanes = planes * block.expansion
183
+ for _ in range(1, blocks):
184
+ layers.append(block(self.inplanes, planes))
185
+ return nn.Sequential(*layers)
186
+
187
+ def _initialize_weights(self):
188
+ for m in self.modules():
189
+ if isinstance(m, nn.Conv2d):
190
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
191
+ elif isinstance(m, nn.BatchNorm2d):
192
+ nn.init.constant_(m.weight, 1)
193
+ nn.init.constant_(m.bias, 0)
194
+
195
+ def forward(self, x):
196
+ x = self.conv1(x)
197
+ x = self.bn1(x)
198
+ x = self.relu(x)
199
+ x = self.maxpool(x)
200
+
201
+ x = self.layer1(x)
202
+ x = self.layer2(x)
203
+ x = self.layer3(x)
204
+ x = self.layer4(x)
205
+
206
+ x = self.avgpool(x)
207
+ x = torch.flatten(x, 1)
208
+ x = self.fc(x)
209
+ return x
210
+
211
+
212
+ # -------------------------
213
+ # OPTIMIZED Training Function with Mixed Precision
214
+ # -------------------------
215
+ def train_model(model, train_loader, val_loader, test_loader, device, num_epochs=100, resume_from=None):
216
+ """Optimized training loop with mixed precision and better checkpointing"""
217
+
218
+ os.makedirs('./outputs', exist_ok=True)
219
+
220
+ criterion = nn.CrossEntropyLoss(label_smoothing=0.1) # Label smoothing for better generalization
221
+ optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4, nesterov=True)
222
+ scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
223
+
224
+ # Mixed precision scaler
225
+ scaler = GradScaler()
226
+
227
+ best_val_acc = 0.0
228
+ train_losses, val_accuracies, learning_rates = [], [], []
229
+ start_epoch = 0
230
+
231
+ # Resume from checkpoint if provided
232
+ if resume_from and os.path.exists(resume_from):
233
+ logger.info(f"Resuming from {resume_from}")
234
+ checkpoint = torch.load(resume_from, map_location=device)
235
+ model.load_state_dict(checkpoint['model_state_dict'])
236
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
237
+ start_epoch = checkpoint['epoch']
238
+ best_val_acc = checkpoint.get('best_val_accuracy', 0.0)
239
+ train_losses = checkpoint.get('train_losses', [])
240
+ val_accuracies = checkpoint.get('val_accuracies', [])
241
+ learning_rates = checkpoint.get('learning_rates', [])
242
+
243
+ logger.info(f"πŸš€ Starting training from epoch {start_epoch+1} for {num_epochs} total epochs...")
244
+
245
+ # Track timing
246
+ total_train_time = 0
247
+
248
+ for epoch in range(start_epoch, num_epochs):
249
+ epoch_start = time.time()
250
+
251
+ # Training phase
252
+ model.train()
253
+ running_loss = 0.0
254
+ train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]', leave=False)
255
+
256
+ for batch_idx, (images, labels) in enumerate(train_pbar):
257
+ images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
258
+
259
+ optimizer.zero_grad()
260
+
261
+ # Mixed precision forward pass
262
+ with autocast():
263
+ outputs = model(images)
264
+ loss = criterion(outputs, labels)
265
+
266
+ # Mixed precision backward pass
267
+ scaler.scale(loss).backward()
268
+ scaler.step(optimizer)
269
+ scaler.update()
270
+
271
+ running_loss += loss.item()
272
+ train_pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{optimizer.param_groups[0]["lr"]:.6f}'})
273
+
274
+ avg_train_loss = running_loss / len(train_loader)
275
+ train_losses.append(avg_train_loss)
276
+ learning_rates.append(optimizer.param_groups[0]['lr'])
277
+
278
+ # Validation phase
279
+ model.eval()
280
+ val_loss = 0.0
281
+ correct = 0
282
+ total = 0
283
+ val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]', leave=False)
284
+
285
+ with torch.no_grad():
286
+ for images, labels in val_pbar:
287
+ images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
288
+
289
+ with autocast():
290
+ outputs = model(images)
291
+ loss = criterion(outputs, labels)
292
+
293
+ val_loss += loss.item()
294
+ _, predicted = torch.max(outputs, 1)
295
+ total += labels.size(0)
296
+ correct += (predicted == labels).sum().item()
297
+
298
+ val_pbar.set_postfix({'acc': f'{100.*correct/total:.2f}%'})
299
+
300
+ val_acc = 100. * correct / total
301
+ val_accuracies.append(val_acc)
302
+ avg_val_loss = val_loss / len(val_loader)
303
+
304
+ # Save best model
305
+ is_best = val_acc > best_val_acc
306
+ if is_best:
307
+ best_val_acc = val_acc
308
+
309
+ # Save checkpoint every 10 epochs and if best
310
+ if (epoch + 1) % 10 == 0 or is_best or epoch == num_epochs - 1:
311
+ checkpoint = {
312
+ 'epoch': epoch + 1,
313
+ 'model_state_dict': model.state_dict(),
314
+ 'optimizer_state_dict': optimizer.state_dict(),
315
+ 'scaler_state_dict': scaler.state_dict(),
316
+ 'best_val_accuracy': best_val_acc,
317
+ 'current_val_accuracy': val_acc,
318
+ 'train_losses': train_losses,
319
+ 'val_accuracies': val_accuracies,
320
+ 'learning_rates': learning_rates,
321
+ }
322
+
323
+ if is_best:
324
+ torch.save(checkpoint, './outputs/food101_resnet50_best.pth')
325
+ # Save just the weights for easier loading
326
+ torch.save(model.state_dict(), './outputs/food101_resnet50_best_weights.pth')
327
+
328
+ if (epoch + 1) % 10 == 0:
329
+ torch.save(checkpoint, f'./outputs/food101_resnet50_epoch_{epoch+1}.pth')
330
+
331
+ scheduler.step()
332
+ epoch_time = time.time() - epoch_start
333
+ total_train_time += epoch_time
334
+
335
+ logger.info(f"Epoch {epoch+1:3d}/{num_epochs} | "
336
+ f"Train Loss: {avg_train_loss:.4f} | "
337
+ f"Val Loss: {avg_val_loss:.4f} | "
338
+ f"Val Acc: {val_acc:.2f}% | "
339
+ f"Best: {best_val_acc:.2f}% | "
340
+ f"LR: {optimizer.param_groups[0]['lr']:.6f} | "
341
+ f"Time: {epoch_time:.1f}s")
342
+
343
+ # Save final model
344
+ final_checkpoint = {
345
+ 'epoch': num_epochs,
346
+ 'model_state_dict': model.state_dict(),
347
+ 'optimizer_state_dict': optimizer.state_dict(),
348
+ 'scaler_state_dict': scaler.state_dict(),
349
+ 'final_val_accuracy': val_accuracies[-1],
350
+ 'best_val_accuracy': best_val_acc,
351
+ 'train_losses': train_losses,
352
+ 'val_accuracies': val_accuracies,
353
+ 'learning_rates': learning_rates,
354
+ 'total_train_time': total_train_time,
355
+ }
356
+ torch.save(final_checkpoint, './outputs/food101_resnet50_final.pth')
357
+ torch.save(model.state_dict(), './outputs/food101_resnet50_final_weights.pth')
358
+
359
+ logger.info(f"πŸ“Š Total training time: {total_train_time/3600:.2f} hours")
360
+
361
+ # Test final accuracy
362
+ test_acc = evaluate_model(model, test_loader, device, "Test")
363
+ logger.info(f"🎯 Final Test Accuracy: {test_acc:.2f}%")
364
+
365
+ # Save comprehensive plots
366
+ plot_training_curves(train_losses, val_accuracies, learning_rates)
367
+
368
+ return best_val_acc, train_losses, val_accuracies
369
+
370
+
371
+ def evaluate_model(model, test_loader, device, dataset_name="Test"):
372
+ """Evaluate model with progress bar"""
373
+ model.eval()
374
+ correct = 0
375
+ total = 0
376
+ test_pbar = tqdm(test_loader, desc=f'{dataset_name} Evaluation', leave=False)
377
+
378
+ with torch.no_grad():
379
+ for images, labels in test_pbar:
380
+ images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
381
+
382
+ with autocast():
383
+ outputs = model(images)
384
+
385
+ _, predicted = torch.max(outputs, 1)
386
+ total += labels.size(0)
387
+ correct += (predicted == labels).sum().item()
388
+
389
+ test_pbar.set_postfix({'acc': f'{100.*correct/total:.2f}%'})
390
+
391
+ return 100. * correct / total
392
+
393
+
394
+ def plot_training_curves(train_losses, val_accuracies, learning_rates):
395
+ """Enhanced plotting with more visualizations"""
396
+ epochs = np.arange(1, len(train_losses) + 1)
397
+
398
+ plt.style.use('default')
399
+ fig, axes = plt.subplots(2, 2, figsize=(16, 12))
400
+ fig.suptitle('Food101 ResNet50 Training Analysis', fontsize=16, fontweight='bold')
401
+
402
+ # Training Loss
403
+ axes[0, 0].plot(epochs, train_losses, 'b-', linewidth=2, alpha=0.8)
404
+ axes[0, 0].set_title('Training Loss Over Time', fontweight='bold')
405
+ axes[0, 0].set_xlabel('Epoch')
406
+ axes[0, 0].set_ylabel('Loss')
407
+ axes[0, 0].grid(True, alpha=0.3)
408
+ axes[0, 0].set_yscale('log')
409
+
410
+ # Validation Accuracy
411
+ axes[0, 1].plot(epochs, val_accuracies, 'r-', linewidth=2, alpha=0.8)
412
+ axes[0, 1].set_title('Validation Accuracy Over Time', fontweight='bold')
413
+ axes[0, 1].set_xlabel('Epoch')
414
+ axes[0, 1].set_ylabel('Accuracy (%)')
415
+ axes[0, 1].grid(True, alpha=0.3)
416
+ axes[0, 1].axhline(y=max(val_accuracies), color='r', linestyle='--', alpha=0.7,
417
+ label=f'Best: {max(val_accuracies):.2f}%')
418
+ axes[0, 1].legend()
419
+
420
+ # Learning Rate Schedule
421
+ axes[1, 0].plot(epochs, learning_rates, 'g-', linewidth=2, alpha=0.8)
422
+ axes[1, 0].set_title('Learning Rate Schedule', fontweight='bold')
423
+ axes[1, 0].set_xlabel('Epoch')
424
+ axes[1, 0].set_ylabel('Learning Rate')
425
+ axes[1, 0].grid(True, alpha=0.3)
426
+ axes[1, 0].set_yscale('log')
427
+
428
+ # Combined view
429
+ ax_combined = axes[1, 1]
430
+ ax_combined.plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=2, alpha=0.8)
431
+ ax_combined.set_xlabel('Epoch')
432
+ ax_combined.set_ylabel('Loss', color='b')
433
+ ax_combined.tick_params(axis='y', labelcolor='b')
434
+ ax_combined.set_yscale('log')
435
+
436
+ ax2 = ax_combined.twinx()
437
+ ax2.plot(epochs, val_accuracies, 'r-', label='Val Accuracy', linewidth=2, alpha=0.8)
438
+ ax2.set_ylabel('Accuracy (%)', color='r')
439
+ ax2.tick_params(axis='y', labelcolor='r')
440
+
441
+ ax_combined.set_title('Loss vs Accuracy', fontweight='bold')
442
+ ax_combined.grid(True, alpha=0.3)
443
+
444
+ plt.tight_layout()
445
+ plt.savefig('./outputs/training_analysis.png', dpi=300, bbox_inches='tight')
446
+ plt.close()
447
+
448
+ # Additional detailed accuracy plot
449
+ plt.figure(figsize=(12, 6))
450
+ plt.plot(epochs, val_accuracies, 'r-', linewidth=2, alpha=0.8)
451
+ plt.fill_between(epochs, val_accuracies, alpha=0.3)
452
+ plt.title('Validation Accuracy Progress', fontsize=14, fontweight='bold')
453
+ plt.xlabel('Epoch')
454
+ plt.ylabel('Accuracy (%)')
455
+ plt.grid(True, alpha=0.3)
456
+ plt.axhline(y=max(val_accuracies), color='r', linestyle='--', alpha=0.7,
457
+ label=f'Peak Accuracy: {max(val_accuracies):.2f}%')
458
+ plt.legend()
459
+ plt.tight_layout()
460
+ plt.savefig('./outputs/accuracy_detail.png', dpi=300, bbox_inches='tight')
461
+ plt.close()
462
+
463
+ logger.info("πŸ“Š Saved enhanced training visualizations")
464
+
465
+
466
+ def save_classes(classes):
467
+ """Save Food101 class names with better formatting"""
468
+ os.makedirs('./outputs', exist_ok=True)
469
+
470
+ with open('./outputs/food101_classes.txt', 'w') as f:
471
+ f.write("Food101 Classes (101 total)\n")
472
+ f.write("=" * 30 + "\n\n")
473
+ for i, cls in enumerate(sorted(classes), 1):
474
+ f.write(f"{i:3d}. {cls.replace('_', ' ').title()}\n")
475
+
476
+ # Also save as a simple list for easy loading
477
+ with open('./outputs/food101_classes_simple.txt', 'w') as f:
478
+ for cls in sorted(classes):
479
+ f.write(f"{cls}\n")
480
+
481
+ logger.info("πŸ“ Saved class names to ./outputs/")
482
+
483
+
484
+ def print_system_info():
485
+ """Print system information for debugging"""
486
+ logger.info("πŸ–₯️ System Information:")
487
+ logger.info(f"PyTorch version: {torch.__version__}")
488
+ logger.info(f"CUDA available: {torch.cuda.is_available()}")
489
+ if torch.cuda.is_available():
490
+ logger.info(f"CUDA version: {torch.version.cuda}")
491
+ logger.info(f"GPU: {torch.cuda.get_device_name()}")
492
+ logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
493
+ logger.info(f"Number of CPU cores: {os.cpu_count()}")
494
+
495
+
496
+ # -------------------------
497
+ # MAIN
498
+ # -------------------------
499
+ def main():
500
+ print_system_info()
501
+
502
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
503
+ logger.info(f"Using device: {device}")
504
+
505
+ try:
506
+ # Load data with optimized settings
507
+ logger.info("πŸ“₯ Loading Food101 dataset...")
508
+ train_loader, val_loader, test_loader, classes = get_food101_loaders(batch_size=64, num_workers=8)
509
+ save_classes(classes)
510
+
511
+ # Model
512
+ logger.info("πŸ—οΈ Building ResNet50...")
513
+ model = ResNet50(num_classes=101).to(device)
514
+ total_params = sum(p.numel() for p in model.parameters())
515
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
516
+ logger.info(f"Total parameters: {total_params/1e6:.1f}M")
517
+ logger.info(f"Trainable parameters: {trainable_params/1e6:.1f}M")
518
+
519
+ # Enable compilation for PyTorch 2.0+
520
+ if hasattr(torch, 'compile'):
521
+ logger.info("πŸš€ Compiling model for faster training...")
522
+ model = torch.compile(model)
523
+
524
+ # Train
525
+ best_val_acc, losses, accuracies = train_model(
526
+ model, train_loader, val_loader, test_loader, device,
527
+ num_epochs=100, resume_from='./outputs/food101_resnet50_best.pth' if os.path.exists('./outputs/food101_resnet50_best.pth') else None
528
+ )
529
+
530
+ logger.info(f"\nπŸŽ‰ TRAINING COMPLETE!")
531
+ logger.info(f"πŸ† Best Validation Accuracy: {best_val_acc:.2f}%")
532
+ logger.info(f"\nπŸ“ SAVED FILES:")
533
+ logger.info(f" β€’ ./outputs/food101_resnet50_best.pth (best checkpoint)")
534
+ logger.info(f" β€’ ./outputs/food101_resnet50_best_weights.pth (best weights only)")
535
+ logger.info(f" β€’ ./outputs/food101_resnet50_final.pth (final checkpoint)")
536
+ logger.info(f" β€’ ./outputs/food101_resnet50_final_weights.pth (final weights only)")
537
+ logger.info(f" β€’ ./outputs/training_analysis.png (comprehensive plots)")
538
+ logger.info(f" β€’ ./outputs/accuracy_detail.png (detailed accuracy)")
539
+ logger.info(f" β€’ ./outputs/food101_classes.txt (formatted class list)")
540
+ logger.info(f" β€’ ./outputs/food101_classes_simple.txt (simple class list)")
541
+
542
+ except Exception as e:
543
+ logger.error(f"❌ Training failed with error: {e}")
544
+ raise
545
+
546
+
547
+ if __name__ == "__main__":
548
+ main()
outputs.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc29b904cd5db7b7f53148c7a2e796b8b406be107124127d03f74800f889b114
3
+ size 2294595563
outputs_food101_resnet50_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b1fde9993b9a56d998e3e420fd6ce7a82a2d94cfbbbb3b9e953d7ae347b5460
3
+ size 190101638
outputs_food101_resnet50_final_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a39d616ca1f70273efcb80ad5aeea2544be752c9696dd49b1b27af5864ed924
3
+ size 95190217
requirements (3).txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ matplotlib>=3.5.0
4
+ numpy>=1.21.0
5
+ Pillow>=8.0.0
6
+ tqdm>=4.64.0
training_analysis.png ADDED

Git LFS Details

  • SHA256: 780868aad101284bc2123fc5e3b7dd9f512ddfa4d242fe035ee1b6cb54b81923
  • Pointer size: 131 Bytes
  • Size of remote file: 637 kB