# Import core libraries import numpy as np import pandas as pd import torch import gradio as gr # Import dataset loader from datasets import load_dataset # Import CLIP model and processor from transformers import CLIPModel, CLIPProcessor # ----------------------------- # Setup # ----------------------------- device = "cuda" if torch.cuda.is_available() else "cpu" MODEL_NAME = "openai/clip-vit-base-patch32" model = CLIPModel.from_pretrained(MODEL_NAME).to(device) processor = CLIPProcessor.from_pretrained(MODEL_NAME) model.eval() # Load precomputed embeddings (image embeddings for the sampled subset) emb_df = pd.read_parquet("clip_embeddings_3000.parquet") embeddings = emb_df.drop(columns=["image_id"]).values.astype(np.float32) # Load sampled indices (to fetch the same 3000 images) sampled_indices = np.load("sampled_indices_3000.npy").astype(int).tolist() # Load dataset and select the sampled subset ds = load_dataset("JamieSJS/stanford-online-products", "corpus", split="corpus") sampled_dataset = ds.select(sampled_indices) # ----------------------------- # Embedding helpers # ----------------------------- def l2_normalize(vec: np.ndarray) -> np.ndarray: return vec / (np.linalg.norm(vec) + 1e-12) def embed_image(image) -> np.ndarray: # Prepare image for CLIP inputs = processor(images=[image], return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} # Extract image features with torch.no_grad(): feats = model.get_image_features(**inputs) vec = feats.cpu().numpy().reshape(-1).astype(np.float32) return l2_normalize(vec) def embed_text(text: str) -> np.ndarray: # Prepare text for CLIP inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(device) for k, v in inputs.items()} # Extract text features with torch.no_grad(): feats = model.get_text_features(**inputs) vec = feats.cpu().numpy().reshape(-1).astype(np.float32) return l2_normalize(vec) def combine_embeddings(image_vec, text_vec, alpha: float) -> np.ndarray: """ alpha = weight for image (1-alpha) = weight for text """ if image_vec is None and text_vec is None: return None if image_vec is None: return text_vec if text_vec is None: return image_vec combo = alpha * image_vec + (1.0 - alpha) * text_vec return l2_normalize(combo.astype(np.float32)) # ----------------------------- # Recommendation function # ----------------------------- def recommend(image, text, alpha): try: # Handle empty inputs if image is None and (text is None or str(text).strip() == ""): return [], "Please upload an image and/or enter a text description." image_vec = None text_vec = None if image is not None: image_vec = embed_image(image) if text is not None and str(text).strip() != "": text_vec = embed_text(str(text).strip()) # Combine user_vec = combine_embeddings(image_vec, text_vec, float(alpha)) if user_vec is None: return [], "Could not compute an embedding from the given inputs." # Cosine similarity (because vectors are normalized) scores = embeddings @ user_vec # Top-3 top_idx = np.argsort(scores)[::-1][:3] top_scores = scores[top_idx] results = [sampled_dataset[int(i)]["image"] for i in top_idx] # Details message mode = [] if image is not None: mode.append("Image") if text is not None and str(text).strip() != "": mode.append("Text") mode_str = " + ".join(mode) msg = ( f"Mode: {mode_str}\n" f"Alpha (image weight): {float(alpha):.2f}\n" f"Top-3 cosine similarity scores: " f"{top_scores[0]:.3f}, {top_scores[1]:.3f}, {top_scores[2]:.3f}" ) return results, msg except Exception as e: return [], f"Error: {str(e)}" # ----------------------------- # Gradio UI # ----------------------------- demo = gr.Interface( fn=recommend, inputs=[ gr.Image(type="pil", label="Upload an image (optional)"), gr.Textbox(label="Text description (optional)", placeholder="e.g., 'small handheld vacuum'"), gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.05, label="Alpha (image vs text weight)"), ], outputs=[ gr.Gallery(label="Top-3 Recommended Images"), gr.Textbox(label="Details"), ], title="Hybrid CLIP Recommender (Image + Text)", description="Upload an image, type a description, or combine both. Recommendations are based on CLIP embeddings + cosine similarity." ) demo.launch(show_error=True, ssr_mode=False)