classification / app.py
rajkumarrawal's picture
fix: improve model loading with meta tensor handling and CPU fallbacks
c10d8f3
raw
history blame
3.32 kB
import gradio as gr
from transformers import AutoModel, AutoProcessor
import torch
import requests
from PIL import Image
from io import BytesIO
fashion_items = ['top', 'trousers', 'jumper']
# Load model and processor with proper meta tensor handling
model_name = 'Marqo/marqo-fashionSigLIP'
# Force CPU usage to avoid device mapping issues
device = torch.device('cpu')
# Handle meta tensor initialization properly
try:
# Load model with empty weights initialization to avoid meta tensor issues
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float32
)
# Check if model has the to_empty method and use it for meta tensor initialization
if hasattr(model, 'model') and hasattr(model.model, 'to_empty'):
model.model.to_empty(device=device)
elif hasattr(model, 'to_empty'):
model.to_empty(device=device)
else:
# Fallback to regular to() method
model = model.to(device)
except Exception as e:
print(f"Primary loading method failed: {e}")
# Fallback method - load with minimal configuration
try:
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True
)
# Move to CPU after loading
model = model.to(device)
except Exception as e2:
print(f"Fallback method also failed: {e2}")
# Last resort - try loading with low CPU memory usage
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
low_cpu_mem_usage=False # Disable to avoid accelerate issues
)
model = model.to(device)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
# Preprocess and normalize text data
with torch.no_grad():
# Ensure truncation and padding are activated
processed_texts = processor(
text=fashion_items,
return_tensors="pt",
truncation=True, # Ensure text is truncated to fit model input size
padding=True # Pad shorter sequences so that all are the same length
)['input_ids']
text_features = model.get_text_features(processed_texts)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Prediction function
def predict_from_url(url):
# Check if the URL is empty
if not url:
return {"Error": "Please input a URL"}
try:
image = Image.open(BytesIO(requests.get(url).content))
except Exception as e:
return {"Error": f"Failed to load image: {str(e)}"}
processed_image = processor(images=image, return_tensors="pt")['pixel_values']
with torch.no_grad():
image_features = model.get_image_features(processed_image)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
return {fashion_items[i]: float(text_probs[0, i]) for i in range(len(fashion_items))}
# Gradio interface
demo = gr.Interface(
fn=predict_from_url,
inputs=gr.Textbox(label="Enter Image URL"),
outputs=gr.Label(label="Classification Results"),
title="Fashion Item Classifier",
allow_flagging="never"
)
# Launch the interface
demo.launch()