|
|
|
|
|
import sys |
|
|
|
|
|
|
|
|
def patch_before_gradio(): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
import gradio.routes as routes_module |
|
|
from gradio_client import utils as client_utils |
|
|
|
|
|
|
|
|
original_get_type = client_utils.get_type |
|
|
def safe_get_type(schema): |
|
|
if not isinstance(schema, dict): |
|
|
return "Any" |
|
|
try: |
|
|
return original_get_type(schema) |
|
|
except (TypeError, AttributeError): |
|
|
return "Any" |
|
|
client_utils.get_type = safe_get_type |
|
|
|
|
|
|
|
|
original_json_schema = client_utils._json_schema_to_python_type |
|
|
def safe_json_schema(schema, defs=None): |
|
|
if not isinstance(schema, dict): |
|
|
return "Any" |
|
|
try: |
|
|
return original_json_schema(schema, defs) |
|
|
except (TypeError, AttributeError): |
|
|
return "Any" |
|
|
client_utils._json_schema_to_python_type = safe_json_schema |
|
|
|
|
|
|
|
|
def empty_api_info(*args, **kwargs): |
|
|
return {"api": {}} |
|
|
routes_module.api_info = empty_api_info |
|
|
|
|
|
import os |
|
|
import shutil |
|
|
from huggingface_hub import hf_hub_download |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import json |
|
|
import cv2 |
|
|
import matplotlib |
|
|
matplotlib.use('Agg') |
|
|
import matplotlib.pyplot as plt |
|
|
import albumentations as A |
|
|
from albumentations.pytorch import ToTensorV2 |
|
|
|
|
|
from models.hybrid_model import HybridFoodClassifier |
|
|
|
|
|
REPO_ID = os.getenv("MODEL_REPO_ID", "codealchemist01/food-image-classifier-hybrid") |
|
|
|
|
|
class FoodClassifier: |
|
|
def __init__(self, model_path: str): |
|
|
self.device = 'cpu' |
|
|
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) |
|
|
num_classes = checkpoint.get('num_classes', 101) |
|
|
|
|
|
try: |
|
|
with open('real_class_mapping.json', 'r') as f: |
|
|
self.class_names = json.load(f)['real_class_names'] |
|
|
except: |
|
|
self.class_names = [f"class_{i}" for i in range(num_classes)] |
|
|
|
|
|
self.model = HybridFoodClassifier(num_classes=num_classes, pretrained=False) |
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
self.model = self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
self.transform = A.Compose([ |
|
|
A.Resize(224, 224), |
|
|
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
ToTensorV2() |
|
|
]) |
|
|
|
|
|
print(f"β
Model loaded successfully! Classes: {num_classes}") |
|
|
|
|
|
def predict(self, image, top_k=5): |
|
|
if image is None: |
|
|
return "", None |
|
|
|
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
|
|
|
img_tensor = self.transform(image=np.array(image))['image'].unsqueeze(0).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(img_tensor, return_features=True) |
|
|
probs = F.softmax(outputs['logits'], dim=1).cpu().numpy()[0] |
|
|
attention_maps = self.model.get_attention_maps(img_tensor) |
|
|
|
|
|
top_indices = np.argsort(probs)[::-1][:int(top_k)] |
|
|
results = "\n".join([ |
|
|
f"{i+1}. {self.class_names[idx]}: {probs[idx]:.3f}" |
|
|
for i, idx in enumerate(top_indices) |
|
|
]) |
|
|
|
|
|
|
|
|
img_np = np.array(image.resize((224, 224))) |
|
|
cnn_att = cv2.resize(attention_maps['cnn_attention'].cpu().numpy()[0, 0], (224, 224)) |
|
|
cnn_att = (cnn_att - cnn_att.min()) / (cnn_att.max() - cnn_att.min() + 1e-8) |
|
|
vit_att = cv2.resize(attention_maps['vit_attention'].cpu().numpy()[0, 0], (224, 224)) |
|
|
vit_att = (vit_att - vit_att.min()) / (vit_att.max() - vit_att.min() + 1e-8) |
|
|
|
|
|
fig, axes = plt.subplots(1, 3, figsize=(15, 5)) |
|
|
axes[0].imshow(img_np) |
|
|
axes[0].set_title('Original Image') |
|
|
axes[0].axis('off') |
|
|
axes[1].imshow(img_np) |
|
|
axes[1].imshow(cnn_att, alpha=0.6, cmap='jet') |
|
|
axes[1].set_title('CNN Attention') |
|
|
axes[1].axis('off') |
|
|
axes[2].imshow(img_np) |
|
|
axes[2].imshow(vit_att, alpha=0.6, cmap='jet') |
|
|
axes[2].set_title('ViT Attention') |
|
|
axes[2].axis('off') |
|
|
plt.tight_layout() |
|
|
|
|
|
import io |
|
|
buf = io.BytesIO() |
|
|
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') |
|
|
buf.seek(0) |
|
|
attention_img = Image.open(buf) |
|
|
plt.close(fig) |
|
|
|
|
|
return results, attention_img |
|
|
|
|
|
print("π₯ Downloading model from Hugging Face Hub...") |
|
|
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="best_model.pth") |
|
|
mapping_path = hf_hub_download(repo_id=REPO_ID, filename="real_class_mapping.json") |
|
|
shutil.copy(mapping_path, "real_class_mapping.json") |
|
|
print("β
Model files downloaded successfully!") |
|
|
|
|
|
classifier = FoodClassifier(ckpt_path) |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=classifier.predict, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="π· Upload Food Image", height=300), |
|
|
gr.Slider(1, 10, 5, step=1, label="π Top K Predictions") |
|
|
], |
|
|
outputs=[ |
|
|
gr.Textbox(label="π― Classification Results", lines=10), |
|
|
gr.Image(label="ποΈ Attention Maps", height=400) |
|
|
], |
|
|
title="π Food Image Classifier", |
|
|
description=""" |
|
|
# π AI-Powered Food Classification System |
|
|
|
|
|
This application uses a **Hybrid CNN-ViT Architecture** to classify food images into 101 different categories. |
|
|
|
|
|
## π How to Use: |
|
|
1. **Upload** a food image (or drag & drop) |
|
|
2. **Adjust** the "Top K" slider to see more/less predictions |
|
|
3. **View** the results: |
|
|
- **Classification Results**: Top food categories with confidence scores |
|
|
- **Attention Maps**: Visual representation of what the AI focuses on |
|
|
|
|
|
## π§ Model Architecture: |
|
|
- **CNN Branch**: ResNet50 (spatial feature extraction) |
|
|
- **ViT Branch**: DeiT-Base (global context understanding) |
|
|
- **Fusion Module**: Adaptive attention-based fusion |
|
|
|
|
|
## π Performance: |
|
|
- **101 Food Categories** from Food-101 dataset (https://www.kaggle.com/datasets/dansbecker/food-101) |
|
|
- **Validation Accuracy**: ~82.5% |
|
|
- **Top-5 Accuracy**: >95% |
|
|
|
|
|
## π― Model Capabilities: |
|
|
The model can classify various food types including: |
|
|
- Pizza, Burger, Sushi, Pasta, Salad, and 96 more categories! |
|
|
|
|
|
**Try uploading a food image now!** π½οΈ |
|
|
""", |
|
|
theme=gr.themes.Soft(), |
|
|
examples=None |
|
|
) |
|
|
|
|
|
print("π Starting Gradio interface...") |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|