File size: 4,319 Bytes
49c65e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import shutil
from huggingface_hub import hf_hub_download
import gradio as gr
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import json
import cv2
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2

from models.hybrid_model import HybridFoodClassifier

# MONKEY PATCH: Completely disable API generation
import gradio.routes as routes_module
def empty_api_info(*args, **kwargs):
    return {"api": {}}
routes_module.api_info = empty_api_info

REPO_ID = os.getenv("MODEL_REPO_ID", "codealchemist01/food-image-classifier-hybrid")

class FoodClassifier:
    def __init__(self, model_path: str):
        self.device = 'cpu'
        checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
        num_classes = checkpoint.get('num_classes', 101)
        
        try:
            with open('real_class_mapping.json', 'r') as f:
                self.class_names = json.load(f)['real_class_names']
        except:
            self.class_names = [f"class_{i}" for i in range(num_classes)]
        
        self.model = HybridFoodClassifier(num_classes=num_classes, pretrained=False)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model = self.model.to(self.device)
        self.model.eval()
        
        self.transform = A.Compose([
            A.Resize(224, 224),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
        
        print(f"Model loaded! Classes: {num_classes}")
    
    def predict(self, image, top_k=5):
        if image is None:
            return "", None
        
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        img_tensor = self.transform(image=np.array(image))['image'].unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            outputs = self.model(img_tensor, return_features=True)
            probs = F.softmax(outputs['logits'], dim=1).cpu().numpy()[0]
            attention_maps = self.model.get_attention_maps(img_tensor)
        
        top_indices = np.argsort(probs)[::-1][:int(top_k)]
        results = "\n".join([
            f"{i+1}. {self.class_names[idx]}: {probs[idx]:.3f}"
            for i, idx in enumerate(top_indices)
        ])
        
        # Attention viz
        img_np = np.array(image.resize((224, 224)))
        cnn_att = cv2.resize(attention_maps['cnn_attention'].cpu().numpy()[0, 0], (224, 224))
        cnn_att = (cnn_att - cnn_att.min()) / (cnn_att.max() - cnn_att.min() + 1e-8)
        vit_att = cv2.resize(attention_maps['vit_attention'].cpu().numpy()[0, 0], (224, 224))
        vit_att = (vit_att - vit_att.min()) / (vit_att.max() - vit_att.min() + 1e-8)
        
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        axes[0].imshow(img_np)
        axes[0].set_title('Original')
        axes[0].axis('off')
        axes[1].imshow(img_np)
        axes[1].imshow(cnn_att, alpha=0.6, cmap='jet')
        axes[1].set_title('CNN Attention')
        axes[1].axis('off')
        axes[2].imshow(img_np)
        axes[2].imshow(vit_att, alpha=0.6, cmap='jet')
        axes[2].set_title('ViT Attention')
        axes[2].axis('off')
        plt.tight_layout()
        
        import io
        buf = io.BytesIO()
        plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
        buf.seek(0)
        attention_img = Image.open(buf)
        plt.close(fig)
        
        return results, attention_img

print("Downloading model...")
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="best_model.pth")
mapping_path = hf_hub_download(repo_id=REPO_ID, filename="real_class_mapping.json")
shutil.copy(mapping_path, "real_class_mapping.json")

classifier = FoodClassifier(ckpt_path)

demo = gr.Interface(
    fn=classifier.predict,
    inputs=[
        gr.Image(type="pil", label="Upload Food Image"),
        gr.Slider(1, 10, 5, step=1, label="Top K")
    ],
    outputs=[
        gr.Textbox(label="Results", lines=10),
        gr.Image(label="Attention Maps", height=400)
    ],
    title="Food Classifier",
    description="Upload food images for classification"
)

demo.launch(server_name="0.0.0.0", server_port=7860)