File size: 6,732 Bytes
8008b77
9e438ea
450676e
8008b77
 
 
 
450676e
8008b77
450676e
9e438ea
8008b77
9e438ea
8008b77
 
 
 
 
 
 
 
 
 
9e438ea
8008b77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e438ea
49c65e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8008b77
49c65e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8008b77
49c65e2
 
 
 
 
 
 
 
8008b77
49c65e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8008b77
49c65e2
 
 
8008b77
49c65e2
 
 
8008b77
49c65e2
 
 
8008b77
 
49c65e2
 
8008b77
 
49c65e2
8008b77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25bffa6
8008b77
 
 
 
 
 
 
 
 
 
 
49c65e2
 
8008b77
49c65e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# ULTIMATE FIX - Patch everything before Gradio loads
import sys

# Patch client_utils BEFORE gradio import
def patch_before_gradio():
    # We'll patch after gradio loads but before it's used
    pass

# Import gradio
import gradio as gr
import gradio.routes as routes_module
from gradio_client import utils as client_utils

# Patch 1: Fix client_utils.get_type() - THE ACTUAL BUG
original_get_type = client_utils.get_type
def safe_get_type(schema):
    if not isinstance(schema, dict):
        return "Any"
    try:
        return original_get_type(schema)
    except (TypeError, AttributeError):
        return "Any"
client_utils.get_type = safe_get_type

# Patch 2: Fix _json_schema_to_python_type
original_json_schema = client_utils._json_schema_to_python_type
def safe_json_schema(schema, defs=None):
    if not isinstance(schema, dict):
        return "Any"
    try:
        return original_json_schema(schema, defs)
    except (TypeError, AttributeError):
        return "Any"
client_utils._json_schema_to_python_type = safe_json_schema

# Patch 3: Disable API generation
def empty_api_info(*args, **kwargs):
    return {"api": {}}
routes_module.api_info = empty_api_info

import os
import shutil
from huggingface_hub import hf_hub_download
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import json
import cv2
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2

from models.hybrid_model import HybridFoodClassifier

REPO_ID = os.getenv("MODEL_REPO_ID", "codealchemist01/food-image-classifier-hybrid")

class FoodClassifier:
    def __init__(self, model_path: str):
        self.device = 'cpu'
        checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
        num_classes = checkpoint.get('num_classes', 101)
        
        try:
            with open('real_class_mapping.json', 'r') as f:
                self.class_names = json.load(f)['real_class_names']
        except:
            self.class_names = [f"class_{i}" for i in range(num_classes)]
        
        self.model = HybridFoodClassifier(num_classes=num_classes, pretrained=False)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model = self.model.to(self.device)
        self.model.eval()
        
        self.transform = A.Compose([
            A.Resize(224, 224),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
        
        print(f"βœ… Model loaded successfully! Classes: {num_classes}")
    
    def predict(self, image, top_k=5):
        if image is None:
            return "", None
        
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        img_tensor = self.transform(image=np.array(image))['image'].unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            outputs = self.model(img_tensor, return_features=True)
            probs = F.softmax(outputs['logits'], dim=1).cpu().numpy()[0]
            attention_maps = self.model.get_attention_maps(img_tensor)
        
        top_indices = np.argsort(probs)[::-1][:int(top_k)]
        results = "\n".join([
            f"{i+1}. {self.class_names[idx]}: {probs[idx]:.3f}"
            for i, idx in enumerate(top_indices)
        ])
        
        # Attention visualization
        img_np = np.array(image.resize((224, 224)))
        cnn_att = cv2.resize(attention_maps['cnn_attention'].cpu().numpy()[0, 0], (224, 224))
        cnn_att = (cnn_att - cnn_att.min()) / (cnn_att.max() - cnn_att.min() + 1e-8)
        vit_att = cv2.resize(attention_maps['vit_attention'].cpu().numpy()[0, 0], (224, 224))
        vit_att = (vit_att - vit_att.min()) / (vit_att.max() - vit_att.min() + 1e-8)
        
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        axes[0].imshow(img_np)
        axes[0].set_title('Original Image')
        axes[0].axis('off')
        axes[1].imshow(img_np)
        axes[1].imshow(cnn_att, alpha=0.6, cmap='jet')
        axes[1].set_title('CNN Attention')
        axes[1].axis('off')
        axes[2].imshow(img_np)
        axes[2].imshow(vit_att, alpha=0.6, cmap='jet')
        axes[2].set_title('ViT Attention')
        axes[2].axis('off')
        plt.tight_layout()
        
        import io
        buf = io.BytesIO()
        plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
        buf.seek(0)
        attention_img = Image.open(buf)
        plt.close(fig)
        
        return results, attention_img

print("πŸ“₯ Downloading model from Hugging Face Hub...")
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="best_model.pth")
mapping_path = hf_hub_download(repo_id=REPO_ID, filename="real_class_mapping.json")
shutil.copy(mapping_path, "real_class_mapping.json")
print("βœ… Model files downloaded successfully!")

classifier = FoodClassifier(ckpt_path)

# Create Gradio Interface
demo = gr.Interface(
    fn=classifier.predict,
    inputs=[
        gr.Image(type="pil", label="πŸ“· Upload Food Image", height=300),
        gr.Slider(1, 10, 5, step=1, label="πŸ” Top K Predictions")
    ],
    outputs=[
        gr.Textbox(label="🎯 Classification Results", lines=10),
        gr.Image(label="πŸ‘οΈ Attention Maps", height=400)
    ],
    title="πŸ• Food Image Classifier",
    description="""
    # πŸ• AI-Powered Food Classification System
    
    This application uses a **Hybrid CNN-ViT Architecture** to classify food images into 101 different categories.
    
    ## πŸš€ How to Use:
    1. **Upload** a food image (or drag & drop)
    2. **Adjust** the "Top K" slider to see more/less predictions
    3. **View** the results:
       - **Classification Results**: Top food categories with confidence scores
       - **Attention Maps**: Visual representation of what the AI focuses on
    
    ## 🧠 Model Architecture:
    - **CNN Branch**: ResNet50 (spatial feature extraction)
    - **ViT Branch**: DeiT-Base (global context understanding)
    - **Fusion Module**: Adaptive attention-based fusion
    
    ## πŸ“Š Performance:
    - **101 Food Categories** from Food-101 dataset (https://www.kaggle.com/datasets/dansbecker/food-101)
    - **Validation Accuracy**: ~82.5%
    - **Top-5 Accuracy**: >95%
    
    ## 🎯 Model Capabilities:
    The model can classify various food types including:
    - Pizza, Burger, Sushi, Pasta, Salad, and 96 more categories!
    
    **Try uploading a food image now!** 🍽️
    """,
    theme=gr.themes.Soft(),
    examples=None  # No examples to avoid cache issues
)

print("πŸš€ Starting Gradio interface...")
demo.launch(server_name="0.0.0.0", server_port=7860)