codealchemist01's picture
Update app.py
25bffa6 verified
# ULTIMATE FIX - Patch everything before Gradio loads
import sys
# Patch client_utils BEFORE gradio import
def patch_before_gradio():
# We'll patch after gradio loads but before it's used
pass
# Import gradio
import gradio as gr
import gradio.routes as routes_module
from gradio_client import utils as client_utils
# Patch 1: Fix client_utils.get_type() - THE ACTUAL BUG
original_get_type = client_utils.get_type
def safe_get_type(schema):
if not isinstance(schema, dict):
return "Any"
try:
return original_get_type(schema)
except (TypeError, AttributeError):
return "Any"
client_utils.get_type = safe_get_type
# Patch 2: Fix _json_schema_to_python_type
original_json_schema = client_utils._json_schema_to_python_type
def safe_json_schema(schema, defs=None):
if not isinstance(schema, dict):
return "Any"
try:
return original_json_schema(schema, defs)
except (TypeError, AttributeError):
return "Any"
client_utils._json_schema_to_python_type = safe_json_schema
# Patch 3: Disable API generation
def empty_api_info(*args, **kwargs):
return {"api": {}}
routes_module.api_info = empty_api_info
import os
import shutil
from huggingface_hub import hf_hub_download
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import json
import cv2
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from models.hybrid_model import HybridFoodClassifier
REPO_ID = os.getenv("MODEL_REPO_ID", "codealchemist01/food-image-classifier-hybrid")
class FoodClassifier:
def __init__(self, model_path: str):
self.device = 'cpu'
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
num_classes = checkpoint.get('num_classes', 101)
try:
with open('real_class_mapping.json', 'r') as f:
self.class_names = json.load(f)['real_class_names']
except:
self.class_names = [f"class_{i}" for i in range(num_classes)]
self.model = HybridFoodClassifier(num_classes=num_classes, pretrained=False)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model = self.model.to(self.device)
self.model.eval()
self.transform = A.Compose([
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
print(f"βœ… Model loaded successfully! Classes: {num_classes}")
def predict(self, image, top_k=5):
if image is None:
return "", None
if image.mode != 'RGB':
image = image.convert('RGB')
img_tensor = self.transform(image=np.array(image))['image'].unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(img_tensor, return_features=True)
probs = F.softmax(outputs['logits'], dim=1).cpu().numpy()[0]
attention_maps = self.model.get_attention_maps(img_tensor)
top_indices = np.argsort(probs)[::-1][:int(top_k)]
results = "\n".join([
f"{i+1}. {self.class_names[idx]}: {probs[idx]:.3f}"
for i, idx in enumerate(top_indices)
])
# Attention visualization
img_np = np.array(image.resize((224, 224)))
cnn_att = cv2.resize(attention_maps['cnn_attention'].cpu().numpy()[0, 0], (224, 224))
cnn_att = (cnn_att - cnn_att.min()) / (cnn_att.max() - cnn_att.min() + 1e-8)
vit_att = cv2.resize(attention_maps['vit_attention'].cpu().numpy()[0, 0], (224, 224))
vit_att = (vit_att - vit_att.min()) / (vit_att.max() - vit_att.min() + 1e-8)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(img_np)
axes[0].set_title('Original Image')
axes[0].axis('off')
axes[1].imshow(img_np)
axes[1].imshow(cnn_att, alpha=0.6, cmap='jet')
axes[1].set_title('CNN Attention')
axes[1].axis('off')
axes[2].imshow(img_np)
axes[2].imshow(vit_att, alpha=0.6, cmap='jet')
axes[2].set_title('ViT Attention')
axes[2].axis('off')
plt.tight_layout()
import io
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
buf.seek(0)
attention_img = Image.open(buf)
plt.close(fig)
return results, attention_img
print("πŸ“₯ Downloading model from Hugging Face Hub...")
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="best_model.pth")
mapping_path = hf_hub_download(repo_id=REPO_ID, filename="real_class_mapping.json")
shutil.copy(mapping_path, "real_class_mapping.json")
print("βœ… Model files downloaded successfully!")
classifier = FoodClassifier(ckpt_path)
# Create Gradio Interface
demo = gr.Interface(
fn=classifier.predict,
inputs=[
gr.Image(type="pil", label="πŸ“· Upload Food Image", height=300),
gr.Slider(1, 10, 5, step=1, label="πŸ” Top K Predictions")
],
outputs=[
gr.Textbox(label="🎯 Classification Results", lines=10),
gr.Image(label="πŸ‘οΈ Attention Maps", height=400)
],
title="πŸ• Food Image Classifier",
description="""
# πŸ• AI-Powered Food Classification System
This application uses a **Hybrid CNN-ViT Architecture** to classify food images into 101 different categories.
## πŸš€ How to Use:
1. **Upload** a food image (or drag & drop)
2. **Adjust** the "Top K" slider to see more/less predictions
3. **View** the results:
- **Classification Results**: Top food categories with confidence scores
- **Attention Maps**: Visual representation of what the AI focuses on
## 🧠 Model Architecture:
- **CNN Branch**: ResNet50 (spatial feature extraction)
- **ViT Branch**: DeiT-Base (global context understanding)
- **Fusion Module**: Adaptive attention-based fusion
## πŸ“Š Performance:
- **101 Food Categories** from Food-101 dataset (https://www.kaggle.com/datasets/dansbecker/food-101)
- **Validation Accuracy**: ~82.5%
- **Top-5 Accuracy**: >95%
## 🎯 Model Capabilities:
The model can classify various food types including:
- Pizza, Burger, Sushi, Pasta, Salad, and 96 more categories!
**Try uploading a food image now!** 🍽️
""",
theme=gr.themes.Soft(),
examples=None # No examples to avoid cache issues
)
print("πŸš€ Starting Gradio interface...")
demo.launch(server_name="0.0.0.0", server_port=7860)