File size: 4,722 Bytes
9e438ea 49c65e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# PATCH GRADIO BEFORE ANYTHING ELSE LOADS
import sys
import gradio.routes as routes_module
import gradio.blocks as blocks_module
# Patch api_info to return empty dict immediately
original_api_info = getattr(routes_module, 'api_info', None)
def patched_api_info(*args, **kwargs):
return {"api": {}}
routes_module.api_info = patched_api_info
# Also patch get_api_info in Blocks
original_get_api_info = getattr(blocks_module.Blocks, 'get_api_info', None)
def patched_get_api_info(self):
return {}
if original_get_api_info:
blocks_module.Blocks.get_api_info = patched_get_api_info
import os
import shutil
from huggingface_hub import hf_hub_download
import gradio as gr
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import json
import cv2
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from models.hybrid_model import HybridFoodClassifier
REPO_ID = os.getenv("MODEL_REPO_ID", "codealchemist01/food-image-classifier-hybrid")
class FoodClassifier:
def __init__(self, model_path: str):
self.device = 'cpu'
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
num_classes = checkpoint.get('num_classes', 101)
try:
with open('real_class_mapping.json', 'r') as f:
self.class_names = json.load(f)['real_class_names']
except:
self.class_names = [f"class_{i}" for i in range(num_classes)]
self.model = HybridFoodClassifier(num_classes=num_classes, pretrained=False)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model = self.model.to(self.device)
self.model.eval()
self.transform = A.Compose([
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
print(f"Model loaded! Classes: {num_classes}")
def predict(self, image, top_k=5):
if image is None:
return "", None
if image.mode != 'RGB':
image = image.convert('RGB')
img_tensor = self.transform(image=np.array(image))['image'].unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(img_tensor, return_features=True)
probs = F.softmax(outputs['logits'], dim=1).cpu().numpy()[0]
attention_maps = self.model.get_attention_maps(img_tensor)
top_indices = np.argsort(probs)[::-1][:int(top_k)]
results = "\n".join([
f"{i+1}. {self.class_names[idx]}: {probs[idx]:.3f}"
for i, idx in enumerate(top_indices)
])
# Attention viz
img_np = np.array(image.resize((224, 224)))
cnn_att = cv2.resize(attention_maps['cnn_attention'].cpu().numpy()[0, 0], (224, 224))
cnn_att = (cnn_att - cnn_att.min()) / (cnn_att.max() - cnn_att.min() + 1e-8)
vit_att = cv2.resize(attention_maps['vit_attention'].cpu().numpy()[0, 0], (224, 224))
vit_att = (vit_att - vit_att.min()) / (vit_att.max() - vit_att.min() + 1e-8)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(img_np)
axes[0].set_title('Original')
axes[0].axis('off')
axes[1].imshow(img_np)
axes[1].imshow(cnn_att, alpha=0.6, cmap='jet')
axes[1].set_title('CNN Attention')
axes[1].axis('off')
axes[2].imshow(img_np)
axes[2].imshow(vit_att, alpha=0.6, cmap='jet')
axes[2].set_title('ViT Attention')
axes[2].axis('off')
plt.tight_layout()
import io
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
buf.seek(0)
attention_img = Image.open(buf)
plt.close(fig)
return results, attention_img
print("Downloading model...")
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="best_model.pth")
mapping_path = hf_hub_download(repo_id=REPO_ID, filename="real_class_mapping.json")
shutil.copy(mapping_path, "real_class_mapping.json")
classifier = FoodClassifier(ckpt_path)
demo = gr.Interface(
fn=classifier.predict,
inputs=[
gr.Image(type="pil", label="Upload Food Image"),
gr.Slider(1, 10, 5, step=1, label="Top K")
],
outputs=[
gr.Textbox(label="Results", lines=10),
gr.Image(label="Attention Maps", height=400)
],
title="Food Classifier",
description="Upload food images for classification"
)
demo.launch(server_name="0.0.0.0", server_port=7860)
|