File size: 4,722 Bytes
9e438ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49c65e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# PATCH GRADIO BEFORE ANYTHING ELSE LOADS
import sys
import gradio.routes as routes_module
import gradio.blocks as blocks_module

# Patch api_info to return empty dict immediately
original_api_info = getattr(routes_module, 'api_info', None)
def patched_api_info(*args, **kwargs):
    return {"api": {}}
routes_module.api_info = patched_api_info

# Also patch get_api_info in Blocks
original_get_api_info = getattr(blocks_module.Blocks, 'get_api_info', None)
def patched_get_api_info(self):
    return {}
if original_get_api_info:
    blocks_module.Blocks.get_api_info = patched_get_api_info

import os
import shutil
from huggingface_hub import hf_hub_download
import gradio as gr
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import json
import cv2
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2

from models.hybrid_model import HybridFoodClassifier

REPO_ID = os.getenv("MODEL_REPO_ID", "codealchemist01/food-image-classifier-hybrid")

class FoodClassifier:
    def __init__(self, model_path: str):
        self.device = 'cpu'
        checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
        num_classes = checkpoint.get('num_classes', 101)
        
        try:
            with open('real_class_mapping.json', 'r') as f:
                self.class_names = json.load(f)['real_class_names']
        except:
            self.class_names = [f"class_{i}" for i in range(num_classes)]
        
        self.model = HybridFoodClassifier(num_classes=num_classes, pretrained=False)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model = self.model.to(self.device)
        self.model.eval()
        
        self.transform = A.Compose([
            A.Resize(224, 224),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
        
        print(f"Model loaded! Classes: {num_classes}")
    
    def predict(self, image, top_k=5):
        if image is None:
            return "", None
        
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        img_tensor = self.transform(image=np.array(image))['image'].unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            outputs = self.model(img_tensor, return_features=True)
            probs = F.softmax(outputs['logits'], dim=1).cpu().numpy()[0]
            attention_maps = self.model.get_attention_maps(img_tensor)
        
        top_indices = np.argsort(probs)[::-1][:int(top_k)]
        results = "\n".join([
            f"{i+1}. {self.class_names[idx]}: {probs[idx]:.3f}"
            for i, idx in enumerate(top_indices)
        ])
        
        # Attention viz
        img_np = np.array(image.resize((224, 224)))
        cnn_att = cv2.resize(attention_maps['cnn_attention'].cpu().numpy()[0, 0], (224, 224))
        cnn_att = (cnn_att - cnn_att.min()) / (cnn_att.max() - cnn_att.min() + 1e-8)
        vit_att = cv2.resize(attention_maps['vit_attention'].cpu().numpy()[0, 0], (224, 224))
        vit_att = (vit_att - vit_att.min()) / (vit_att.max() - vit_att.min() + 1e-8)
        
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        axes[0].imshow(img_np)
        axes[0].set_title('Original')
        axes[0].axis('off')
        axes[1].imshow(img_np)
        axes[1].imshow(cnn_att, alpha=0.6, cmap='jet')
        axes[1].set_title('CNN Attention')
        axes[1].axis('off')
        axes[2].imshow(img_np)
        axes[2].imshow(vit_att, alpha=0.6, cmap='jet')
        axes[2].set_title('ViT Attention')
        axes[2].axis('off')
        plt.tight_layout()
        
        import io
        buf = io.BytesIO()
        plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
        buf.seek(0)
        attention_img = Image.open(buf)
        plt.close(fig)
        
        return results, attention_img

print("Downloading model...")
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="best_model.pth")
mapping_path = hf_hub_download(repo_id=REPO_ID, filename="real_class_mapping.json")
shutil.copy(mapping_path, "real_class_mapping.json")

classifier = FoodClassifier(ckpt_path)

demo = gr.Interface(
    fn=classifier.predict,
    inputs=[
        gr.Image(type="pil", label="Upload Food Image"),
        gr.Slider(1, 10, 5, step=1, label="Top K")
    ],
    outputs=[
        gr.Textbox(label="Results", lines=10),
        gr.Image(label="Attention Maps", height=400)
    ],
    title="Food Classifier",
    description="Upload food images for classification"
)

demo.launch(server_name="0.0.0.0", server_port=7860)