Upload 10 files
Browse files- .gitattributes +2 -0
- accuracy_detail.png +3 -0
- app (1).py +154 -0
- food101_classes.txt +104 -0
- food101_classes_simple.txt +101 -0
- main (3).py +548 -0
- outputs.zip +3 -0
- outputs_food101_resnet50_final.pth +3 -0
- outputs_food101_resnet50_final_weights.pth +3 -0
- requirements (3).txt +6 -0
- training_analysis.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
accuracy_detail.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
training_analysis.png filter=lfs diff=lfs merge=lfs -text
|
accuracy_detail.png
ADDED
|
Git LFS Details
|
app (1).py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision.transforms as transforms
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
# Assuming ResNet50 class is defined in main.py or you copy it here
|
| 8 |
+
# For simplicity, I'll put a placeholder. In a real scenario, you'd import ResNet50
|
| 9 |
+
# from a separate models.py or main.py. For this example, let's assume it's available.
|
| 10 |
+
|
| 11 |
+
# --- ResNet50 Model Definition (copy-pasted from main.py for self-containment) ---
|
| 12 |
+
class Bottleneck(torch.nn.Module):
|
| 13 |
+
expansion = 4
|
| 14 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.conv1 = torch.nn.Conv2d(inplanes, planes, 1, bias=False)
|
| 17 |
+
self.bn1 = torch.nn.BatchNorm2d(planes)
|
| 18 |
+
self.conv2 = torch.nn.Conv2d(planes, planes, 3, stride, 1, bias=False)
|
| 19 |
+
self.bn2 = torch.nn.BatchNorm2d(planes)
|
| 20 |
+
self.conv3 = torch.nn.Conv2d(planes, planes*self.expansion, 1, bias=False)
|
| 21 |
+
self.bn3 = torch.nn.BatchNorm2d(planes*self.expansion)
|
| 22 |
+
self.relu = torch.nn.ReLU(inplace=True)
|
| 23 |
+
self.downsample = downsample
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
identity = x
|
| 27 |
+
out = self.conv1(x)
|
| 28 |
+
out = self.bn1(out)
|
| 29 |
+
out = self.relu(out)
|
| 30 |
+
out = self.conv2(out)
|
| 31 |
+
out = self.bn2(out)
|
| 32 |
+
out = self.relu(out)
|
| 33 |
+
out = self.conv3(out)
|
| 34 |
+
out = self.bn3(out)
|
| 35 |
+
if self.downsample: identity = self.downsample(x)
|
| 36 |
+
out += identity
|
| 37 |
+
out = self.relu(out)
|
| 38 |
+
return out
|
| 39 |
+
|
| 40 |
+
class ResNet50(torch.nn.Module):
|
| 41 |
+
def __init__(self, num_classes=101):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.inplanes = 64
|
| 44 |
+
|
| 45 |
+
self.conv1 = torch.nn.Conv2d(3, 64, 7, 2, 3, bias=False)
|
| 46 |
+
self.bn1 = torch.nn.BatchNorm2d(64)
|
| 47 |
+
self.relu = torch.nn.ReLU(inplace=True)
|
| 48 |
+
self.maxpool = torch.nn.MaxPool2d(3, 2, 1)
|
| 49 |
+
|
| 50 |
+
self.layer1 = self._make_layer(Bottleneck, 64, 3)
|
| 51 |
+
self.layer2 = self._make_layer(Bottleneck, 128, 4, 2)
|
| 52 |
+
self.layer3 = self._make_layer(Bottleneck, 256, 6, 2)
|
| 53 |
+
self.layer4 = self._make_layer(Bottleneck, 512, 3, 2)
|
| 54 |
+
|
| 55 |
+
self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
|
| 56 |
+
self.fc = torch.nn.Linear(512*Bottleneck.expansion, num_classes)
|
| 57 |
+
|
| 58 |
+
self._initialize_weights()
|
| 59 |
+
|
| 60 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 61 |
+
downsample = None
|
| 62 |
+
if stride != 1 or self.inplanes != planes*block.expansion:
|
| 63 |
+
downsample = torch.nn.Sequential(
|
| 64 |
+
torch.nn.Conv2d(self.inplanes, planes*block.expansion, 1, stride, bias=False),
|
| 65 |
+
torch.nn.BatchNorm2d(planes*block.expansion)
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
layers = [block(self.inplanes, planes, stride, downsample)]
|
| 69 |
+
self.inplanes = planes * block.expansion
|
| 70 |
+
for _ in range(1, blocks):
|
| 71 |
+
layers.append(block(self.inplanes, planes))
|
| 72 |
+
return torch.nn.Sequential(*layers)
|
| 73 |
+
|
| 74 |
+
def _initialize_weights(self):
|
| 75 |
+
for m in self.modules():
|
| 76 |
+
if isinstance(m, torch.nn.Conv2d):
|
| 77 |
+
torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 78 |
+
elif isinstance(m, torch.nn.BatchNorm2d):
|
| 79 |
+
torch.nn.init.constant_(m.weight, 1)
|
| 80 |
+
torch.nn.init.constant_(m.bias, 0)
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
x = self.conv1(x)
|
| 84 |
+
x = self.bn1(x)
|
| 85 |
+
x = self.relu(x)
|
| 86 |
+
x = self.maxpool(x)
|
| 87 |
+
|
| 88 |
+
x = self.layer1(x)
|
| 89 |
+
x = self.layer2(x)
|
| 90 |
+
x = self.layer3(x)
|
| 91 |
+
x = self.layer4(x)
|
| 92 |
+
|
| 93 |
+
x = self.avgpool(x)
|
| 94 |
+
x = torch.flatten(x, 1)
|
| 95 |
+
x = self.fc(x)
|
| 96 |
+
return x
|
| 97 |
+
# --- End ResNet50 Model Definition ---
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Load class names
|
| 101 |
+
with open('./outputs/food101_classes_simple.txt', 'r') as f:
|
| 102 |
+
class_names = [line.strip() for line in f]
|
| 103 |
+
|
| 104 |
+
num_classes = len(class_names)
|
| 105 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 106 |
+
|
| 107 |
+
# Load the model
|
| 108 |
+
model = ResNet50(num_classes=num_classes).to(device)
|
| 109 |
+
model_path = './outputs/food101_resnet50_final_weights.pth'
|
| 110 |
+
if not os.path.exists(model_path):
|
| 111 |
+
raise FileNotFoundError(f"Model weights not found at {model_path}. Please train the model first.")
|
| 112 |
+
|
| 113 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 114 |
+
model.eval()
|
| 115 |
+
|
| 116 |
+
# Define the image transformations
|
| 117 |
+
transform = transforms.Compose([
|
| 118 |
+
transforms.Resize((224, 224)),
|
| 119 |
+
transforms.ToTensor(),
|
| 120 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 121 |
+
])
|
| 122 |
+
|
| 123 |
+
def predict_image(image: Image.Image):
|
| 124 |
+
# Apply transformations
|
| 125 |
+
image = transform(image).unsqueeze(0).to(device)
|
| 126 |
+
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
outputs = model(image)
|
| 129 |
+
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
|
| 130 |
+
|
| 131 |
+
# Get top 5 predictions
|
| 132 |
+
top5_prob, top5_indices = torch.topk(probabilities, 5)
|
| 133 |
+
|
| 134 |
+
predictions = {class_names[idx]: round(prob.item() * 100, 2) for idx, prob in zip(top5_indices, top5_prob)}
|
| 135 |
+
|
| 136 |
+
return predictions
|
| 137 |
+
|
| 138 |
+
# Create Gradio interface
|
| 139 |
+
iface = gr.Interface(
|
| 140 |
+
fn=predict_image,
|
| 141 |
+
inputs=gr.Image(type="pil", label="Upload Food Image"),
|
| 142 |
+
outputs=gr.Label(num_top_classes=5),
|
| 143 |
+
title="Food101 ResNet50 Classifier",
|
| 144 |
+
description="Upload an image of food and get predictions for 101 food categories. Model trained on Food101 dataset.",
|
| 145 |
+
examples=[
|
| 146 |
+
# Add some example images here if you have them, e.g.,
|
| 147 |
+
# ["path/to/example_image1.jpg"],
|
| 148 |
+
# ["path/to/example_image2.jpg"],
|
| 149 |
+
]
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Launch the Gradio app
|
| 153 |
+
if __name__ == "__main__":
|
| 154 |
+
iface.launch(server_name="0.0.0.0", server_port=8000) # Use port 8000 for Lightning AI deployments
|
food101_classes.txt
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Food101 Classes (101 total)
|
| 2 |
+
==============================
|
| 3 |
+
|
| 4 |
+
1. Apple Pie
|
| 5 |
+
2. Baby Back Ribs
|
| 6 |
+
3. Baklava
|
| 7 |
+
4. Beef Carpaccio
|
| 8 |
+
5. Beef Tartare
|
| 9 |
+
6. Beet Salad
|
| 10 |
+
7. Beignets
|
| 11 |
+
8. Bibimbap
|
| 12 |
+
9. Bread Pudding
|
| 13 |
+
10. Breakfast Burrito
|
| 14 |
+
11. Bruschetta
|
| 15 |
+
12. Caesar Salad
|
| 16 |
+
13. Cannoli
|
| 17 |
+
14. Caprese Salad
|
| 18 |
+
15. Carrot Cake
|
| 19 |
+
16. Ceviche
|
| 20 |
+
17. Cheese Plate
|
| 21 |
+
18. Cheesecake
|
| 22 |
+
19. Chicken Curry
|
| 23 |
+
20. Chicken Quesadilla
|
| 24 |
+
21. Chicken Wings
|
| 25 |
+
22. Chocolate Cake
|
| 26 |
+
23. Chocolate Mousse
|
| 27 |
+
24. Churros
|
| 28 |
+
25. Clam Chowder
|
| 29 |
+
26. Club Sandwich
|
| 30 |
+
27. Crab Cakes
|
| 31 |
+
28. Creme Brulee
|
| 32 |
+
29. Croque Madame
|
| 33 |
+
30. Cup Cakes
|
| 34 |
+
31. Deviled Eggs
|
| 35 |
+
32. Donuts
|
| 36 |
+
33. Dumplings
|
| 37 |
+
34. Edamame
|
| 38 |
+
35. Eggs Benedict
|
| 39 |
+
36. Escargots
|
| 40 |
+
37. Falafel
|
| 41 |
+
38. Filet Mignon
|
| 42 |
+
39. Fish And Chips
|
| 43 |
+
40. Foie Gras
|
| 44 |
+
41. French Fries
|
| 45 |
+
42. French Onion Soup
|
| 46 |
+
43. French Toast
|
| 47 |
+
44. Fried Calamari
|
| 48 |
+
45. Fried Rice
|
| 49 |
+
46. Frozen Yogurt
|
| 50 |
+
47. Garlic Bread
|
| 51 |
+
48. Gnocchi
|
| 52 |
+
49. Greek Salad
|
| 53 |
+
50. Grilled Cheese Sandwich
|
| 54 |
+
51. Grilled Salmon
|
| 55 |
+
52. Guacamole
|
| 56 |
+
53. Gyoza
|
| 57 |
+
54. Hamburger
|
| 58 |
+
55. Hot And Sour Soup
|
| 59 |
+
56. Hot Dog
|
| 60 |
+
57. Huevos Rancheros
|
| 61 |
+
58. Hummus
|
| 62 |
+
59. Ice Cream
|
| 63 |
+
60. Lasagna
|
| 64 |
+
61. Lobster Bisque
|
| 65 |
+
62. Lobster Roll Sandwich
|
| 66 |
+
63. Macaroni And Cheese
|
| 67 |
+
64. Macarons
|
| 68 |
+
65. Miso Soup
|
| 69 |
+
66. Mussels
|
| 70 |
+
67. Nachos
|
| 71 |
+
68. Omelette
|
| 72 |
+
69. Onion Rings
|
| 73 |
+
70. Oysters
|
| 74 |
+
71. Pad Thai
|
| 75 |
+
72. Paella
|
| 76 |
+
73. Pancakes
|
| 77 |
+
74. Panna Cotta
|
| 78 |
+
75. Peking Duck
|
| 79 |
+
76. Pho
|
| 80 |
+
77. Pizza
|
| 81 |
+
78. Pork Chop
|
| 82 |
+
79. Poutine
|
| 83 |
+
80. Prime Rib
|
| 84 |
+
81. Pulled Pork Sandwich
|
| 85 |
+
82. Ramen
|
| 86 |
+
83. Ravioli
|
| 87 |
+
84. Red Velvet Cake
|
| 88 |
+
85. Risotto
|
| 89 |
+
86. Samosa
|
| 90 |
+
87. Sashimi
|
| 91 |
+
88. Scallops
|
| 92 |
+
89. Seaweed Salad
|
| 93 |
+
90. Shrimp And Grits
|
| 94 |
+
91. Spaghetti Bolognese
|
| 95 |
+
92. Spaghetti Carbonara
|
| 96 |
+
93. Spring Rolls
|
| 97 |
+
94. Steak
|
| 98 |
+
95. Strawberry Shortcake
|
| 99 |
+
96. Sushi
|
| 100 |
+
97. Tacos
|
| 101 |
+
98. Takoyaki
|
| 102 |
+
99. Tiramisu
|
| 103 |
+
100. Tuna Tartare
|
| 104 |
+
101. Waffles
|
food101_classes_simple.txt
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
apple_pie
|
| 2 |
+
baby_back_ribs
|
| 3 |
+
baklava
|
| 4 |
+
beef_carpaccio
|
| 5 |
+
beef_tartare
|
| 6 |
+
beet_salad
|
| 7 |
+
beignets
|
| 8 |
+
bibimbap
|
| 9 |
+
bread_pudding
|
| 10 |
+
breakfast_burrito
|
| 11 |
+
bruschetta
|
| 12 |
+
caesar_salad
|
| 13 |
+
cannoli
|
| 14 |
+
caprese_salad
|
| 15 |
+
carrot_cake
|
| 16 |
+
ceviche
|
| 17 |
+
cheese_plate
|
| 18 |
+
cheesecake
|
| 19 |
+
chicken_curry
|
| 20 |
+
chicken_quesadilla
|
| 21 |
+
chicken_wings
|
| 22 |
+
chocolate_cake
|
| 23 |
+
chocolate_mousse
|
| 24 |
+
churros
|
| 25 |
+
clam_chowder
|
| 26 |
+
club_sandwich
|
| 27 |
+
crab_cakes
|
| 28 |
+
creme_brulee
|
| 29 |
+
croque_madame
|
| 30 |
+
cup_cakes
|
| 31 |
+
deviled_eggs
|
| 32 |
+
donuts
|
| 33 |
+
dumplings
|
| 34 |
+
edamame
|
| 35 |
+
eggs_benedict
|
| 36 |
+
escargots
|
| 37 |
+
falafel
|
| 38 |
+
filet_mignon
|
| 39 |
+
fish_and_chips
|
| 40 |
+
foie_gras
|
| 41 |
+
french_fries
|
| 42 |
+
french_onion_soup
|
| 43 |
+
french_toast
|
| 44 |
+
fried_calamari
|
| 45 |
+
fried_rice
|
| 46 |
+
frozen_yogurt
|
| 47 |
+
garlic_bread
|
| 48 |
+
gnocchi
|
| 49 |
+
greek_salad
|
| 50 |
+
grilled_cheese_sandwich
|
| 51 |
+
grilled_salmon
|
| 52 |
+
guacamole
|
| 53 |
+
gyoza
|
| 54 |
+
hamburger
|
| 55 |
+
hot_and_sour_soup
|
| 56 |
+
hot_dog
|
| 57 |
+
huevos_rancheros
|
| 58 |
+
hummus
|
| 59 |
+
ice_cream
|
| 60 |
+
lasagna
|
| 61 |
+
lobster_bisque
|
| 62 |
+
lobster_roll_sandwich
|
| 63 |
+
macaroni_and_cheese
|
| 64 |
+
macarons
|
| 65 |
+
miso_soup
|
| 66 |
+
mussels
|
| 67 |
+
nachos
|
| 68 |
+
omelette
|
| 69 |
+
onion_rings
|
| 70 |
+
oysters
|
| 71 |
+
pad_thai
|
| 72 |
+
paella
|
| 73 |
+
pancakes
|
| 74 |
+
panna_cotta
|
| 75 |
+
peking_duck
|
| 76 |
+
pho
|
| 77 |
+
pizza
|
| 78 |
+
pork_chop
|
| 79 |
+
poutine
|
| 80 |
+
prime_rib
|
| 81 |
+
pulled_pork_sandwich
|
| 82 |
+
ramen
|
| 83 |
+
ravioli
|
| 84 |
+
red_velvet_cake
|
| 85 |
+
risotto
|
| 86 |
+
samosa
|
| 87 |
+
sashimi
|
| 88 |
+
scallops
|
| 89 |
+
seaweed_salad
|
| 90 |
+
shrimp_and_grits
|
| 91 |
+
spaghetti_bolognese
|
| 92 |
+
spaghetti_carbonara
|
| 93 |
+
spring_rolls
|
| 94 |
+
steak
|
| 95 |
+
strawberry_shortcake
|
| 96 |
+
sushi
|
| 97 |
+
tacos
|
| 98 |
+
takoyaki
|
| 99 |
+
tiramisu
|
| 100 |
+
tuna_tartare
|
| 101 |
+
waffles
|
main (3).py
ADDED
|
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
β
OPTIMIZED Food101 + ResNet50 with major speed improvements
|
| 4 |
+
β
Mixed precision training (2x faster)
|
| 5 |
+
β
Better data loading (persistent workers)
|
| 6 |
+
β
Progress bars and better logging
|
| 7 |
+
β
Robust error handling and checkpointing
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import time
|
| 12 |
+
import copy
|
| 13 |
+
import numpy as np
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.optim as optim
|
| 21 |
+
import torchvision
|
| 22 |
+
import torchvision.transforms as transforms
|
| 23 |
+
from torch.utils.data import DataLoader
|
| 24 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 25 |
+
|
| 26 |
+
# Setup logging
|
| 27 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
# -------------------------
|
| 31 |
+
# OPTIMIZED Data Loaders
|
| 32 |
+
# -------------------------
|
| 33 |
+
def get_food101_loaders(batch_size=64, num_workers=8): # Increased batch size and workers
|
| 34 |
+
"""Returns optimized train/val/test loaders + class names"""
|
| 35 |
+
|
| 36 |
+
# More aggressive data augmentation
|
| 37 |
+
transform_train = transforms.Compose([
|
| 38 |
+
transforms.Resize((256, 256)), # Resize larger first
|
| 39 |
+
transforms.RandomCrop((224, 224)), # Then crop to avoid distortion
|
| 40 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
| 41 |
+
transforms.RandomRotation(15),
|
| 42 |
+
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
|
| 43 |
+
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
|
| 44 |
+
transforms.ToTensor(),
|
| 45 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 46 |
+
])
|
| 47 |
+
|
| 48 |
+
transform_test = transforms.Compose([
|
| 49 |
+
transforms.Resize((224, 224)),
|
| 50 |
+
transforms.ToTensor(),
|
| 51 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 52 |
+
])
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
# Full train split (75k images)
|
| 56 |
+
full_train = torchvision.datasets.Food101(
|
| 57 |
+
root='./data', split='train', download=True, transform=transform_train
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# 90/10 train/val split with fixed seed for reproducibility
|
| 61 |
+
torch.manual_seed(42)
|
| 62 |
+
train_size = int(0.9 * len(full_train))
|
| 63 |
+
val_size = len(full_train) - train_size
|
| 64 |
+
train_dataset, val_dataset = torch.utils.data.random_split(
|
| 65 |
+
full_train, [train_size, val_size]
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# Test split (25k images)
|
| 69 |
+
test_dataset = torchvision.datasets.Food101(
|
| 70 |
+
root='./data', split='test', download=True, transform=transform_test
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
logger.info(f"Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
|
| 74 |
+
|
| 75 |
+
# Optimized DataLoaders with persistent workers
|
| 76 |
+
train_loader = DataLoader(
|
| 77 |
+
train_dataset, batch_size, shuffle=True, num_workers=num_workers,
|
| 78 |
+
pin_memory=True, persistent_workers=True, drop_last=True
|
| 79 |
+
)
|
| 80 |
+
val_loader = DataLoader(
|
| 81 |
+
val_dataset, batch_size, shuffle=False, num_workers=num_workers,
|
| 82 |
+
pin_memory=True, persistent_workers=True
|
| 83 |
+
)
|
| 84 |
+
test_loader = DataLoader(
|
| 85 |
+
test_dataset, batch_size, shuffle=False, num_workers=num_workers,
|
| 86 |
+
pin_memory=True, persistent_workers=True
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
return train_loader, val_loader, test_loader, full_train.classes
|
| 90 |
+
|
| 91 |
+
except Exception as e:
|
| 92 |
+
logger.error(f"Error loading data: {e}")
|
| 93 |
+
raise
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# -------------------------
|
| 97 |
+
# ResNet Building Blocks (same as original but with better initialization)
|
| 98 |
+
# -------------------------
|
| 99 |
+
class BasicBlock(nn.Module):
|
| 100 |
+
expansion = 1
|
| 101 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 102 |
+
super().__init__()
|
| 103 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 3, stride, 1, bias=False)
|
| 104 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 105 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False)
|
| 106 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 107 |
+
self.relu = nn.ReLU(inplace=True)
|
| 108 |
+
self.downsample = downsample
|
| 109 |
+
|
| 110 |
+
def forward(self, x):
|
| 111 |
+
identity = x
|
| 112 |
+
out = self.conv1(x)
|
| 113 |
+
out = self.bn1(out)
|
| 114 |
+
out = self.relu(out)
|
| 115 |
+
out = self.conv2(out)
|
| 116 |
+
out = self.bn2(out)
|
| 117 |
+
if self.downsample: identity = self.downsample(x)
|
| 118 |
+
out += identity
|
| 119 |
+
out = self.relu(out)
|
| 120 |
+
return out
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class Bottleneck(nn.Module):
|
| 124 |
+
expansion = 4
|
| 125 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 126 |
+
super().__init__()
|
| 127 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
| 128 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 129 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, bias=False)
|
| 130 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 131 |
+
self.conv3 = nn.Conv2d(planes, planes*self.expansion, 1, bias=False)
|
| 132 |
+
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
| 133 |
+
self.relu = nn.ReLU(inplace=True)
|
| 134 |
+
self.downsample = downsample
|
| 135 |
+
|
| 136 |
+
def forward(self, x):
|
| 137 |
+
identity = x
|
| 138 |
+
out = self.conv1(x)
|
| 139 |
+
out = self.bn1(out)
|
| 140 |
+
out = self.relu(out)
|
| 141 |
+
out = self.conv2(out)
|
| 142 |
+
out = self.bn2(out)
|
| 143 |
+
out = self.relu(out)
|
| 144 |
+
out = self.conv3(out)
|
| 145 |
+
out = self.bn3(out)
|
| 146 |
+
if self.downsample: identity = self.downsample(x)
|
| 147 |
+
out += identity
|
| 148 |
+
out = self.relu(out)
|
| 149 |
+
return out
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class ResNet50(nn.Module):
|
| 153 |
+
def __init__(self, num_classes=101):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.inplanes = 64
|
| 156 |
+
|
| 157 |
+
self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
|
| 158 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 159 |
+
self.relu = nn.ReLU(inplace=True)
|
| 160 |
+
self.maxpool = nn.MaxPool2d(3, 2, 1)
|
| 161 |
+
|
| 162 |
+
self.layer1 = self._make_layer(Bottleneck, 64, 3)
|
| 163 |
+
self.layer2 = self._make_layer(Bottleneck, 128, 4, 2)
|
| 164 |
+
self.layer3 = self._make_layer(Bottleneck, 256, 6, 2)
|
| 165 |
+
self.layer4 = self._make_layer(Bottleneck, 512, 3, 2)
|
| 166 |
+
|
| 167 |
+
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
| 168 |
+
self.fc = nn.Linear(512*Bottleneck.expansion, num_classes)
|
| 169 |
+
|
| 170 |
+
# Better initialization
|
| 171 |
+
self._initialize_weights()
|
| 172 |
+
|
| 173 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 174 |
+
downsample = None
|
| 175 |
+
if stride != 1 or self.inplanes != planes*block.expansion:
|
| 176 |
+
downsample = nn.Sequential(
|
| 177 |
+
nn.Conv2d(self.inplanes, planes*block.expansion, 1, stride, bias=False),
|
| 178 |
+
nn.BatchNorm2d(planes*block.expansion)
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
layers = [block(self.inplanes, planes, stride, downsample)]
|
| 182 |
+
self.inplanes = planes * block.expansion
|
| 183 |
+
for _ in range(1, blocks):
|
| 184 |
+
layers.append(block(self.inplanes, planes))
|
| 185 |
+
return nn.Sequential(*layers)
|
| 186 |
+
|
| 187 |
+
def _initialize_weights(self):
|
| 188 |
+
for m in self.modules():
|
| 189 |
+
if isinstance(m, nn.Conv2d):
|
| 190 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 191 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 192 |
+
nn.init.constant_(m.weight, 1)
|
| 193 |
+
nn.init.constant_(m.bias, 0)
|
| 194 |
+
|
| 195 |
+
def forward(self, x):
|
| 196 |
+
x = self.conv1(x)
|
| 197 |
+
x = self.bn1(x)
|
| 198 |
+
x = self.relu(x)
|
| 199 |
+
x = self.maxpool(x)
|
| 200 |
+
|
| 201 |
+
x = self.layer1(x)
|
| 202 |
+
x = self.layer2(x)
|
| 203 |
+
x = self.layer3(x)
|
| 204 |
+
x = self.layer4(x)
|
| 205 |
+
|
| 206 |
+
x = self.avgpool(x)
|
| 207 |
+
x = torch.flatten(x, 1)
|
| 208 |
+
x = self.fc(x)
|
| 209 |
+
return x
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# -------------------------
|
| 213 |
+
# OPTIMIZED Training Function with Mixed Precision
|
| 214 |
+
# -------------------------
|
| 215 |
+
def train_model(model, train_loader, val_loader, test_loader, device, num_epochs=100, resume_from=None):
|
| 216 |
+
"""Optimized training loop with mixed precision and better checkpointing"""
|
| 217 |
+
|
| 218 |
+
os.makedirs('./outputs', exist_ok=True)
|
| 219 |
+
|
| 220 |
+
criterion = nn.CrossEntropyLoss(label_smoothing=0.1) # Label smoothing for better generalization
|
| 221 |
+
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4, nesterov=True)
|
| 222 |
+
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
|
| 223 |
+
|
| 224 |
+
# Mixed precision scaler
|
| 225 |
+
scaler = GradScaler()
|
| 226 |
+
|
| 227 |
+
best_val_acc = 0.0
|
| 228 |
+
train_losses, val_accuracies, learning_rates = [], [], []
|
| 229 |
+
start_epoch = 0
|
| 230 |
+
|
| 231 |
+
# Resume from checkpoint if provided
|
| 232 |
+
if resume_from and os.path.exists(resume_from):
|
| 233 |
+
logger.info(f"Resuming from {resume_from}")
|
| 234 |
+
checkpoint = torch.load(resume_from, map_location=device)
|
| 235 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 236 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 237 |
+
start_epoch = checkpoint['epoch']
|
| 238 |
+
best_val_acc = checkpoint.get('best_val_accuracy', 0.0)
|
| 239 |
+
train_losses = checkpoint.get('train_losses', [])
|
| 240 |
+
val_accuracies = checkpoint.get('val_accuracies', [])
|
| 241 |
+
learning_rates = checkpoint.get('learning_rates', [])
|
| 242 |
+
|
| 243 |
+
logger.info(f"π Starting training from epoch {start_epoch+1} for {num_epochs} total epochs...")
|
| 244 |
+
|
| 245 |
+
# Track timing
|
| 246 |
+
total_train_time = 0
|
| 247 |
+
|
| 248 |
+
for epoch in range(start_epoch, num_epochs):
|
| 249 |
+
epoch_start = time.time()
|
| 250 |
+
|
| 251 |
+
# Training phase
|
| 252 |
+
model.train()
|
| 253 |
+
running_loss = 0.0
|
| 254 |
+
train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]', leave=False)
|
| 255 |
+
|
| 256 |
+
for batch_idx, (images, labels) in enumerate(train_pbar):
|
| 257 |
+
images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
|
| 258 |
+
|
| 259 |
+
optimizer.zero_grad()
|
| 260 |
+
|
| 261 |
+
# Mixed precision forward pass
|
| 262 |
+
with autocast():
|
| 263 |
+
outputs = model(images)
|
| 264 |
+
loss = criterion(outputs, labels)
|
| 265 |
+
|
| 266 |
+
# Mixed precision backward pass
|
| 267 |
+
scaler.scale(loss).backward()
|
| 268 |
+
scaler.step(optimizer)
|
| 269 |
+
scaler.update()
|
| 270 |
+
|
| 271 |
+
running_loss += loss.item()
|
| 272 |
+
train_pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{optimizer.param_groups[0]["lr"]:.6f}'})
|
| 273 |
+
|
| 274 |
+
avg_train_loss = running_loss / len(train_loader)
|
| 275 |
+
train_losses.append(avg_train_loss)
|
| 276 |
+
learning_rates.append(optimizer.param_groups[0]['lr'])
|
| 277 |
+
|
| 278 |
+
# Validation phase
|
| 279 |
+
model.eval()
|
| 280 |
+
val_loss = 0.0
|
| 281 |
+
correct = 0
|
| 282 |
+
total = 0
|
| 283 |
+
val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]', leave=False)
|
| 284 |
+
|
| 285 |
+
with torch.no_grad():
|
| 286 |
+
for images, labels in val_pbar:
|
| 287 |
+
images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
|
| 288 |
+
|
| 289 |
+
with autocast():
|
| 290 |
+
outputs = model(images)
|
| 291 |
+
loss = criterion(outputs, labels)
|
| 292 |
+
|
| 293 |
+
val_loss += loss.item()
|
| 294 |
+
_, predicted = torch.max(outputs, 1)
|
| 295 |
+
total += labels.size(0)
|
| 296 |
+
correct += (predicted == labels).sum().item()
|
| 297 |
+
|
| 298 |
+
val_pbar.set_postfix({'acc': f'{100.*correct/total:.2f}%'})
|
| 299 |
+
|
| 300 |
+
val_acc = 100. * correct / total
|
| 301 |
+
val_accuracies.append(val_acc)
|
| 302 |
+
avg_val_loss = val_loss / len(val_loader)
|
| 303 |
+
|
| 304 |
+
# Save best model
|
| 305 |
+
is_best = val_acc > best_val_acc
|
| 306 |
+
if is_best:
|
| 307 |
+
best_val_acc = val_acc
|
| 308 |
+
|
| 309 |
+
# Save checkpoint every 10 epochs and if best
|
| 310 |
+
if (epoch + 1) % 10 == 0 or is_best or epoch == num_epochs - 1:
|
| 311 |
+
checkpoint = {
|
| 312 |
+
'epoch': epoch + 1,
|
| 313 |
+
'model_state_dict': model.state_dict(),
|
| 314 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 315 |
+
'scaler_state_dict': scaler.state_dict(),
|
| 316 |
+
'best_val_accuracy': best_val_acc,
|
| 317 |
+
'current_val_accuracy': val_acc,
|
| 318 |
+
'train_losses': train_losses,
|
| 319 |
+
'val_accuracies': val_accuracies,
|
| 320 |
+
'learning_rates': learning_rates,
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
if is_best:
|
| 324 |
+
torch.save(checkpoint, './outputs/food101_resnet50_best.pth')
|
| 325 |
+
# Save just the weights for easier loading
|
| 326 |
+
torch.save(model.state_dict(), './outputs/food101_resnet50_best_weights.pth')
|
| 327 |
+
|
| 328 |
+
if (epoch + 1) % 10 == 0:
|
| 329 |
+
torch.save(checkpoint, f'./outputs/food101_resnet50_epoch_{epoch+1}.pth')
|
| 330 |
+
|
| 331 |
+
scheduler.step()
|
| 332 |
+
epoch_time = time.time() - epoch_start
|
| 333 |
+
total_train_time += epoch_time
|
| 334 |
+
|
| 335 |
+
logger.info(f"Epoch {epoch+1:3d}/{num_epochs} | "
|
| 336 |
+
f"Train Loss: {avg_train_loss:.4f} | "
|
| 337 |
+
f"Val Loss: {avg_val_loss:.4f} | "
|
| 338 |
+
f"Val Acc: {val_acc:.2f}% | "
|
| 339 |
+
f"Best: {best_val_acc:.2f}% | "
|
| 340 |
+
f"LR: {optimizer.param_groups[0]['lr']:.6f} | "
|
| 341 |
+
f"Time: {epoch_time:.1f}s")
|
| 342 |
+
|
| 343 |
+
# Save final model
|
| 344 |
+
final_checkpoint = {
|
| 345 |
+
'epoch': num_epochs,
|
| 346 |
+
'model_state_dict': model.state_dict(),
|
| 347 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 348 |
+
'scaler_state_dict': scaler.state_dict(),
|
| 349 |
+
'final_val_accuracy': val_accuracies[-1],
|
| 350 |
+
'best_val_accuracy': best_val_acc,
|
| 351 |
+
'train_losses': train_losses,
|
| 352 |
+
'val_accuracies': val_accuracies,
|
| 353 |
+
'learning_rates': learning_rates,
|
| 354 |
+
'total_train_time': total_train_time,
|
| 355 |
+
}
|
| 356 |
+
torch.save(final_checkpoint, './outputs/food101_resnet50_final.pth')
|
| 357 |
+
torch.save(model.state_dict(), './outputs/food101_resnet50_final_weights.pth')
|
| 358 |
+
|
| 359 |
+
logger.info(f"π Total training time: {total_train_time/3600:.2f} hours")
|
| 360 |
+
|
| 361 |
+
# Test final accuracy
|
| 362 |
+
test_acc = evaluate_model(model, test_loader, device, "Test")
|
| 363 |
+
logger.info(f"π― Final Test Accuracy: {test_acc:.2f}%")
|
| 364 |
+
|
| 365 |
+
# Save comprehensive plots
|
| 366 |
+
plot_training_curves(train_losses, val_accuracies, learning_rates)
|
| 367 |
+
|
| 368 |
+
return best_val_acc, train_losses, val_accuracies
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def evaluate_model(model, test_loader, device, dataset_name="Test"):
|
| 372 |
+
"""Evaluate model with progress bar"""
|
| 373 |
+
model.eval()
|
| 374 |
+
correct = 0
|
| 375 |
+
total = 0
|
| 376 |
+
test_pbar = tqdm(test_loader, desc=f'{dataset_name} Evaluation', leave=False)
|
| 377 |
+
|
| 378 |
+
with torch.no_grad():
|
| 379 |
+
for images, labels in test_pbar:
|
| 380 |
+
images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
|
| 381 |
+
|
| 382 |
+
with autocast():
|
| 383 |
+
outputs = model(images)
|
| 384 |
+
|
| 385 |
+
_, predicted = torch.max(outputs, 1)
|
| 386 |
+
total += labels.size(0)
|
| 387 |
+
correct += (predicted == labels).sum().item()
|
| 388 |
+
|
| 389 |
+
test_pbar.set_postfix({'acc': f'{100.*correct/total:.2f}%'})
|
| 390 |
+
|
| 391 |
+
return 100. * correct / total
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def plot_training_curves(train_losses, val_accuracies, learning_rates):
|
| 395 |
+
"""Enhanced plotting with more visualizations"""
|
| 396 |
+
epochs = np.arange(1, len(train_losses) + 1)
|
| 397 |
+
|
| 398 |
+
plt.style.use('default')
|
| 399 |
+
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
|
| 400 |
+
fig.suptitle('Food101 ResNet50 Training Analysis', fontsize=16, fontweight='bold')
|
| 401 |
+
|
| 402 |
+
# Training Loss
|
| 403 |
+
axes[0, 0].plot(epochs, train_losses, 'b-', linewidth=2, alpha=0.8)
|
| 404 |
+
axes[0, 0].set_title('Training Loss Over Time', fontweight='bold')
|
| 405 |
+
axes[0, 0].set_xlabel('Epoch')
|
| 406 |
+
axes[0, 0].set_ylabel('Loss')
|
| 407 |
+
axes[0, 0].grid(True, alpha=0.3)
|
| 408 |
+
axes[0, 0].set_yscale('log')
|
| 409 |
+
|
| 410 |
+
# Validation Accuracy
|
| 411 |
+
axes[0, 1].plot(epochs, val_accuracies, 'r-', linewidth=2, alpha=0.8)
|
| 412 |
+
axes[0, 1].set_title('Validation Accuracy Over Time', fontweight='bold')
|
| 413 |
+
axes[0, 1].set_xlabel('Epoch')
|
| 414 |
+
axes[0, 1].set_ylabel('Accuracy (%)')
|
| 415 |
+
axes[0, 1].grid(True, alpha=0.3)
|
| 416 |
+
axes[0, 1].axhline(y=max(val_accuracies), color='r', linestyle='--', alpha=0.7,
|
| 417 |
+
label=f'Best: {max(val_accuracies):.2f}%')
|
| 418 |
+
axes[0, 1].legend()
|
| 419 |
+
|
| 420 |
+
# Learning Rate Schedule
|
| 421 |
+
axes[1, 0].plot(epochs, learning_rates, 'g-', linewidth=2, alpha=0.8)
|
| 422 |
+
axes[1, 0].set_title('Learning Rate Schedule', fontweight='bold')
|
| 423 |
+
axes[1, 0].set_xlabel('Epoch')
|
| 424 |
+
axes[1, 0].set_ylabel('Learning Rate')
|
| 425 |
+
axes[1, 0].grid(True, alpha=0.3)
|
| 426 |
+
axes[1, 0].set_yscale('log')
|
| 427 |
+
|
| 428 |
+
# Combined view
|
| 429 |
+
ax_combined = axes[1, 1]
|
| 430 |
+
ax_combined.plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=2, alpha=0.8)
|
| 431 |
+
ax_combined.set_xlabel('Epoch')
|
| 432 |
+
ax_combined.set_ylabel('Loss', color='b')
|
| 433 |
+
ax_combined.tick_params(axis='y', labelcolor='b')
|
| 434 |
+
ax_combined.set_yscale('log')
|
| 435 |
+
|
| 436 |
+
ax2 = ax_combined.twinx()
|
| 437 |
+
ax2.plot(epochs, val_accuracies, 'r-', label='Val Accuracy', linewidth=2, alpha=0.8)
|
| 438 |
+
ax2.set_ylabel('Accuracy (%)', color='r')
|
| 439 |
+
ax2.tick_params(axis='y', labelcolor='r')
|
| 440 |
+
|
| 441 |
+
ax_combined.set_title('Loss vs Accuracy', fontweight='bold')
|
| 442 |
+
ax_combined.grid(True, alpha=0.3)
|
| 443 |
+
|
| 444 |
+
plt.tight_layout()
|
| 445 |
+
plt.savefig('./outputs/training_analysis.png', dpi=300, bbox_inches='tight')
|
| 446 |
+
plt.close()
|
| 447 |
+
|
| 448 |
+
# Additional detailed accuracy plot
|
| 449 |
+
plt.figure(figsize=(12, 6))
|
| 450 |
+
plt.plot(epochs, val_accuracies, 'r-', linewidth=2, alpha=0.8)
|
| 451 |
+
plt.fill_between(epochs, val_accuracies, alpha=0.3)
|
| 452 |
+
plt.title('Validation Accuracy Progress', fontsize=14, fontweight='bold')
|
| 453 |
+
plt.xlabel('Epoch')
|
| 454 |
+
plt.ylabel('Accuracy (%)')
|
| 455 |
+
plt.grid(True, alpha=0.3)
|
| 456 |
+
plt.axhline(y=max(val_accuracies), color='r', linestyle='--', alpha=0.7,
|
| 457 |
+
label=f'Peak Accuracy: {max(val_accuracies):.2f}%')
|
| 458 |
+
plt.legend()
|
| 459 |
+
plt.tight_layout()
|
| 460 |
+
plt.savefig('./outputs/accuracy_detail.png', dpi=300, bbox_inches='tight')
|
| 461 |
+
plt.close()
|
| 462 |
+
|
| 463 |
+
logger.info("π Saved enhanced training visualizations")
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def save_classes(classes):
|
| 467 |
+
"""Save Food101 class names with better formatting"""
|
| 468 |
+
os.makedirs('./outputs', exist_ok=True)
|
| 469 |
+
|
| 470 |
+
with open('./outputs/food101_classes.txt', 'w') as f:
|
| 471 |
+
f.write("Food101 Classes (101 total)\n")
|
| 472 |
+
f.write("=" * 30 + "\n\n")
|
| 473 |
+
for i, cls in enumerate(sorted(classes), 1):
|
| 474 |
+
f.write(f"{i:3d}. {cls.replace('_', ' ').title()}\n")
|
| 475 |
+
|
| 476 |
+
# Also save as a simple list for easy loading
|
| 477 |
+
with open('./outputs/food101_classes_simple.txt', 'w') as f:
|
| 478 |
+
for cls in sorted(classes):
|
| 479 |
+
f.write(f"{cls}\n")
|
| 480 |
+
|
| 481 |
+
logger.info("π Saved class names to ./outputs/")
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def print_system_info():
|
| 485 |
+
"""Print system information for debugging"""
|
| 486 |
+
logger.info("π₯οΈ System Information:")
|
| 487 |
+
logger.info(f"PyTorch version: {torch.__version__}")
|
| 488 |
+
logger.info(f"CUDA available: {torch.cuda.is_available()}")
|
| 489 |
+
if torch.cuda.is_available():
|
| 490 |
+
logger.info(f"CUDA version: {torch.version.cuda}")
|
| 491 |
+
logger.info(f"GPU: {torch.cuda.get_device_name()}")
|
| 492 |
+
logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
| 493 |
+
logger.info(f"Number of CPU cores: {os.cpu_count()}")
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
# -------------------------
|
| 497 |
+
# MAIN
|
| 498 |
+
# -------------------------
|
| 499 |
+
def main():
|
| 500 |
+
print_system_info()
|
| 501 |
+
|
| 502 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 503 |
+
logger.info(f"Using device: {device}")
|
| 504 |
+
|
| 505 |
+
try:
|
| 506 |
+
# Load data with optimized settings
|
| 507 |
+
logger.info("π₯ Loading Food101 dataset...")
|
| 508 |
+
train_loader, val_loader, test_loader, classes = get_food101_loaders(batch_size=64, num_workers=8)
|
| 509 |
+
save_classes(classes)
|
| 510 |
+
|
| 511 |
+
# Model
|
| 512 |
+
logger.info("ποΈ Building ResNet50...")
|
| 513 |
+
model = ResNet50(num_classes=101).to(device)
|
| 514 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 515 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 516 |
+
logger.info(f"Total parameters: {total_params/1e6:.1f}M")
|
| 517 |
+
logger.info(f"Trainable parameters: {trainable_params/1e6:.1f}M")
|
| 518 |
+
|
| 519 |
+
# Enable compilation for PyTorch 2.0+
|
| 520 |
+
if hasattr(torch, 'compile'):
|
| 521 |
+
logger.info("π Compiling model for faster training...")
|
| 522 |
+
model = torch.compile(model)
|
| 523 |
+
|
| 524 |
+
# Train
|
| 525 |
+
best_val_acc, losses, accuracies = train_model(
|
| 526 |
+
model, train_loader, val_loader, test_loader, device,
|
| 527 |
+
num_epochs=100, resume_from='./outputs/food101_resnet50_best.pth' if os.path.exists('./outputs/food101_resnet50_best.pth') else None
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
logger.info(f"\nπ TRAINING COMPLETE!")
|
| 531 |
+
logger.info(f"π Best Validation Accuracy: {best_val_acc:.2f}%")
|
| 532 |
+
logger.info(f"\nπ SAVED FILES:")
|
| 533 |
+
logger.info(f" β’ ./outputs/food101_resnet50_best.pth (best checkpoint)")
|
| 534 |
+
logger.info(f" β’ ./outputs/food101_resnet50_best_weights.pth (best weights only)")
|
| 535 |
+
logger.info(f" β’ ./outputs/food101_resnet50_final.pth (final checkpoint)")
|
| 536 |
+
logger.info(f" β’ ./outputs/food101_resnet50_final_weights.pth (final weights only)")
|
| 537 |
+
logger.info(f" β’ ./outputs/training_analysis.png (comprehensive plots)")
|
| 538 |
+
logger.info(f" β’ ./outputs/accuracy_detail.png (detailed accuracy)")
|
| 539 |
+
logger.info(f" β’ ./outputs/food101_classes.txt (formatted class list)")
|
| 540 |
+
logger.info(f" β’ ./outputs/food101_classes_simple.txt (simple class list)")
|
| 541 |
+
|
| 542 |
+
except Exception as e:
|
| 543 |
+
logger.error(f"β Training failed with error: {e}")
|
| 544 |
+
raise
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
if __name__ == "__main__":
|
| 548 |
+
main()
|
outputs.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bc29b904cd5db7b7f53148c7a2e796b8b406be107124127d03f74800f889b114
|
| 3 |
+
size 2294595563
|
outputs_food101_resnet50_final.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7b1fde9993b9a56d998e3e420fd6ce7a82a2d94cfbbbb3b9e953d7ae347b5460
|
| 3 |
+
size 190101638
|
outputs_food101_resnet50_final_weights.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6a39d616ca1f70273efcb80ad5aeea2544be752c9696dd49b1b27af5864ed924
|
| 3 |
+
size 95190217
|
requirements (3).txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
torchvision>=0.15.0
|
| 3 |
+
matplotlib>=3.5.0
|
| 4 |
+
numpy>=1.21.0
|
| 5 |
+
Pillow>=8.0.0
|
| 6 |
+
tqdm>=4.64.0
|
training_analysis.png
ADDED
|
Git LFS Details
|