Spaces:
Runtime error
Runtime error
| # import os | |
| # from pinecone import Pinecone, ServerlessSpec | |
| # from PIL import Image, ImageOps | |
| # import numpy as np | |
| # from datasets import load_dataset | |
| # from pinecone_text.sparse import BM25Encoder | |
| # from sentence_transformers import SentenceTransformer | |
| # import torch | |
| # from tqdm.auto import tqdm | |
| # import gradio as gr | |
| # # ------------------- Pinecone Setup ------------------- | |
| # os.environ["PINECONE_API_KEY"] = "pcsk_TMCYK_LrbmZMTDhkxTjUXcr8iTcQ8LxurwKBFDvv4ahFis8SVob7QexVPPEt6g2zW6d3g" | |
| # api_key = os.environ.get('PINECONE_API_KEY') | |
| # pc = Pinecone(api_key=api_key) | |
| # cloud = os.environ.get('PINECONE_CLOUD') or 'aws' | |
| # region = os.environ.get('PINECONE_REGION') or 'us-east-1' | |
| # spec = ServerlessSpec(cloud=cloud, region=region) | |
| # index_name = "hybrid-image-search" | |
| # spec = ServerlessSpec(cloud="aws", region="us-east-1") | |
| # # choose a name for your index | |
| # index_name = "hybrid-image-search" | |
| # import time | |
| # # check if index already exists (it shouldn't if this is first time) | |
| # if index_name not in pc.list_indexes().names(): | |
| # # if does not exist, create index | |
| # pc.create_index( | |
| # index_name, | |
| # dimension=512, | |
| # metric='dotproduct', | |
| # spec=spec | |
| # ) | |
| # # wait for index to be initialized | |
| # while not pc.describe_index(index_name).status['ready']: | |
| # time.sleep(1) | |
| # # connect to index | |
| # index = pc.Index(index_name) | |
| # # view index stats | |
| # index.describe_index_stats() | |
| # # ------------------- Dataset Loading ------------------- | |
| # fashion = load_dataset("ashraq/fashion-product-images-small", split="train") | |
| # images = fashion["image"] | |
| # metadata = fashion.remove_columns("image").to_pandas() | |
| # # ------------------- Encoders ------------------- | |
| # bm25 = BM25Encoder() | |
| # bm25.fit(metadata["productDisplayName"]) | |
| # model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device='cuda' if torch.cuda.is_available() else 'cpu') | |
| # from sentence_transformers import SentenceTransformer | |
| # import torch | |
| # device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # # load a CLIP model from huggingface | |
| # model = SentenceTransformer( | |
| # 'sentence-transformers/clip-ViT-B-32', | |
| # device=device | |
| # ) | |
| # model | |
| # # ------------------- Hybrid Scaling ------------------- | |
| # def hybrid_scale(dense, sparse, alpha: float): | |
| # if alpha < 0 or alpha > 1: | |
| # raise ValueError("Alpha must be between 0 and 1") | |
| # # scale sparse and dense vectors to create hybrid search vecs | |
| # hsparse = { | |
| # 'indices': sparse['indices'], | |
| # 'values': [v * (1 - alpha) for v in sparse['values']] | |
| # } | |
| # hdense = [v * alpha for v in dense] | |
| # return hdense, hsparse | |
| # # ------------------- Metadata Filter Extraction ------------------- | |
| # from PIL import Image, ImageOps | |
| # import numpy as np | |
| # from PIL import Image, ImageOps | |
| # import numpy as np | |
| # from PIL import Image, ImageOps | |
| # import numpy as np | |
| # from transformers import CLIPProcessor, CLIPModel | |
| # clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| # clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # def extract_metadata_filters(query: str): | |
| # query_lower = query.lower() | |
| # gender = None | |
| # category = None | |
| # subcategory = None | |
| # color = None | |
| # # --- Gender Mapping --- | |
| # gender_map = { | |
| # "men": "Men", "man": "Men", "mens": "Men", "mans": "Men", "male": "Men", | |
| # "women": "Women", "woman": "Women", "womens": "Women", "female": "Women", | |
| # "boys": "Boys", "boy": "Boys", | |
| # "girls": "Girls", "girl": "Girls", | |
| # "kids": "Kids","kid": "Kids", | |
| # "unisex": "Unisex" | |
| # } | |
| # for term, mapped_value in gender_map.items(): | |
| # if term in query_lower: | |
| # gender = mapped_value | |
| # break | |
| # # --- Category Mapping --- | |
| # category_map = { | |
| # "shirt": "Shirts", | |
| # "tshirt": "Tshirts", "t-shirt": "Tshirts", | |
| # "jeans": "Jeans", | |
| # "watch": "Watches", | |
| # "kurta": "Kurtas", | |
| # "dress": "Dresses", "dresses": "Dresses", | |
| # "trousers": "Trousers", "pants": "Trousers", | |
| # "shorts": "Shorts", | |
| # "footwear": "Footwear", | |
| # "shoes": "Shoes", # note kept as Shoes | |
| # "fashion": "Apparel" | |
| # } | |
| # for term, mapped_value in category_map.items(): | |
| # if term in query_lower: | |
| # category = mapped_value | |
| # break | |
| # # --- SubCategory Mapping --- | |
| # subCategory_list = [ | |
| # "Accessories", "Apparel Set", "Bags", "Bath and Body", "Beauty Accessories", | |
| # "Belts", "Bottomwear", "Cufflinks", "Dress", "Eyes", "Eyewear", "Flip Flops", | |
| # "Fragrance", "Free Gifts", "Gloves", "Hair", "Headwear", "Home Furnishing", | |
| # "Innerwear", "Jewellery", "Lips", "Loungewear and Nightwear", "Makeup", | |
| # "Mufflers", "Nails", "Perfumes", "Sandal", "Saree", "Scarves", "Shoe Accessories", | |
| # "Shoes", "Skin", "Skin Care", "Socks", "Sports Accessories", "Sports Equipment", | |
| # "Stoles", "Ties", "Topwear", "Umbrellas", "Vouchers", "Wallets", "Watches", | |
| # "Water Bottle", "Wristbands" | |
| # ] | |
| # if "topwear" in query_lower or "top" in query_lower: | |
| # subcategory = "Topwear" | |
| # else: | |
| # for subcat in subCategory_list: | |
| # if subcat.lower() in query_lower: | |
| # subcategory = subcat | |
| # break | |
| # # --- Color Extraction --- | |
| # colors = [ | |
| # "red","blue","green","yellow","black","white", | |
| # "orange","pink","purple","brown","grey","beige" | |
| # ] | |
| # for c in colors: | |
| # if c in query_lower: | |
| # color = c.capitalize() | |
| # break | |
| # # --- Invalid pairs --- | |
| # invalid_pairs = { | |
| # ("Men", "Dresses"), ("Men", "Sarees"), ("Men", "Skirts"), | |
| # ("Boys", "Dresses"), ("Boys", "Sarees"), | |
| # ("Girls", "Boxers"), ("Men", "Heels") | |
| # } | |
| # if (gender, category) in invalid_pairs: | |
| # print(f"⚠️ Invalid pair: {gender} + {category}, dropping gender") | |
| # gender = None | |
| # # fallback | |
| # if gender and not category: | |
| # category = "Apparel" | |
| # return gender, category, subcategory, color | |
| # def search_fashion(query: str, alpha: float, start: int = 0, end: int = 12, gender_override: str = None): | |
| # gender, category, subcategory, color = extract_metadata_filters(query) | |
| # # override from dropdown | |
| # if gender_override: | |
| # gender = gender_override | |
| # # --- Pinecone Filter --- | |
| # filter = {} | |
| # if gender: | |
| # filter["gender"] = gender | |
| # if category: | |
| # if category in ["Footwear", "Shoes"]: | |
| # shoe_article_types = [ | |
| # "Casual Shoes", "Sports Shoes", "Formal Shoes", "Training Shoes", | |
| # "Sneakers", "Sandals", "Slippers", "Boots", "Flip Flops" | |
| # ] | |
| # filter["articleType"] = {"$in": shoe_article_types} | |
| # else: | |
| # filter["articleType"] = category | |
| # if subcategory: | |
| # filter["subCategory"] = subcategory | |
| # if color: | |
| # filter["baseColour"] = color | |
| # print(f"🔍 Using filter: {filter} (showing {start} to {end})") | |
| # sparse = bm25.encode_queries(query) | |
| # dense = model.encode(query).tolist() | |
| # hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha) | |
| # result = index.query( | |
| # top_k=end, | |
| # vector=hdense, | |
| # sparse_vector=hsparse, | |
| # include_metadata=True, | |
| # filter=filter if filter else None | |
| # ) | |
| # # fallback if no results | |
| # if len(result["matches"]) == 0: | |
| # print("⚠️ No results, retrying with alpha=0 sparse only") | |
| # hdense, hsparse = hybrid_scale(dense, sparse, alpha=0) | |
| # result = index.query( | |
| # top_k=end, | |
| # vector=hdense, | |
| # sparse_vector=hsparse, | |
| # include_metadata=True, | |
| # filter=filter if filter else None | |
| # ) | |
| # # fallback if no results with gender | |
| # if gender and len(result["matches"]) == 0: | |
| # print(f"⚠️ No results for gender {gender}, relaxing gender filter") | |
| # filter.pop("gender", None) | |
| # result = index.query( | |
| # top_k=end, | |
| # vector=hdense, | |
| # sparse_vector=hsparse, | |
| # include_metadata=True, | |
| # filter=filter if filter else None | |
| # ) | |
| # matches = result["matches"][start:end] | |
| # imgs_with_captions = [] | |
| # for r in matches: | |
| # idx = int(r["id"]) | |
| # img = images[idx] | |
| # meta = r.get("metadata", {}) | |
| # if not isinstance(img, Image.Image): | |
| # img = Image.fromarray(np.array(img)) | |
| # padded = ImageOps.pad(img, (256, 256), color="white") | |
| # caption = str(meta.get("productDisplayName", "Unknown Product")) | |
| # imgs_with_captions.append((padded, caption)) | |
| # return imgs_with_captions | |
| # # this is working code block | |
| # from PIL import Image, ImageOps | |
| # import numpy as np | |
| # def search_by_image(uploaded_image, alpha=0.5, start=0, end=12): | |
| # """ | |
| # Search visually similar products with support for pagination. | |
| # """ | |
| # # Preprocess image for CLIP | |
| # processed = clip_processor(images=uploaded_image, return_tensors="pt").to(device) | |
| # with torch.no_grad(): | |
| # image_vec = clip_model.get_image_features(**processed) | |
| # image_vec = image_vec.cpu().numpy().flatten().tolist() | |
| # # Query a larger top_k so you have enough to paginate | |
| # result = index.query( | |
| # top_k=end, | |
| # vector=image_vec, | |
| # include_metadata=True | |
| # ) | |
| # matches = result["matches"][start:end] # slice for pagination | |
| # imgs_with_captions = [] | |
| # for r in matches: | |
| # idx = int(r["id"]) | |
| # img = images[idx] | |
| # meta = r.get("metadata", {}) | |
| # if not isinstance(img, Image.Image): | |
| # img = Image.fromarray(np.array(img)) | |
| # padded = ImageOps.pad(img, (256, 256), color="white") | |
| # caption = str(meta.get("productDisplayName", "Unknown Product")) | |
| # imgs_with_captions.append((padded, caption)) | |
| # return imgs_with_captions | |
| # # with gr.Blocks(css=custom_css) as demo: | |
| # # gr.Markdown("# 🛍️ Fashion Product Hybrid Search") | |
| # # with gr.Row(equal_height=True): | |
| # # with gr.Column(scale=5, elem_classes="query-slider"): | |
| # # query = gr.Textbox( | |
| # # label="Enter your fashion search query", | |
| # # placeholder="Type something or leave blank to only use the image" | |
| # # ) | |
| # # alpha = gr.Slider( | |
| # # 0, 1, value=0.5, | |
| # # label="Hybrid Weight (alpha: 0=sparse, 1=dense)" | |
| # # ) | |
| # # with gr.Column(scale=1): | |
| # # image_input = gr.Image( | |
| # # type="pil", | |
| # # label="Upload an image (optional)", | |
| # # height=256, | |
| # # width=356, | |
| # # show_label=True | |
| # # ) | |
| # # search_btn = gr.Button("Search", elem_classes="search-btn") | |
| # # gallery = gr.Gallery( | |
| # # label="Search Results", | |
| # # columns=6, | |
| # # height="40vh" | |
| # # ) | |
| # import gradio as gr | |
| # custom_css = """ | |
| # .search-btn { | |
| # width: 100%; | |
| # } | |
| # .gr-row { | |
| # gap: 8px !important; | |
| # } | |
| # .query-slider > div { | |
| # margin-bottom: 4px !important; | |
| # } | |
| # .upload-box .icon-container { | |
| # display: none !important; | |
| # } | |
| # """ | |
| # with gr.Blocks(css=custom_css) as demo: | |
| # gr.Markdown("# 🛍️ Fashion Product Hybrid Search") | |
| # with gr.Row(equal_height=True): | |
| # with gr.Column(scale=5, elem_classes="query-slider"): | |
| # query = gr.Textbox( | |
| # label="Enter your fashion search query", | |
| # placeholder="Type something or leave blank to only use the image" | |
| # ) | |
| # alpha = gr.Slider(0, 1, value=0.5, label="Hybrid Weight (alpha: 0=sparse, 1=dense)") | |
| # gender_dropdown = gr.Dropdown( | |
| # ["", "Men", "Women", "Boys", "Girls", "Kids", "Unisex"], | |
| # label="Gender Filter (optional)" | |
| # ) | |
| # # with gr.Column(scale=1): | |
| # # image_input = gr.Image( | |
| # # type="pil", | |
| # # label="Upload an image (optional)", | |
| # # height=256, | |
| # # width=356 | |
| # # ) | |
| # with gr.Column(scale=1): | |
| # image_input = gr.Image( | |
| # type="pil", | |
| # label="Upload an image (optional)", | |
| # height=256, | |
| # width=356, | |
| # sources=["upload", "clipboard"] # only upload and paste allowed | |
| # ) | |
| # search_btn = gr.Button("Search", elem_classes="search-btn") | |
| # gallery = gr.Gallery(label="Search Results", columns=6, height="50vh") | |
| # load_more_btn = gr.Button("Load More") | |
| # # States to track | |
| # search_offset = gr.State(0) | |
| # current_query = gr.State("") | |
| # current_image = gr.State(None) | |
| # current_gender = gr.State("") | |
| # shown_results = gr.State([]) # new: store the list of shown images | |
| # def unified_search(q, uploaded_image, a, offset, gender_ui): | |
| # start = 0 | |
| # end = 12 | |
| # gender_override = gender_ui if gender_ui else None | |
| # if uploaded_image is not None: | |
| # results = search_by_image(uploaded_image, a, start, end) | |
| # elif q.strip() != "": | |
| # results = search_fashion(q, a, start, end, gender_override) | |
| # else: | |
| # results = [] | |
| # # reset shown_results to just these first 12 | |
| # return results, end, q, uploaded_image, gender_ui, results | |
| # search_btn.click( | |
| # unified_search, | |
| # inputs=[query, image_input, alpha, search_offset, gender_dropdown], | |
| # outputs=[gallery, search_offset, current_query, current_image, current_gender, shown_results] | |
| # ) | |
| # def load_more_fn(a, offset, q, img, gender_ui, prev_results): | |
| # start = offset | |
| # end = offset + 12 | |
| # gender_override = gender_ui if gender_ui else None | |
| # if img is not None: | |
| # new_results = search_by_image(img, a, start, end) | |
| # elif q.strip() != "": | |
| # new_results = search_fashion(q, a, start, end, gender_override) | |
| # else: | |
| # new_results = [] | |
| # combined_results = prev_results + new_results | |
| # return combined_results, end, combined_results | |
| # load_more_btn.click( | |
| # load_more_fn, | |
| # inputs=[alpha, search_offset, current_query, current_image, current_gender, shown_results], | |
| # outputs=[gallery, search_offset, shown_results] | |
| # ) | |
| # gr.Markdown("Powered by your hybrid AI search model 🚀") | |
| # demo.launch() | |
| # app.py | |
| import os | |
| import time | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image, ImageOps | |
| from tqdm.auto import tqdm | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| from pinecone import Pinecone, ServerlessSpec | |
| from pinecone_text.sparse import BM25Encoder | |
| from transformers import CLIPProcessor, CLIPModel | |
| import openai | |
| # ------------------- Keys & Setup ------------------- | |
| openai.api_key = os.getenv("OPENAI_API_KEY") | |
| pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) | |
| spec = ServerlessSpec(cloud=os.getenv("PINECONE_CLOUD") or "aws", region=os.getenv("PINECONE_REGION") or "us-east-1") | |
| index_name = "hybrid-image-search" | |
| if index_name not in pc.list_indexes().names(): | |
| pc.create_index(index_name, dimension=512, metric='dotproduct', spec=spec) | |
| while not pc.describe_index(index_name).status['ready']: | |
| time.sleep(1) | |
| index = pc.Index(index_name) | |
| # ------------------- Models & Dataset ------------------- | |
| fashion = load_dataset("ashraq/fashion-product-images-small", split="train") | |
| images = fashion["image"] | |
| metadata = fashion.remove_columns("image").to_pandas() | |
| bm25 = BM25Encoder() | |
| bm25.fit(metadata["productDisplayName"]) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device=device) | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # ------------------- Helper Functions ------------------- | |
| def hybrid_scale(dense, sparse, alpha: float): | |
| if alpha < 0 or alpha > 1: | |
| raise ValueError("Alpha must be between 0 and 1") | |
| hsparse = { | |
| 'indices': sparse['indices'], | |
| 'values': [v * (1 - alpha) for v in sparse['values']] | |
| } | |
| hdense = [v * alpha for v in dense] | |
| return hdense, hsparse | |
| def extract_intent_from_openai(query: str): | |
| prompt = f""" | |
| You are an assistant for a fashion search engine. Extract the user's intent from the following query. | |
| Return a Python dictionary with keys: category, gender, subcategory, color. | |
| If something is missing, use null. | |
| Query: "{query}" | |
| Only return the dictionary. | |
| """ | |
| try: | |
| response = openai.ChatCompletion.create( | |
| model="gpt-4", | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0 | |
| ) | |
| raw = response.choices[0].message['content'] | |
| structured = eval(raw) | |
| return structured | |
| except Exception as e: | |
| print(f"⚠️ OpenAI intent extraction failed: {e}") | |
| return {"include": {}, "exclude": {}} | |
| #-----------------below changed------------------------------# | |
| import imagehash | |
| from PIL import Image | |
| def is_duplicate(img, existing_hashes, hash_size=16, tolerance=0): | |
| """ | |
| Checks if the image is a near-duplicate based on perceptual hash. | |
| :param img: PIL Image | |
| :param existing_hashes: set of previously seen hashes | |
| :param hash_size: size of the hash (default=16 for more precision) | |
| :param tolerance: allowable Hamming distance for near-duplicates | |
| :return: (bool) whether image is duplicate | |
| """ | |
| img_hash = imagehash.phash(img, hash_size=hash_size) | |
| for h in existing_hashes: | |
| if abs(img_hash - h) <= tolerance: | |
| return True | |
| existing_hashes.add(img_hash) | |
| return False | |
| def extract_metadata_filters(query: str): | |
| query_lower = query.lower() | |
| gender = None | |
| category = None | |
| subcategory = None | |
| color = None | |
| # --- Gender Mapping --- | |
| gender_map = { | |
| "men": "Men", "man": "Men", "mens": "Men", "mans": "Men", "male": "Men", | |
| "women": "Women", "woman": "Women", "womens": "Women", "female": "Women", | |
| "boys": "Boys", "boy": "Boys", | |
| "girls": "Girls", "girl": "Girls", | |
| "kids": "Kids", "kid": "Kids", | |
| "unisex": "Unisex" | |
| } | |
| for term, mapped_value in gender_map.items(): | |
| if term in query_lower: | |
| gender = mapped_value | |
| break | |
| # --- Category Mapping --- | |
| category_map = { | |
| "shirt": "Shirts", | |
| "tshirt": "Tshirts", | |
| "t-shirt": "Tshirts", | |
| "jeans": "Jeans", | |
| "watch": "Watches", | |
| "kurta": "Kurtas", | |
| "dress": "Dresses", | |
| "trousers": "Trousers", "pants": "Trousers", | |
| "shorts": "Shorts", | |
| "footwear": "Footwear", | |
| "shoes": "Shoes", | |
| "fashion": "Apparel" | |
| } | |
| for term, mapped_value in category_map.items(): | |
| if term in query_lower: | |
| category = mapped_value | |
| break | |
| # --- SubCategory Mapping --- | |
| subCategory_list = [ | |
| "Accessories", "Apparel Set", "Bags", "Bath and Body", "Beauty Accessories", | |
| "Belts", "Bottomwear", "Cufflinks", "Dress", "Eyes", "Eyewear", "Flip Flops", | |
| "Fragrance", "Free Gifts", "Gloves", "Hair", "Headwear", "Home Furnishing", | |
| "Innerwear", "Jewellery", "Lips", "Loungewear and Nightwear", "Makeup", | |
| "Mufflers", "Nails", "Perfumes", "Sandal", "Saree", "Scarves", "Shoe Accessories", | |
| "Shoes", "Skin", "Skin Care", "Socks", "Sports Accessories", "Sports Equipment", | |
| "Stoles", "Ties", "Topwear", "Umbrellas", "Vouchers", "Wallets", "Watches", | |
| "Water Bottle", "Wristbands" | |
| ] | |
| if "topwear" in query_lower or "top" in query_lower: | |
| subcategory = "Topwear" | |
| else: | |
| query_words = query_lower.split() | |
| for subcat in subCategory_list: | |
| if subcat.lower() in query_words: | |
| subcategory = subcat | |
| break | |
| # --- Color Extraction --- | |
| color_list = [ | |
| "red", "blue", "green", "yellow", "black", "white", | |
| "orange", "pink", "purple", "brown", "grey", "beige" | |
| ] | |
| for c in color_list: | |
| if c in query_lower: | |
| color = c.capitalize() | |
| break | |
| # --- Invalid pairs --- | |
| invalid_pairs = { | |
| ("Men", "Dresses"), ("Men", "Sarees"), ("Men", "Skirts"), | |
| ("Boys", "Dresses"), ("Boys", "Sarees"), | |
| ("Girls", "Boxers"), ("Men", "Heels") | |
| } | |
| if (gender, category) in invalid_pairs: | |
| print(f"⚠️ Invalid pair: {gender} + {category}, dropping gender") | |
| gender = None | |
| # --- Fallback for missing category --- | |
| if gender and not category: | |
| category = "Apparel" | |
| # --- Refine subcategory for party/wedding-related queries --- | |
| if "party" in query_lower or "wedding" in query_lower or "cocktail" in query_lower: | |
| if subcategory in ["Loungewear and Nightwear", "Nightdress", "Innerwear"]: | |
| subcategory = None # reset it to avoid filtering into wrong items | |
| return gender, category, subcategory, color | |
| # ------------------- Search Functions ------------------- | |
| def search_fashion(query: str, alpha: float, start: int = 0, end: int = 12, gender_override: str = None): | |
| intent = extract_intent_from_openai(query) | |
| include = intent.get("include", {}) | |
| exclude = intent.get("exclude", {}) | |
| gender = include.get("gender") | |
| category = include.get("category") | |
| subcategory = include.get("subcategory") | |
| color = include.get("color") | |
| # Apply override from dropdown | |
| if gender_override: | |
| gender = gender_override | |
| # Build Pinecone filter | |
| filter = {} | |
| # Inclusion filters | |
| if gender: | |
| filter["gender"] = gender | |
| if category: | |
| if category in ["Footwear", "Shoes"]: | |
| filter["articleType"] = {"$regex": ".*(Shoe|Footwear).*"} | |
| else: | |
| filter["articleType"] = category | |
| if subcategory: | |
| filter["subCategory"] = subcategory | |
| # Step 4: Exclude irrelevant items for party-like queries | |
| query_lower = query.lower() | |
| if any(word in query_lower for word in ["party", "wedding", "cocktail", "traditional", "reception"]): | |
| filter.setdefault("subCategory", {}) | |
| if isinstance(filter["subCategory"], dict): | |
| filter["subCategory"]["$nin"] = [ | |
| "Loungewear and Nightwear", "Nightdress", "Innerwear", "Sleepwear", "Vests", "Boxers" | |
| ] | |
| if color: | |
| filter["baseColour"] = color | |
| # Exclusion filters | |
| exclude_filter = {} | |
| if exclude.get("color"): | |
| exclude_filter["baseColour"] = {"$ne": exclude["color"]} | |
| if exclude.get("subcategory"): | |
| exclude_filter["subCategory"] = {"$ne": exclude["subcategory"]} | |
| if exclude.get("category"): | |
| exclude_filter["articleType"] = {"$ne": exclude["category"]} | |
| # Combine all filters | |
| if filter and exclude_filter: | |
| final_filter = {"$and": [filter, exclude_filter]} | |
| elif filter: | |
| final_filter = filter | |
| elif exclude_filter: | |
| final_filter = exclude_filter | |
| else: | |
| final_filter = None | |
| print(f"🔍 Using filter: {final_filter} (showing {start} to {end})") | |
| # Hybrid encoding | |
| sparse = bm25.encode_queries(query) | |
| dense = model.encode(query).tolist() | |
| hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha) | |
| result = index.query( | |
| top_k=100, | |
| vector=hdense, | |
| sparse_vector=hsparse, | |
| include_metadata=True, | |
| filter=final_filter | |
| ) | |
| # Retry fallback | |
| if len(result["matches"]) == 0: | |
| print("⚠️ No results, retrying with alpha=0 sparse only") | |
| hdense, hsparse = hybrid_scale(dense, sparse, alpha=0) | |
| result = index.query( | |
| top_k=100, | |
| vector=hdense, | |
| sparse_vector=hsparse, | |
| include_metadata=True, | |
| filter=final_filter | |
| ) | |
| # Format results | |
| imgs_with_captions = [] | |
| seen_hashes = set() | |
| for r in result["matches"]: | |
| idx = int(r["id"]) | |
| img = images[idx] | |
| meta = r.get("metadata", {}) | |
| if not isinstance(img, Image.Image): | |
| img = Image.fromarray(np.array(img)) | |
| padded = ImageOps.pad(img, (256, 256), color="white") | |
| caption = str(meta.get("productDisplayName", "Unknown Product")) | |
| if not is_duplicate(padded, seen_hashes): | |
| imgs_with_captions.append((padded, caption)) | |
| if len(imgs_with_captions) >= end: | |
| break | |
| return imgs_with_captions | |
| def search_by_image(uploaded_image, alpha=0.5, start=0, end=12): | |
| # Step 1: Preprocess image for CLIP model | |
| processed = clip_processor(images=uploaded_image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| image_vec = clip_model.get_image_features(**processed) | |
| image_vec = image_vec.cpu().numpy().flatten().tolist() | |
| # Step 2: Query Pinecone index for similar images | |
| result = index.query( | |
| top_k=100, # fetch more to allow deduplication | |
| vector=image_vec, | |
| include_metadata=True | |
| ) | |
| matches = result["matches"] | |
| imgs_with_captions = [] | |
| seen_hashes = set() | |
| # Step 3: Deduplicate based on image hash | |
| for r in matches: | |
| idx = int(r["id"]) | |
| img = images[idx] | |
| meta = r.get("metadata", {}) | |
| caption = str(meta.get("productDisplayName", "Unknown Product")) | |
| if not isinstance(img, Image.Image): | |
| img = Image.fromarray(np.array(img)) | |
| padded = ImageOps.pad(img, (256, 256), color="white") | |
| if not is_duplicate(padded, seen_hashes): | |
| imgs_with_captions.append((padded, caption)) | |
| if len(imgs_with_captions) >= end: | |
| break | |
| return imgs_with_captions | |
| import gradio as gr | |
| import whisper | |
| asr_model = whisper.load_model("base") | |
| def handle_voice_search(vf_path, a, offset, gender_ui): | |
| try: | |
| transcription = asr_model.transcribe(vf_path)["text"].strip() | |
| except: | |
| transcription = "" | |
| filters = extract_intent_from_openai(transcription) if transcription else {} | |
| gender_override = gender_ui if gender_ui else filters.get("gender") | |
| results = search_fashion(transcription, a, 0, 12, gender_override) | |
| seen_ids = {r[1] for r in results} | |
| return results, 12, transcription, None, gender_override, results, seen_ids | |
| custom_css = """ | |
| /* === Global Styling === */ | |
| /* === Override Gradio default background === */ | |
| html, body { | |
| height: 100% !important; | |
| margin: 0 !important; | |
| padding: 0 !important; | |
| background: radial-gradient(circle at center, #0b1f36 0%, #033e3e 100%) !important; | |
| background-attachment: fixed; | |
| } | |
| .gr-root, .gr-block { | |
| background: transparent !important; | |
| } | |
| body::before { | |
| content: ""; | |
| position: fixed; | |
| top: 0; left: 0; | |
| width: 100%; height: 100%; | |
| background: radial-gradient(circle at center, rgba(0, 255, 255, 0.08), transparent); | |
| z-index: -1; | |
| } | |
| #app-bg { | |
| min-height: 100vh; | |
| padding: 0; | |
| margin: 0; | |
| background: radial-gradient(circle at center, #0b1f36 0%, #033e3e 100%); | |
| display: flex; | |
| justify-content: center; | |
| align-items: flex-start; | |
| background-attachment: fixed; | |
| position: relative; | |
| overflow: hidden; | |
| } | |
| #app-bg::before { | |
| content: ""; | |
| position: absolute; | |
| top: 0; left: 0; | |
| width: 100%; height: 100%; | |
| background: radial-gradient(circle at center, rgba(0, 255, 255, 0.08), transparent); | |
| z-index: 0; | |
| } | |
| #main-container { | |
| z-index: 1; | |
| position: relative; | |
| } | |
| /* === Heading Style === */ | |
| h1, .gr-markdown h1 { | |
| font-size: 2.2rem !important; | |
| font-weight: bold; | |
| color: #000000; | |
| text-align: center; | |
| margin-bottom: 1rem; | |
| } | |
| /* === Tabs === */ | |
| .gr-tab { | |
| border-radius: 12px !important; | |
| background-color: #ffffff !important; | |
| box-shadow: 0 3px 10px rgba(0, 0, 0, 0.08); | |
| padding: 16px !important; | |
| margin-top: 12px; | |
| } | |
| /* === Textbox, Dropdown, Slider === */ | |
| input[type="text"], .gr-textbox textarea, .gr-dropdown, .gr-slider { | |
| border-radius: 8px !important; | |
| border: 1px solid #ccc !important; | |
| padding: 10px !important; | |
| font-size: 16px; | |
| box-shadow: 0 1px 3px rgba(0,0,0,0.05); | |
| } | |
| /* === Image Upload === */ | |
| .gr-image { | |
| width: 100% !important; | |
| max-width: 100% !important; | |
| border-radius: 12px; | |
| box-shadow: 0 2px 10px rgba(0,0,0,0.1); | |
| } | |
| /* === Buttons (custom style .button-36) === */ | |
| .gr-button { | |
| background-color: #DBDBDB !important; | |
| background-image: linear-gradient(92.88deg, #455EB5 9.16%, #5643CC 43.89%, #673FD7 64.72%); | |
| border-radius: 8px !important; | |
| border-style: none !important; | |
| box-sizing: border-box; | |
| color: #FFFFFF !important; | |
| cursor: pointer; | |
| flex-shrink: 0; | |
| font-family: "Inter UI","SF Pro Display",-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,Oxygen,Ubuntu,Cantarell,"Open Sans","Helvetica Neue",sans-serif; | |
| font-size: 16px; | |
| font-weight: 500; | |
| height: 4rem; | |
| padding: 0 1.6rem; | |
| text-align: center; | |
| text-shadow: rgba(0, 0, 0, 0.25) 0 3px 8px; | |
| transition: all .5s; | |
| user-select: none; | |
| -webkit-user-select: none; | |
| touch-action: manipulation; | |
| } | |
| .gr-button:hover { | |
| box-shadow: rgba(80, 63, 205, 0.5) 0 1px 30px; | |
| transition-duration: .1s; | |
| } | |
| /* === Responsive padding === */ | |
| @media (min-width: 768px) { | |
| .gr-button { | |
| padding: 0 2.6rem; | |
| } | |
| } | |
| /* === Gallery Grid === */ | |
| .gr-gallery { | |
| padding-top: 12px; | |
| } | |
| .gr-gallery-item { | |
| width: 128px !important; | |
| height: 128px !important; | |
| transition: transform 0.3s ease-in-out; | |
| border-radius: 8px; | |
| overflow: hidden; | |
| } | |
| .gr-gallery-item:hover { | |
| transform: scale(1.06); | |
| box-shadow: 0 3px 12px rgba(0,0,0,0.15); | |
| } | |
| .gr-gallery-item img { | |
| object-fit: cover !important; | |
| width: 100% !important; | |
| height: 100% !important; | |
| border-radius: 8px; | |
| } | |
| /* === Audio Upload === */ | |
| .gr-audio { | |
| width: 100% !important; | |
| border-radius: 12px; | |
| background-color: #fff !important; | |
| box-shadow: 0 1px 5px rgba(0,0,0,0.1); | |
| } | |
| /* === Footer === */ | |
| .gr-markdown:last-child { | |
| text-align: center; | |
| font-size: 14px; | |
| color: #666; | |
| padding-top: 1rem; | |
| } | |
| /* === Main Container Centered and Wide === */ | |
| #main-container { | |
| max-width: 90%; | |
| width: 1100px; | |
| margin: 40px auto !important; | |
| padding: 24px; | |
| background: #ffffff; | |
| border-radius: 18px; | |
| box-shadow: 0 10px 30px rgba(0,0,0,0.08); | |
| border: 3px solid orange; /* Orange border */ | |
| } | |
| /* === Tab Label Styling === */ | |
| button[role="tab"] { | |
| color: #000000 !important; /* Default tab text color: black */ | |
| font-weight: 500; | |
| transition: color 0.3s ease-in-out; | |
| font-size: 16px; | |
| } | |
| /* Active tab title */ | |
| button[role="tab"][aria-selected="true"] { | |
| color: #f57c00 !important; /* Active tab text color: orange */ | |
| font-weight: bold !important; | |
| } | |
| /* Hover effect on tab titles */ | |
| button[role="tab"]:hover { | |
| color: #f57c00 !important; /* Orange on hover */ | |
| font-weight: 600; | |
| cursor: pointer; | |
| } | |
| /* === Uniform Input Sizes for Text, Audio, Image === */ | |
| .gr-textbox, .gr-audio, .gr-image { | |
| max-width: 100% !important; | |
| width: 100% !important; | |
| } | |
| .gr-audio, .gr-image { | |
| max-width: 500px !important; | |
| margin: 0 auto; | |
| } | |
| .gr-image { | |
| height: 256px !important; | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css) as demo: | |
| with gr.Column(elem_id="app-bg"): | |
| with gr.Column(elem_id="main-container"): | |
| gr.Markdown("# 🛍️ Fashion Product Hybrid Search") | |
| alpha = gr.Slider(0, 1, value=0.5, label="Hybrid Weight (alpha: 0=sparse, 1=dense)") | |
| with gr.Tabs(): | |
| with gr.Tab("Text Search"): | |
| query = gr.Textbox( | |
| label="Text Query", | |
| placeholder="e.g., floral summer dress for women" | |
| ) | |
| gender_dropdown = gr.Dropdown( | |
| ["", "Men", "Women", "Boys", "Girls", "Kids", "Unisex"], | |
| label="Gender Filter (optional)" | |
| ) | |
| text_search_btn = gr.Button("Search by Text", elem_classes="search-btn") | |
| with gr.Tab("🎙️ Voice Search"): | |
| voice_input = gr.Audio(label="Speak Your Query", type="filepath") | |
| voice_gender_dropdown = gr.Dropdown(["", "Men", "Women", "Boys", "Girls", "Kids", "Unisex"], label="Gender") | |
| voice_search_btn = gr.Button("Search by Voice") | |
| with gr.Tab("Image Search"): | |
| # image_input = gr.Image( | |
| # type="pil", | |
| # label="Upload an image", | |
| # sources=["upload", "clipboard"], | |
| # height=256, | |
| # width=356 | |
| # ) | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Upload an image", | |
| sources=["upload", "clipboard"], | |
| # tool=None, | |
| height=400 | |
| ) | |
| image_gender_dropdown = gr.Dropdown( | |
| ["", "Men", "Women", "Boys", "Girls", "Kids", "Unisex"], | |
| label="Gender Filter (optional)" | |
| ) | |
| image_search_btn = gr.Button("Search by Image", elem_classes="search-btn") | |
| gallery = gr.Gallery(label="Search Results", columns=6, height=None) | |
| load_more_btn = gr.Button("Load More") | |
| # --- UI State Holders --- | |
| search_offset = gr.State(0) | |
| current_query = gr.State("") | |
| current_image = gr.State(None) | |
| current_gender = gr.State("") | |
| shown_results = gr.State([]) | |
| shown_ids = gr.State(set()) | |
| # --- Unified Search Function --- | |
| def unified_search(q, uploaded_image, a, offset, gender_ui): | |
| start = 0 | |
| end = 12 | |
| filters = extract_intent_from_openai(q) if q.strip() else {} | |
| gender_override = gender_ui if gender_ui else filters.get("gender") | |
| if uploaded_image is not None: | |
| results = search_by_image(uploaded_image, a, start, end) | |
| elif q.strip(): | |
| results = search_fashion(q, a, start, end, gender_override) | |
| else: | |
| results = [] | |
| seen_ids = {r[1] for r in results} | |
| return results, end, q, uploaded_image, gender_override, results, seen_ids | |
| # Text Search | |
| # Text Search | |
| text_search_btn.click( | |
| unified_search, | |
| inputs=[query, gr.State(None), alpha, search_offset, gender_dropdown], | |
| outputs=[gallery, search_offset, current_query, current_image, current_gender, shown_results, shown_ids] | |
| ) | |
| voice_search_btn.click( | |
| handle_voice_search, | |
| inputs=[voice_input, alpha, search_offset, voice_gender_dropdown], | |
| outputs=[gallery, search_offset, current_query, current_image, current_gender, shown_results, shown_ids] | |
| ) | |
| # Image Search | |
| image_search_btn.click( | |
| unified_search, | |
| inputs=[gr.State(""), image_input, alpha, search_offset, image_gender_dropdown], | |
| outputs=[gallery, search_offset, current_query, current_image, current_gender, shown_results, shown_ids] | |
| ) | |
| # --- Load More Button --- | |
| def load_more_fn(a, offset, q, img, gender_ui, prev_results, prev_ids): | |
| start = offset | |
| end = offset + 12 | |
| gender_override = gender_ui | |
| if img is not None: | |
| new_results = search_by_image(img, a, start, end) | |
| elif q.strip(): | |
| new_results = search_fashion(q, a, start, end, gender_override) | |
| else: | |
| new_results = [] | |
| filtered_new = [] | |
| new_ids = set() | |
| for item in new_results: | |
| img_obj, caption = item | |
| if caption not in prev_ids: | |
| filtered_new.append(item) | |
| new_ids.add(caption) | |
| combined = prev_results + filtered_new | |
| updated_ids = prev_ids.union(new_ids) | |
| return combined, end, combined, updated_ids | |
| load_more_btn.click( | |
| load_more_fn, | |
| inputs=[alpha, search_offset, current_query, current_image, current_gender, shown_results, shown_ids], | |
| outputs=[gallery, search_offset, shown_results, shown_ids] | |
| ) | |
| # gr.Markdown("🧠 Powered by OpenAI + Hybrid AI Fashion Search") | |
| demo.launch() | |