File size: 6,732 Bytes
8008b77 9e438ea 450676e 8008b77 450676e 8008b77 450676e 9e438ea 8008b77 9e438ea 8008b77 9e438ea 8008b77 9e438ea 49c65e2 8008b77 49c65e2 8008b77 49c65e2 8008b77 49c65e2 8008b77 49c65e2 8008b77 49c65e2 8008b77 49c65e2 8008b77 49c65e2 8008b77 49c65e2 8008b77 25bffa6 8008b77 49c65e2 8008b77 49c65e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
# ULTIMATE FIX - Patch everything before Gradio loads
import sys
# Patch client_utils BEFORE gradio import
def patch_before_gradio():
# We'll patch after gradio loads but before it's used
pass
# Import gradio
import gradio as gr
import gradio.routes as routes_module
from gradio_client import utils as client_utils
# Patch 1: Fix client_utils.get_type() - THE ACTUAL BUG
original_get_type = client_utils.get_type
def safe_get_type(schema):
if not isinstance(schema, dict):
return "Any"
try:
return original_get_type(schema)
except (TypeError, AttributeError):
return "Any"
client_utils.get_type = safe_get_type
# Patch 2: Fix _json_schema_to_python_type
original_json_schema = client_utils._json_schema_to_python_type
def safe_json_schema(schema, defs=None):
if not isinstance(schema, dict):
return "Any"
try:
return original_json_schema(schema, defs)
except (TypeError, AttributeError):
return "Any"
client_utils._json_schema_to_python_type = safe_json_schema
# Patch 3: Disable API generation
def empty_api_info(*args, **kwargs):
return {"api": {}}
routes_module.api_info = empty_api_info
import os
import shutil
from huggingface_hub import hf_hub_download
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import json
import cv2
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from models.hybrid_model import HybridFoodClassifier
REPO_ID = os.getenv("MODEL_REPO_ID", "codealchemist01/food-image-classifier-hybrid")
class FoodClassifier:
def __init__(self, model_path: str):
self.device = 'cpu'
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
num_classes = checkpoint.get('num_classes', 101)
try:
with open('real_class_mapping.json', 'r') as f:
self.class_names = json.load(f)['real_class_names']
except:
self.class_names = [f"class_{i}" for i in range(num_classes)]
self.model = HybridFoodClassifier(num_classes=num_classes, pretrained=False)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model = self.model.to(self.device)
self.model.eval()
self.transform = A.Compose([
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
print(f"β
Model loaded successfully! Classes: {num_classes}")
def predict(self, image, top_k=5):
if image is None:
return "", None
if image.mode != 'RGB':
image = image.convert('RGB')
img_tensor = self.transform(image=np.array(image))['image'].unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(img_tensor, return_features=True)
probs = F.softmax(outputs['logits'], dim=1).cpu().numpy()[0]
attention_maps = self.model.get_attention_maps(img_tensor)
top_indices = np.argsort(probs)[::-1][:int(top_k)]
results = "\n".join([
f"{i+1}. {self.class_names[idx]}: {probs[idx]:.3f}"
for i, idx in enumerate(top_indices)
])
# Attention visualization
img_np = np.array(image.resize((224, 224)))
cnn_att = cv2.resize(attention_maps['cnn_attention'].cpu().numpy()[0, 0], (224, 224))
cnn_att = (cnn_att - cnn_att.min()) / (cnn_att.max() - cnn_att.min() + 1e-8)
vit_att = cv2.resize(attention_maps['vit_attention'].cpu().numpy()[0, 0], (224, 224))
vit_att = (vit_att - vit_att.min()) / (vit_att.max() - vit_att.min() + 1e-8)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(img_np)
axes[0].set_title('Original Image')
axes[0].axis('off')
axes[1].imshow(img_np)
axes[1].imshow(cnn_att, alpha=0.6, cmap='jet')
axes[1].set_title('CNN Attention')
axes[1].axis('off')
axes[2].imshow(img_np)
axes[2].imshow(vit_att, alpha=0.6, cmap='jet')
axes[2].set_title('ViT Attention')
axes[2].axis('off')
plt.tight_layout()
import io
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
buf.seek(0)
attention_img = Image.open(buf)
plt.close(fig)
return results, attention_img
print("π₯ Downloading model from Hugging Face Hub...")
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="best_model.pth")
mapping_path = hf_hub_download(repo_id=REPO_ID, filename="real_class_mapping.json")
shutil.copy(mapping_path, "real_class_mapping.json")
print("β
Model files downloaded successfully!")
classifier = FoodClassifier(ckpt_path)
# Create Gradio Interface
demo = gr.Interface(
fn=classifier.predict,
inputs=[
gr.Image(type="pil", label="π· Upload Food Image", height=300),
gr.Slider(1, 10, 5, step=1, label="π Top K Predictions")
],
outputs=[
gr.Textbox(label="π― Classification Results", lines=10),
gr.Image(label="ποΈ Attention Maps", height=400)
],
title="π Food Image Classifier",
description="""
# π AI-Powered Food Classification System
This application uses a **Hybrid CNN-ViT Architecture** to classify food images into 101 different categories.
## π How to Use:
1. **Upload** a food image (or drag & drop)
2. **Adjust** the "Top K" slider to see more/less predictions
3. **View** the results:
- **Classification Results**: Top food categories with confidence scores
- **Attention Maps**: Visual representation of what the AI focuses on
## π§ Model Architecture:
- **CNN Branch**: ResNet50 (spatial feature extraction)
- **ViT Branch**: DeiT-Base (global context understanding)
- **Fusion Module**: Adaptive attention-based fusion
## π Performance:
- **101 Food Categories** from Food-101 dataset (https://www.kaggle.com/datasets/dansbecker/food-101)
- **Validation Accuracy**: ~82.5%
- **Top-5 Accuracy**: >95%
## π― Model Capabilities:
The model can classify various food types including:
- Pizza, Burger, Sushi, Pasta, Salad, and 96 more categories!
**Try uploading a food image now!** π½οΈ
""",
theme=gr.themes.Soft(),
examples=None # No examples to avoid cache issues
)
print("π Starting Gradio interface...")
demo.launch(server_name="0.0.0.0", server_port=7860)
|