File size: 5,596 Bytes
7a59d7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
import os

# Assuming ResNet50 class is defined in main.py or you copy it here
# For simplicity, I'll put a placeholder. In a real scenario, you'd import ResNet50
# from a separate models.py or main.py. For this example, let's assume it's available.

# --- ResNet50 Model Definition (copy-pasted from main.py for self-containment) ---
class Bottleneck(torch.nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(inplanes, planes, 1, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(planes)
        self.conv2 = torch.nn.Conv2d(planes, planes, 3, stride, 1, bias=False)
        self.bn2 = torch.nn.BatchNorm2d(planes)
        self.conv3 = torch.nn.Conv2d(planes, planes*self.expansion, 1, bias=False)
        self.bn3 = torch.nn.BatchNorm2d(planes*self.expansion)
        self.relu = torch.nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample: identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

class ResNet50(torch.nn.Module):
    def __init__(self, num_classes=101):
        super().__init__()
        self.inplanes = 64
        
        self.conv1 = torch.nn.Conv2d(3, 64, 7, 2, 3, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(64)
        self.relu = torch.nn.ReLU(inplace=True)
        self.maxpool = torch.nn.MaxPool2d(3, 2, 1)
        
        self.layer1 = self._make_layer(Bottleneck, 64, 3)
        self.layer2 = self._make_layer(Bottleneck, 128, 4, 2)
        self.layer3 = self._make_layer(Bottleneck, 256, 6, 2)
        self.layer4 = self._make_layer(Bottleneck, 512, 3, 2)
        
        self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
        self.fc = torch.nn.Linear(512*Bottleneck.expansion, num_classes)
        
        self._initialize_weights()
    
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes*block.expansion:
            downsample = torch.nn.Sequential(
                torch.nn.Conv2d(self.inplanes, planes*block.expansion, 1, stride, bias=False),
                torch.nn.BatchNorm2d(planes*block.expansion)
            )
        
        layers = [block(self.inplanes, planes, stride, downsample)]
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))
        return torch.nn.Sequential(*layers)
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, torch.nn.BatchNorm2d):
                torch.nn.init.constant_(m.weight, 1)
                torch.nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
# --- End ResNet50 Model Definition ---


# Load class names
with open('./outputs/food101_classes_simple.txt', 'r') as f:
    class_names = [line.strip() for line in f]

num_classes = len(class_names)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the model
model = ResNet50(num_classes=num_classes).to(device)
model_path = './outputs/food101_resnet50_final_weights.pth'
if not os.path.exists(model_path):
    raise FileNotFoundError(f"Model weights not found at {model_path}. Please train the model first.")

model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

# Define the image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def predict_image(image: Image.Image):
    # Apply transformations
    image = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model(image)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
    
    # Get top 5 predictions
    top5_prob, top5_indices = torch.topk(probabilities, 5)
    
    predictions = {class_names[idx]: round(prob.item() * 100, 2) for idx, prob in zip(top5_indices, top5_prob)}
    
    return predictions

# Create Gradio interface
iface = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil", label="Upload Food Image"),
    outputs=gr.Label(num_top_classes=5),
    title="Food101 ResNet50 Classifier",
    description="Upload an image of food and get predictions for 101 food categories. Model trained on Food101 dataset.",
    examples=[
        # Add some example images here if you have them, e.g.,
        # ["path/to/example_image1.jpg"],
        # ["path/to/example_image2.jpg"],
    ]
)

# Launch the Gradio app
if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=8000) # Use port 8000 for Lightning AI deployments