classification / app.py
rajkumarrawal's picture
feat: monkey-patch open_clip to avoid meta tensors and load model on CPU
053d849
raw
history blame
3.16 kB
import gradio as gr
from transformers import AutoModel, AutoProcessor
import torch
import requests
from PIL import Image
from io import BytesIO
fashion_items = ['top', 'trousers', 'jumper']
# Load model and processor with proper meta tensor handling
model_name = 'Marqo/marqo-fashionSigLIP'
# Force CPU usage to avoid device mapping issues
device = torch.device('cpu')
# Set environment variables to prevent meta tensor issues
import os
os.environ['HF_HOME'] = '/tmp/hf_cache' # Use temporary cache directory
# Monkey patch open_clip to prevent meta tensor issues
try:
import open_clip
original_create_model = open_clip.factory.create_model
def patched_create_model(*args, **kwargs):
# Force device to CPU to prevent meta tensor creation
kwargs['device'] = 'cpu'
kwargs['precision'] = 'fp32' # Force float32 precision
return original_create_model(*args, **kwargs)
open_clip.factory.create_model = patched_create_model
except Exception as e:
print(f"Could not patch open_clip: {e}")
# Load model with patched open_clip to prevent meta tensor issues
try:
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float32
)
model = model.to(device)
except Exception as e:
print(f"Model loading failed: {e}")
# Fallback - try loading with different configuration
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True
)
model = model.to(device)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
# Preprocess and normalize text data
with torch.no_grad():
# Ensure truncation and padding are activated
processed_texts = processor(
text=fashion_items,
return_tensors="pt",
truncation=True, # Ensure text is truncated to fit model input size
padding=True # Pad shorter sequences so that all are the same length
)['input_ids']
text_features = model.get_text_features(processed_texts)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Prediction function
def predict_from_url(url):
# Check if the URL is empty
if not url:
return {"Error": "Please input a URL"}
try:
image = Image.open(BytesIO(requests.get(url).content))
except Exception as e:
return {"Error": f"Failed to load image: {str(e)}"}
processed_image = processor(images=image, return_tensors="pt")['pixel_values']
with torch.no_grad():
image_features = model.get_image_features(processed_image)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
return {fashion_items[i]: float(text_probs[0, i]) for i in range(len(fashion_items))}
# Gradio interface
demo = gr.Interface(
fn=predict_from_url,
inputs=gr.Textbox(label="Enter Image URL"),
outputs=gr.Label(label="Classification Results"),
title="Fashion Item Classifier",
allow_flagging="never"
)
# Launch the interface
demo.launch()