Spaces:
Runtime error
Runtime error
| import os | |
| from pinecone import Pinecone, ServerlessSpec | |
| from PIL import Image, ImageOps | |
| import numpy as np | |
| from datasets import load_dataset | |
| from pinecone_text.sparse import BM25Encoder | |
| from sentence_transformers import SentenceTransformer | |
| import torch | |
| from tqdm.auto import tqdm | |
| import gradio as gr | |
| # ------------------- Pinecone Setup ------------------- | |
| os.environ["PINECONE_API_KEY"] = "pcsk_TMCYK_LrbmZMTDhkxTjUXcr8iTcQ8LxurwKBFDvv4ahFis8SVob7QexVPPEt6g2zW6d3g" | |
| api_key = os.environ.get('PINECONE_API_KEY') | |
| pc = Pinecone(api_key=api_key) | |
| cloud = os.environ.get('PINECONE_CLOUD') or 'aws' | |
| region = os.environ.get('PINECONE_REGION') or 'us-east-1' | |
| spec = ServerlessSpec(cloud=cloud, region=region) | |
| index_name = "hybrid-image-search" | |
| spec = ServerlessSpec(cloud="aws", region="us-east-1") | |
| # choose a name for your index | |
| index_name = "hybrid-image-search" | |
| import time | |
| # check if index already exists (it shouldn't if this is first time) | |
| if index_name not in pc.list_indexes().names(): | |
| # if does not exist, create index | |
| pc.create_index( | |
| index_name, | |
| dimension=512, | |
| metric='dotproduct', | |
| spec=spec | |
| ) | |
| # wait for index to be initialized | |
| while not pc.describe_index(index_name).status['ready']: | |
| time.sleep(1) | |
| # connect to index | |
| index = pc.Index(index_name) | |
| # view index stats | |
| index.describe_index_stats() | |
| # ------------------- Dataset Loading ------------------- | |
| fashion = load_dataset("ashraq/fashion-product-images-small", split="train") | |
| images = fashion["image"] | |
| metadata = fashion.remove_columns("image").to_pandas() | |
| # ------------------- Encoders ------------------- | |
| bm25 = BM25Encoder() | |
| bm25.fit(metadata["productDisplayName"]) | |
| model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device='cuda' if torch.cuda.is_available() else 'cpu') | |
| from sentence_transformers import SentenceTransformer | |
| import torch | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # load a CLIP model from huggingface | |
| model = SentenceTransformer( | |
| 'sentence-transformers/clip-ViT-B-32', | |
| device=device | |
| ) | |
| model | |
| # ------------------- Hybrid Scaling ------------------- | |
| def hybrid_scale(dense, sparse, alpha: float): | |
| if alpha < 0 or alpha > 1: | |
| raise ValueError("Alpha must be between 0 and 1") | |
| # scale sparse and dense vectors to create hybrid search vecs | |
| hsparse = { | |
| 'indices': sparse['indices'], | |
| 'values': [v * (1 - alpha) for v in sparse['values']] | |
| } | |
| hdense = [v * alpha for v in dense] | |
| return hdense, hsparse | |
| # ------------------- Metadata Filter Extraction ------------------- | |
| from PIL import Image, ImageOps | |
| import numpy as np | |
| from PIL import Image, ImageOps | |
| import numpy as np | |
| def extract_metadata_filters(query: str): | |
| query_lower = query.lower() | |
| gender = None | |
| category = None | |
| subcategory = None | |
| color = None | |
| # --- Gender Mapping --- | |
| gender_map = { | |
| "men": "Men", "man": "Men", "mens": "Men", "mans": "Men", "male": "Men", | |
| "women": "Women", "woman": "Women", "womens": "Women", "female": "Women", | |
| "boys": "Boys", "boy": "Boys", | |
| "girls": "Girls", "girl": "Girls", | |
| "kids": "Kids", "unisex": "Unisex" | |
| } | |
| for term, mapped_value in gender_map.items(): | |
| if term in query_lower: | |
| gender = mapped_value | |
| break | |
| # --- Category Mapping --- | |
| category_map = { | |
| "shirt": "Shirts", | |
| "tshirt": "Tshirts", "t-shirt": "Tshirts", | |
| "jeans": "Jeans", | |
| "watch": "Watches", | |
| "kurta": "Kurtas", | |
| "dress": "Dresses", "dresses": "Dresses", | |
| "trousers": "Trousers", "pants": "Trousers", | |
| "shorts": "Shorts", | |
| "footwear": "Footwear", | |
| "shoes": "Footwear", | |
| "fashion": "Apparel" | |
| } | |
| for term, mapped_value in category_map.items(): | |
| if term in query_lower: | |
| category = mapped_value | |
| break | |
| # --- SubCategory Mapping --- | |
| subCategory_list = [ | |
| "Accessories", "Apparel Set", "Bags", "Bath and Body", "Beauty Accessories", | |
| "Belts", "Bottomwear", "Cufflinks", "Dress", "Eyes", "Eyewear", "Flip Flops", | |
| "Fragrance", "Free Gifts", "Gloves", "Hair", "Headwear", "Home Furnishing", | |
| "Innerwear", "Jewellery", "Lips", "Loungewear and Nightwear", "Makeup", | |
| "Mufflers", "Nails", "Perfumes", "Sandal", "Saree", "Scarves", "Shoe Accessories", | |
| "Shoes", "Skin", "Skin Care", "Socks", "Sports Accessories", "Sports Equipment", | |
| "Stoles", "Ties", "Topwear", "Umbrellas", "Vouchers", "Wallets", "Watches", | |
| "Water Bottle", "Wristbands" | |
| ] | |
| if "topwear" in query_lower or "top" in query_lower: | |
| subcategory = "Topwear" | |
| else: | |
| for subcat in subCategory_list: | |
| if subcat.lower() in query_lower: | |
| subcategory = subcat | |
| break | |
| # --- Color Extraction --- | |
| colors = [ | |
| "red","blue","green","yellow","black","white", | |
| "orange","pink","purple","brown","grey","beige" | |
| ] | |
| for c in colors: | |
| if c in query_lower: | |
| color = c.capitalize() | |
| break | |
| # --- Invalid pairs --- | |
| invalid_pairs = { | |
| ("Men", "Dresses"), ("Men", "Sarees"), ("Men", "Skirts"), | |
| ("Boys", "Dresses"), ("Boys", "Sarees"), | |
| ("Girls", "Boxers"), ("Men", "Heels") | |
| } | |
| if (gender, category) in invalid_pairs: | |
| print(f"β οΈ Invalid pair: {gender} + {category}, dropping gender") | |
| gender = None | |
| # fallback | |
| if gender and not category: | |
| category = "Apparel" | |
| return gender, category, subcategory, color | |
| def search_fashion(query: str, alpha: float): | |
| gender, category, subcategory, color = extract_metadata_filters(query) | |
| # Build Pinecone filter | |
| filter = {} | |
| if gender: | |
| filter["gender"] = gender | |
| if category: | |
| filter["articleType"] = category | |
| if subcategory: | |
| filter["subCategory"] = subcategory | |
| if color: | |
| filter["baseColour"] = color | |
| print(f"π Using filter: {filter}") | |
| # hybrid | |
| sparse = bm25.encode_queries(query) | |
| dense = model.encode(query).tolist() | |
| hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha) | |
| # initial search | |
| result = index.query( | |
| top_k=12, | |
| vector=hdense, | |
| sparse_vector=hsparse, | |
| include_metadata=True, | |
| filter=filter if filter else None | |
| ) | |
| # fallback: if zero results with gender, relax gender | |
| if gender and len(result["matches"]) == 0: | |
| print(f"β οΈ No results with gender {gender}, relaxing gender filter") | |
| filter.pop("gender") | |
| result = index.query( | |
| top_k=12, | |
| vector=hdense, | |
| sparse_vector=hsparse, | |
| include_metadata=True, | |
| filter=filter if filter else None | |
| ) | |
| # results | |
| imgs_with_captions = [] | |
| for r in result["matches"]: | |
| idx = int(r["id"]) | |
| img = images[idx] | |
| meta = r.get("metadata", {}) | |
| if not isinstance(img, Image.Image): | |
| img = Image.fromarray(np.array(img)) | |
| padded = ImageOps.pad(img, (256, 256), color="white") | |
| caption = str(meta.get("productDisplayName", "Unknown Product")) | |
| imgs_with_captions.append((padded, caption)) | |
| return imgs_with_captions | |
| from transformers import CLIPProcessor, CLIPModel | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| from PIL import Image, ImageOps | |
| import numpy as np | |
| def search_by_image(uploaded_image, alpha=0.5): | |
| """ | |
| Given a PIL image from Gradio, find visually similar products. | |
| """ | |
| # Preprocess as CLIP expects | |
| processed = clip_processor(images=uploaded_image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| image_vec = clip_model.get_image_features(**processed) | |
| image_vec = image_vec.cpu().numpy().flatten().tolist() | |
| # since your Pinecone is purely visual, we query on visual vector | |
| result = index.query( | |
| top_k=12, | |
| vector=image_vec, | |
| include_metadata=True | |
| ) | |
| imgs_with_captions = [] | |
| for r in result["matches"]: | |
| idx = int(r["id"]) | |
| img = images[idx] | |
| meta = r.get("metadata", {}) | |
| if not isinstance(img, Image.Image): | |
| img = Image.fromarray(np.array(img)) | |
| padded = ImageOps.pad(img, (256, 256), color="white") | |
| caption = str(meta.get("productDisplayName", "Unknown Product")) | |
| imgs_with_captions.append((padded, caption)) | |
| return imgs_with_captions | |
| custom_css = """ | |
| .search-btn { | |
| width: 100%; | |
| } | |
| .gr-row { | |
| gap: 8px !important; /* slightly tighter column gap */ | |
| } | |
| .query-slider > div { | |
| margin-bottom: 4px !important; /* reduce space between textbox and slider */ | |
| } | |
| """ | |
| # with gr.Blocks(css=custom_css) as demo: | |
| # gr.Markdown("# ποΈ Fashion Product Hybrid Search") | |
| # with gr.Row(equal_height=True): | |
| # with gr.Column(scale=5, elem_classes="query-slider"): | |
| # query = gr.Textbox( | |
| # label="Enter your fashion search query", | |
| # placeholder="Type something or leave blank to only use the image" | |
| # ) | |
| # alpha = gr.Slider( | |
| # 0, 1, value=0.5, | |
| # label="Hybrid Weight (alpha: 0=sparse, 1=dense)" | |
| # ) | |
| # with gr.Column(scale=1): | |
| # image_input = gr.Image( | |
| # type="pil", | |
| # label="Upload an image (optional)", | |
| # height=256, | |
| # width=356, | |
| # show_label=True | |
| # ) | |
| # search_btn = gr.Button("Search", elem_classes="search-btn") | |
| # gallery = gr.Gallery( | |
| # label="Search Results", | |
| # columns=6, | |
| # height="40vh" | |
| # ) | |
| with gr.Blocks(css=custom_css) as demo: | |
| gr.Markdown("# ποΈ Fashion Product Hybrid Search") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=5, elem_classes="query-slider"): | |
| query = gr.Textbox( | |
| label="Enter your fashion search query", | |
| placeholder="Type something or leave blank to only use the image" | |
| ) | |
| alpha = gr.Slider( | |
| 0, 1, value=0.5, | |
| label="Hybrid Weight (alpha: 0=sparse, 1=dense)" | |
| ) | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| source="webcam", # π Enables webcam button | |
| type="pil", | |
| label="π· Capture or Upload Image", | |
| height=256, | |
| width=356, | |
| show_label=True | |
| ) | |
| search_btn = gr.Button("Search", elem_classes="search-btn") | |
| gallery = gr.Gallery( | |
| label="Search Results", | |
| columns=6, | |
| height="40vh" | |
| ) | |
| def unified_search(q, uploaded_image, a): | |
| if uploaded_image is not None: | |
| return search_by_image(uploaded_image, a) | |
| elif q.strip() != "": | |
| return search_fashion(q, a) | |
| else: | |
| return [] | |
| search_btn.click( | |
| unified_search, | |
| inputs=[query, image_input, alpha], | |
| outputs=gallery | |
| ) | |
| gr.Markdown("Powered by your hybrid AI search model π") | |
| demo.launch() |