Barvero's picture
Update app.py
9591d95 verified
# Import core libraries
import numpy as np
import pandas as pd
import torch
import gradio as gr
# Import dataset loader
from datasets import load_dataset
# Import CLIP model and processor
from transformers import CLIPModel, CLIPProcessor
# -----------------------------
# Setup
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(MODEL_NAME).to(device)
processor = CLIPProcessor.from_pretrained(MODEL_NAME)
model.eval()
# Load precomputed embeddings (image embeddings for the sampled subset)
emb_df = pd.read_parquet("clip_embeddings_3000.parquet")
embeddings = emb_df.drop(columns=["image_id"]).values.astype(np.float32)
# Load sampled indices (to fetch the same 3000 images)
sampled_indices = np.load("sampled_indices_3000.npy").astype(int).tolist()
# Load dataset and select the sampled subset
ds = load_dataset("JamieSJS/stanford-online-products", "corpus", split="corpus")
sampled_dataset = ds.select(sampled_indices)
# -----------------------------
# Embedding helpers
# -----------------------------
def l2_normalize(vec: np.ndarray) -> np.ndarray:
return vec / (np.linalg.norm(vec) + 1e-12)
def embed_image(image) -> np.ndarray:
# Prepare image for CLIP
inputs = processor(images=[image], return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Extract image features
with torch.no_grad():
feats = model.get_image_features(**inputs)
vec = feats.cpu().numpy().reshape(-1).astype(np.float32)
return l2_normalize(vec)
def embed_text(text: str) -> np.ndarray:
# Prepare text for CLIP
inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Extract text features
with torch.no_grad():
feats = model.get_text_features(**inputs)
vec = feats.cpu().numpy().reshape(-1).astype(np.float32)
return l2_normalize(vec)
def combine_embeddings(image_vec, text_vec, alpha: float) -> np.ndarray:
"""
alpha = weight for image
(1-alpha) = weight for text
"""
if image_vec is None and text_vec is None:
return None
if image_vec is None:
return text_vec
if text_vec is None:
return image_vec
combo = alpha * image_vec + (1.0 - alpha) * text_vec
return l2_normalize(combo.astype(np.float32))
# -----------------------------
# Recommendation function
# -----------------------------
def recommend(image, text, alpha):
try:
# Handle empty inputs
if image is None and (text is None or str(text).strip() == ""):
return [], "Please upload an image and/or enter a text description."
image_vec = None
text_vec = None
if image is not None:
image_vec = embed_image(image)
if text is not None and str(text).strip() != "":
text_vec = embed_text(str(text).strip())
# Combine
user_vec = combine_embeddings(image_vec, text_vec, float(alpha))
if user_vec is None:
return [], "Could not compute an embedding from the given inputs."
# Cosine similarity (because vectors are normalized)
scores = embeddings @ user_vec
# Top-3
top_idx = np.argsort(scores)[::-1][:3]
top_scores = scores[top_idx]
results = [sampled_dataset[int(i)]["image"] for i in top_idx]
# Details message
mode = []
if image is not None:
mode.append("Image")
if text is not None and str(text).strip() != "":
mode.append("Text")
mode_str = " + ".join(mode)
msg = (
f"Mode: {mode_str}\n"
f"Alpha (image weight): {float(alpha):.2f}\n"
f"Top-3 cosine similarity scores: "
f"{top_scores[0]:.3f}, {top_scores[1]:.3f}, {top_scores[2]:.3f}"
)
return results, msg
except Exception as e:
return [], f"Error: {str(e)}"
# -----------------------------
# Gradio UI
# -----------------------------
demo = gr.Interface(
fn=recommend,
inputs=[
gr.Image(type="pil", label="Upload an image (optional)"),
gr.Textbox(label="Text description (optional)", placeholder="e.g., 'small handheld vacuum'"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.05, label="Alpha (image vs text weight)"),
],
outputs=[
gr.Gallery(label="Top-3 Recommended Images"),
gr.Textbox(label="Details"),
],
title="Hybrid CLIP Recommender (Image + Text)",
description="Upload an image, type a description, or combine both. Recommendations are based on CLIP embeddings + cosine similarity."
)
demo.launch(show_error=True, ssr_mode=False)