Spaces:
Sleeping
Sleeping
File size: 3,317 Bytes
cd112fe c10d8f3 cd112fe 6f9b46a 826cc00 c10d8f3 6f9b46a cd112fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import gradio as gr
from transformers import AutoModel, AutoProcessor
import torch
import requests
from PIL import Image
from io import BytesIO
fashion_items = ['top', 'trousers', 'jumper']
# Load model and processor with proper meta tensor handling
model_name = 'Marqo/marqo-fashionSigLIP'
# Force CPU usage to avoid device mapping issues
device = torch.device('cpu')
# Handle meta tensor initialization properly
try:
# Load model with empty weights initialization to avoid meta tensor issues
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float32
)
# Check if model has the to_empty method and use it for meta tensor initialization
if hasattr(model, 'model') and hasattr(model.model, 'to_empty'):
model.model.to_empty(device=device)
elif hasattr(model, 'to_empty'):
model.to_empty(device=device)
else:
# Fallback to regular to() method
model = model.to(device)
except Exception as e:
print(f"Primary loading method failed: {e}")
# Fallback method - load with minimal configuration
try:
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True
)
# Move to CPU after loading
model = model.to(device)
except Exception as e2:
print(f"Fallback method also failed: {e2}")
# Last resort - try loading with low CPU memory usage
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
low_cpu_mem_usage=False # Disable to avoid accelerate issues
)
model = model.to(device)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
# Preprocess and normalize text data
with torch.no_grad():
# Ensure truncation and padding are activated
processed_texts = processor(
text=fashion_items,
return_tensors="pt",
truncation=True, # Ensure text is truncated to fit model input size
padding=True # Pad shorter sequences so that all are the same length
)['input_ids']
text_features = model.get_text_features(processed_texts)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Prediction function
def predict_from_url(url):
# Check if the URL is empty
if not url:
return {"Error": "Please input a URL"}
try:
image = Image.open(BytesIO(requests.get(url).content))
except Exception as e:
return {"Error": f"Failed to load image: {str(e)}"}
processed_image = processor(images=image, return_tensors="pt")['pixel_values']
with torch.no_grad():
image_features = model.get_image_features(processed_image)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
return {fashion_items[i]: float(text_probs[0, i]) for i in range(len(fashion_items))}
# Gradio interface
demo = gr.Interface(
fn=predict_from_url,
inputs=gr.Textbox(label="Enter Image URL"),
outputs=gr.Label(label="Classification Results"),
title="Fashion Item Classifier",
allow_flagging="never"
)
# Launch the interface
demo.launch()
|