Spaces:
Sleeping
Sleeping
| # Import core libraries | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import gradio as gr | |
| # Import dataset loader | |
| from datasets import load_dataset | |
| # Import CLIP model and processor | |
| from transformers import CLIPModel, CLIPProcessor | |
| # ----------------------------- | |
| # Setup | |
| # ----------------------------- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_NAME = "openai/clip-vit-base-patch32" | |
| model = CLIPModel.from_pretrained(MODEL_NAME).to(device) | |
| processor = CLIPProcessor.from_pretrained(MODEL_NAME) | |
| model.eval() | |
| # Load precomputed embeddings (image embeddings for the sampled subset) | |
| emb_df = pd.read_parquet("clip_embeddings_3000.parquet") | |
| embeddings = emb_df.drop(columns=["image_id"]).values.astype(np.float32) | |
| # Load sampled indices (to fetch the same 3000 images) | |
| sampled_indices = np.load("sampled_indices_3000.npy").astype(int).tolist() | |
| # Load dataset and select the sampled subset | |
| ds = load_dataset("JamieSJS/stanford-online-products", "corpus", split="corpus") | |
| sampled_dataset = ds.select(sampled_indices) | |
| # ----------------------------- | |
| # Embedding helpers | |
| # ----------------------------- | |
| def l2_normalize(vec: np.ndarray) -> np.ndarray: | |
| return vec / (np.linalg.norm(vec) + 1e-12) | |
| def embed_image(image) -> np.ndarray: | |
| # Prepare image for CLIP | |
| inputs = processor(images=[image], return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Extract image features | |
| with torch.no_grad(): | |
| feats = model.get_image_features(**inputs) | |
| vec = feats.cpu().numpy().reshape(-1).astype(np.float32) | |
| return l2_normalize(vec) | |
| def embed_text(text: str) -> np.ndarray: | |
| # Prepare text for CLIP | |
| inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Extract text features | |
| with torch.no_grad(): | |
| feats = model.get_text_features(**inputs) | |
| vec = feats.cpu().numpy().reshape(-1).astype(np.float32) | |
| return l2_normalize(vec) | |
| def combine_embeddings(image_vec, text_vec, alpha: float) -> np.ndarray: | |
| """ | |
| alpha = weight for image | |
| (1-alpha) = weight for text | |
| """ | |
| if image_vec is None and text_vec is None: | |
| return None | |
| if image_vec is None: | |
| return text_vec | |
| if text_vec is None: | |
| return image_vec | |
| combo = alpha * image_vec + (1.0 - alpha) * text_vec | |
| return l2_normalize(combo.astype(np.float32)) | |
| # ----------------------------- | |
| # Recommendation function | |
| # ----------------------------- | |
| def recommend(image, text, alpha): | |
| try: | |
| # Handle empty inputs | |
| if image is None and (text is None or str(text).strip() == ""): | |
| return [], "Please upload an image and/or enter a text description." | |
| image_vec = None | |
| text_vec = None | |
| if image is not None: | |
| image_vec = embed_image(image) | |
| if text is not None and str(text).strip() != "": | |
| text_vec = embed_text(str(text).strip()) | |
| # Combine | |
| user_vec = combine_embeddings(image_vec, text_vec, float(alpha)) | |
| if user_vec is None: | |
| return [], "Could not compute an embedding from the given inputs." | |
| # Cosine similarity (because vectors are normalized) | |
| scores = embeddings @ user_vec | |
| # Top-3 | |
| top_idx = np.argsort(scores)[::-1][:3] | |
| top_scores = scores[top_idx] | |
| results = [sampled_dataset[int(i)]["image"] for i in top_idx] | |
| # Details message | |
| mode = [] | |
| if image is not None: | |
| mode.append("Image") | |
| if text is not None and str(text).strip() != "": | |
| mode.append("Text") | |
| mode_str = " + ".join(mode) | |
| msg = ( | |
| f"Mode: {mode_str}\n" | |
| f"Alpha (image weight): {float(alpha):.2f}\n" | |
| f"Top-3 cosine similarity scores: " | |
| f"{top_scores[0]:.3f}, {top_scores[1]:.3f}, {top_scores[2]:.3f}" | |
| ) | |
| return results, msg | |
| except Exception as e: | |
| return [], f"Error: {str(e)}" | |
| # ----------------------------- | |
| # Gradio UI | |
| # ----------------------------- | |
| demo = gr.Interface( | |
| fn=recommend, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload an image (optional)"), | |
| gr.Textbox(label="Text description (optional)", placeholder="e.g., 'small handheld vacuum'"), | |
| gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.05, label="Alpha (image vs text weight)"), | |
| ], | |
| outputs=[ | |
| gr.Gallery(label="Top-3 Recommended Images"), | |
| gr.Textbox(label="Details"), | |
| ], | |
| title="Hybrid CLIP Recommender (Image + Text)", | |
| description="Upload an image, type a description, or combine both. Recommendations are based on CLIP embeddings + cosine similarity." | |
| ) | |
| demo.launch(show_error=True, ssr_mode=False) | |