Zalaid commited on
Commit
7fad421
Β·
verified Β·
1 Parent(s): 2ac03d0

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +91 -0
  2. checkpoint_food101_stage3.pth +3 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ import timm
5
+ import torch.nn as nn
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ # ── class names ───────────────────────────────────────
9
+ class_names = [
10
+ 'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare',
11
+ 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito',
12
+ 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake',
13
+ 'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry', 'chicken_quesadilla',
14
+ 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder',
15
+ 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes',
16
+ 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict',
17
+ 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras',
18
+ 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice',
19
+ 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich',
20
+ 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup',
21
+ 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna',
22
+ 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup',
23
+ 'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters',
24
+ 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck',
25
+ 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib',
26
+ 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto',
27
+ 'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits',
28
+ 'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake',
29
+ 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles'
30
+ ]
31
+
32
+ # ── model architecture ────────────────────────────────
33
+ class Food101ModelV2(nn.Module):
34
+ def __init__(self, output_shape: int):
35
+ super().__init__()
36
+ self.base = timm.create_model("efficientnet_b2", pretrained=False)
37
+ in_features = self.base.classifier.in_features
38
+ self.base.classifier = nn.Sequential(
39
+ nn.Linear(in_features, 512),
40
+ nn.BatchNorm1d(512),
41
+ nn.ReLU(),
42
+ nn.Dropout(0.4),
43
+ nn.Linear(512, output_shape)
44
+ )
45
+
46
+ def forward(self, x):
47
+ return self.base(x)
48
+
49
+ # ── load model ────────────────────────────────────────
50
+ device = "cuda" if torch.cuda.is_available() else "cpu"
51
+
52
+ model_path = hf_hub_download(
53
+ repo_id="Zalaid/food-classifier",
54
+ filename="checkpoint_food101_stage3.pth"
55
+ )
56
+
57
+ model = Food101ModelV2(output_shape=101).to(device)
58
+ checkpoint = torch.load(model_path, map_location=device)
59
+ model.load_state_dict(checkpoint["model_state_dict"])
60
+ model.eval()
61
+ print(f"Model loaded! Best acc: {checkpoint['best_test_acc']:.4f}")
62
+
63
+ # ── transform ─────────────────────────────────────────
64
+ transform = transforms.Compose([
65
+ transforms.Resize(280),
66
+ transforms.CenterCrop(260),
67
+ transforms.ToTensor(),
68
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
69
+ std=[0.229, 0.224, 0.225]),
70
+ ])
71
+
72
+ # ── predict function ──────────────────────────────────
73
+ def predict(image):
74
+ img_tensor = transform(image).unsqueeze(0).to(device)
75
+ with torch.inference_mode():
76
+ output = model(img_tensor)
77
+ probs = torch.softmax(output, dim=1)[0]
78
+ top5_probs, top5_idxs = torch.topk(probs, 5)
79
+ return {class_names[i]: p.item() for i, p in zip(top5_idxs, top5_probs)}
80
+
81
+ # ── gradio app ────────────────────────────────────────
82
+ demo = gr.Interface(
83
+ fn=predict,
84
+ inputs=gr.Image(type="pil", label="Upload a food image"),
85
+ outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"),
86
+ title="πŸ• Food-101 Classifier",
87
+ description="Upload any food image β€” model will predict what food it is! Trained on 101 categories using EfficientNet-B2.",
88
+ theme=gr.themes.Soft(),
89
+ )
90
+
91
+ demo.launch()
checkpoint_food101_stage3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ded24e1401d637f3f6bf4f455fefa04ffc580a3d725cffaa2fa6399d9e3daa86
3
+ size 90920923
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ timm
5
+ huggingface_hub