classification / app.py
rajkumarrawal's picture
refactor: simplify Gradio interface by removing allow_flagging
7176315
import gradio as gr
from transformers import AutoModel, AutoProcessor
import torch
import requests
from PIL import Image
from io import BytesIO
fashion_items = ['top', 'trousers', 'jumper']
# Load model and processor with proper meta tensor handling
model_name = 'Marqo/marqo-fashionSigLIP'
# Force CPU usage to avoid device mapping issues
device = torch.device('cpu')
# Set environment variables to prevent meta tensor issues
import os
os.environ['HF_HOME'] = '/tmp/hf_cache' # Use temporary cache directory
# Targeted patching of open_clip to prevent meta tensor issues
try:
import open_clip
import torch.nn as nn
# Store original methods
original_to = nn.Module.to
original_set_model_device_and_precision = open_clip.factory._set_model_device_and_precision
# Patch the problematic _set_model_device_and_precision function
def patched_set_model_device_and_precision(model, device, precision, is_timm_model):
# Force device to CPU and use to_empty instead of to
cpu_device = torch.device('cpu')
if hasattr(model, 'to_empty'):
model.to_empty(device=cpu_device)
else:
# Fallback to original method but with CPU device
try:
original_to(model, device=cpu_device)
except:
# If that fails, try to move parameters individually
for param in model.parameters():
if param.device != cpu_device:
param.data = param.data.to(cpu_device)
if param.grad is not None:
param.grad.data = param.grad.data.to(cpu_device)
# Apply the patch
open_clip.factory._set_model_device_and_precision = patched_set_model_device_and_precision
# Also patch the Module.to method to handle meta tensors
def patched_to(self, *args, **kwargs):
# Check if we're moving from meta device
if hasattr(self, 'parameters'):
for param in self.parameters():
if param.device.type == 'meta':
# Use to_empty instead of to for meta tensors
if hasattr(self, 'to_empty'):
return self.to_empty(device=torch.device('cpu'))
else:
# Create new tensors with the same shape
cpu_device = torch.device('cpu')
for name, param in self.named_parameters(recurse=False):
if param.device.type == 'meta':
new_param = torch.empty_like(param, device=cpu_device)
setattr(self, name, torch.nn.Parameter(new_param))
for name, buffer in self.named_buffers(recurse=False):
if buffer.device.type == 'meta':
new_buffer = torch.empty_like(buffer, device=cpu_device)
setattr(self, name, new_buffer)
return self
# Fallback to original method
return original_to(self, *args, **kwargs)
# Apply the patch
nn.Module.to = patched_to
except Exception as e:
print(f"Could not patch open_clip: {e}")
# Load model with patched open_clip to prevent meta tensor issues
try:
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float32
)
model = model.to(device)
except Exception as e:
print(f"Model loading failed: {e}")
# Fallback - try loading with different configuration
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True
)
model = model.to(device)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
# Preprocess and normalize text data
with torch.no_grad():
# Ensure truncation and padding are activated
processed_texts = processor(
text=fashion_items,
return_tensors="pt",
truncation=True, # Ensure text is truncated to fit model input size
padding=True # Pad shorter sequences so that all are the same length
)['input_ids']
text_features = model.get_text_features(processed_texts)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Prediction function
def predict_from_url(url):
# Check if the URL is empty
if not url:
return {"Error": "Please input a URL"}
try:
image = Image.open(BytesIO(requests.get(url).content))
except Exception as e:
return {"Error": f"Failed to load image: {str(e)}"}
processed_image = processor(images=image, return_tensors="pt")['pixel_values']
with torch.no_grad():
image_features = model.get_image_features(processed_image)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
return {fashion_items[i]: float(text_probs[0, i]) for i in range(len(fashion_items))}
# Gradio interface
demo = gr.Interface(
fn=predict_from_url,
inputs=gr.Textbox(label="Enter Image URL"),
outputs=gr.Label(label="Classification Results"),
title="Fashion Item Classifier"
)
# Launch the interface
demo.launch()