codealchemist01 commited on
Commit
49c65e2
·
verified ·
1 Parent(s): 287465e

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +123 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from huggingface_hub import hf_hub_download
4
+ import gradio as gr
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from PIL import Image
9
+ import json
10
+ import cv2
11
+ import matplotlib
12
+ matplotlib.use('Agg')
13
+ import matplotlib.pyplot as plt
14
+ import albumentations as A
15
+ from albumentations.pytorch import ToTensorV2
16
+
17
+ from models.hybrid_model import HybridFoodClassifier
18
+
19
+ # MONKEY PATCH: Completely disable API generation
20
+ import gradio.routes as routes_module
21
+ def empty_api_info(*args, **kwargs):
22
+ return {"api": {}}
23
+ routes_module.api_info = empty_api_info
24
+
25
+ REPO_ID = os.getenv("MODEL_REPO_ID", "codealchemist01/food-image-classifier-hybrid")
26
+
27
+ class FoodClassifier:
28
+ def __init__(self, model_path: str):
29
+ self.device = 'cpu'
30
+ checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
31
+ num_classes = checkpoint.get('num_classes', 101)
32
+
33
+ try:
34
+ with open('real_class_mapping.json', 'r') as f:
35
+ self.class_names = json.load(f)['real_class_names']
36
+ except:
37
+ self.class_names = [f"class_{i}" for i in range(num_classes)]
38
+
39
+ self.model = HybridFoodClassifier(num_classes=num_classes, pretrained=False)
40
+ self.model.load_state_dict(checkpoint['model_state_dict'])
41
+ self.model = self.model.to(self.device)
42
+ self.model.eval()
43
+
44
+ self.transform = A.Compose([
45
+ A.Resize(224, 224),
46
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
47
+ ToTensorV2()
48
+ ])
49
+
50
+ print(f"Model loaded! Classes: {num_classes}")
51
+
52
+ def predict(self, image, top_k=5):
53
+ if image is None:
54
+ return "", None
55
+
56
+ if image.mode != 'RGB':
57
+ image = image.convert('RGB')
58
+
59
+ img_tensor = self.transform(image=np.array(image))['image'].unsqueeze(0).to(self.device)
60
+
61
+ with torch.no_grad():
62
+ outputs = self.model(img_tensor, return_features=True)
63
+ probs = F.softmax(outputs['logits'], dim=1).cpu().numpy()[0]
64
+ attention_maps = self.model.get_attention_maps(img_tensor)
65
+
66
+ top_indices = np.argsort(probs)[::-1][:int(top_k)]
67
+ results = "\n".join([
68
+ f"{i+1}. {self.class_names[idx]}: {probs[idx]:.3f}"
69
+ for i, idx in enumerate(top_indices)
70
+ ])
71
+
72
+ # Attention viz
73
+ img_np = np.array(image.resize((224, 224)))
74
+ cnn_att = cv2.resize(attention_maps['cnn_attention'].cpu().numpy()[0, 0], (224, 224))
75
+ cnn_att = (cnn_att - cnn_att.min()) / (cnn_att.max() - cnn_att.min() + 1e-8)
76
+ vit_att = cv2.resize(attention_maps['vit_attention'].cpu().numpy()[0, 0], (224, 224))
77
+ vit_att = (vit_att - vit_att.min()) / (vit_att.max() - vit_att.min() + 1e-8)
78
+
79
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
80
+ axes[0].imshow(img_np)
81
+ axes[0].set_title('Original')
82
+ axes[0].axis('off')
83
+ axes[1].imshow(img_np)
84
+ axes[1].imshow(cnn_att, alpha=0.6, cmap='jet')
85
+ axes[1].set_title('CNN Attention')
86
+ axes[1].axis('off')
87
+ axes[2].imshow(img_np)
88
+ axes[2].imshow(vit_att, alpha=0.6, cmap='jet')
89
+ axes[2].set_title('ViT Attention')
90
+ axes[2].axis('off')
91
+ plt.tight_layout()
92
+
93
+ import io
94
+ buf = io.BytesIO()
95
+ plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
96
+ buf.seek(0)
97
+ attention_img = Image.open(buf)
98
+ plt.close(fig)
99
+
100
+ return results, attention_img
101
+
102
+ print("Downloading model...")
103
+ ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="best_model.pth")
104
+ mapping_path = hf_hub_download(repo_id=REPO_ID, filename="real_class_mapping.json")
105
+ shutil.copy(mapping_path, "real_class_mapping.json")
106
+
107
+ classifier = FoodClassifier(ckpt_path)
108
+
109
+ demo = gr.Interface(
110
+ fn=classifier.predict,
111
+ inputs=[
112
+ gr.Image(type="pil", label="Upload Food Image"),
113
+ gr.Slider(1, 10, 5, step=1, label="Top K")
114
+ ],
115
+ outputs=[
116
+ gr.Textbox(label="Results", lines=10),
117
+ gr.Image(label="Attention Maps", height=400)
118
+ ],
119
+ title="Food Classifier",
120
+ description="Upload food images for classification"
121
+ )
122
+
123
+ demo.launch(server_name="0.0.0.0", server_port=7860)