Spaces:
Runtime error
Runtime error
| # import os | |
| # from pinecone import Pinecone, ServerlessSpec | |
| # from PIL import Image, ImageOps | |
| # import numpy as np | |
| # from datasets import load_dataset | |
| # from pinecone_text.sparse import BM25Encoder | |
| # from sentence_transformers import SentenceTransformer | |
| # import torch | |
| # from tqdm.auto import tqdm | |
| # import gradio as gr | |
| # # ------------------- Pinecone Setup ------------------- | |
| # os.environ["PINECONE_API_KEY"] = "pcsk_TMCYK_LrbmZMTDhkxTjUXcr8iTcQ8LxurwKBFDvv4ahFis8SVob7QexVPPEt6g2zW6d3g" | |
| # api_key = os.environ.get('PINECONE_API_KEY') | |
| # pc = Pinecone(api_key=api_key) | |
| # cloud = os.environ.get('PINECONE_CLOUD') or 'aws' | |
| # region = os.environ.get('PINECONE_REGION') or 'us-east-1' | |
| # spec = ServerlessSpec(cloud=cloud, region=region) | |
| # index_name = "hybrid-image-search" | |
| # spec = ServerlessSpec(cloud="aws", region="us-east-1") | |
| # # choose a name for your index | |
| # index_name = "hybrid-image-search" | |
| # import time | |
| # # check if index already exists (it shouldn't if this is first time) | |
| # if index_name not in pc.list_indexes().names(): | |
| # # if does not exist, create index | |
| # pc.create_index( | |
| # index_name, | |
| # dimension=512, | |
| # metric='dotproduct', | |
| # spec=spec | |
| # ) | |
| # # wait for index to be initialized | |
| # while not pc.describe_index(index_name).status['ready']: | |
| # time.sleep(1) | |
| # # connect to index | |
| # index = pc.Index(index_name) | |
| # # view index stats | |
| # index.describe_index_stats() | |
| # # ------------------- Dataset Loading ------------------- | |
| # fashion = load_dataset("ashraq/fashion-product-images-small", split="train") | |
| # images = fashion["image"] | |
| # metadata = fashion.remove_columns("image").to_pandas() | |
| # # ------------------- Encoders ------------------- | |
| # bm25 = BM25Encoder() | |
| # bm25.fit(metadata["productDisplayName"]) | |
| # model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device='cuda' if torch.cuda.is_available() else 'cpu') | |
| # from sentence_transformers import SentenceTransformer | |
| # import torch | |
| # device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # # load a CLIP model from huggingface | |
| # model = SentenceTransformer( | |
| # 'sentence-transformers/clip-ViT-B-32', | |
| # device=device | |
| # ) | |
| # model | |
| # # ------------------- Hybrid Scaling ------------------- | |
| # def hybrid_scale(dense, sparse, alpha: float): | |
| # if alpha < 0 or alpha > 1: | |
| # raise ValueError("Alpha must be between 0 and 1") | |
| # # scale sparse and dense vectors to create hybrid search vecs | |
| # hsparse = { | |
| # 'indices': sparse['indices'], | |
| # 'values': [v * (1 - alpha) for v in sparse['values']] | |
| # } | |
| # hdense = [v * alpha for v in dense] | |
| # return hdense, hsparse | |
| # # ------------------- Metadata Filter Extraction ------------------- | |
| # from PIL import Image, ImageOps | |
| # import numpy as np | |
| # from PIL import Image, ImageOps | |
| # import numpy as np | |
| # from PIL import Image, ImageOps | |
| # import numpy as np | |
| # from transformers import CLIPProcessor, CLIPModel | |
| # clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| # clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # def extract_metadata_filters(query: str): | |
| # query_lower = query.lower() | |
| # gender = None | |
| # category = None | |
| # subcategory = None | |
| # color = None | |
| # # --- Gender Mapping --- | |
| # gender_map = { | |
| # "men": "Men", "man": "Men", "mens": "Men", "mans": "Men", "male": "Men", | |
| # "women": "Women", "woman": "Women", "womens": "Women", "female": "Women", | |
| # "boys": "Boys", "boy": "Boys", | |
| # "girls": "Girls", "girl": "Girls", | |
| # "kids": "Kids","kid": "Kids", | |
| # "unisex": "Unisex" | |
| # } | |
| # for term, mapped_value in gender_map.items(): | |
| # if term in query_lower: | |
| # gender = mapped_value | |
| # break | |
| # # --- Category Mapping --- | |
| # category_map = { | |
| # "shirt": "Shirts", | |
| # "tshirt": "Tshirts", "t-shirt": "Tshirts", | |
| # "jeans": "Jeans", | |
| # "watch": "Watches", | |
| # "kurta": "Kurtas", | |
| # "dress": "Dresses", "dresses": "Dresses", | |
| # "trousers": "Trousers", "pants": "Trousers", | |
| # "shorts": "Shorts", | |
| # "footwear": "Footwear", | |
| # "shoes": "Shoes", # note kept as Shoes | |
| # "fashion": "Apparel" | |
| # } | |
| # for term, mapped_value in category_map.items(): | |
| # if term in query_lower: | |
| # category = mapped_value | |
| # break | |
| # # --- SubCategory Mapping --- | |
| # subCategory_list = [ | |
| # "Accessories", "Apparel Set", "Bags", "Bath and Body", "Beauty Accessories", | |
| # "Belts", "Bottomwear", "Cufflinks", "Dress", "Eyes", "Eyewear", "Flip Flops", | |
| # "Fragrance", "Free Gifts", "Gloves", "Hair", "Headwear", "Home Furnishing", | |
| # "Innerwear", "Jewellery", "Lips", "Loungewear and Nightwear", "Makeup", | |
| # "Mufflers", "Nails", "Perfumes", "Sandal", "Saree", "Scarves", "Shoe Accessories", | |
| # "Shoes", "Skin", "Skin Care", "Socks", "Sports Accessories", "Sports Equipment", | |
| # "Stoles", "Ties", "Topwear", "Umbrellas", "Vouchers", "Wallets", "Watches", | |
| # "Water Bottle", "Wristbands" | |
| # ] | |
| # if "topwear" in query_lower or "top" in query_lower: | |
| # subcategory = "Topwear" | |
| # else: | |
| # for subcat in subCategory_list: | |
| # if subcat.lower() in query_lower: | |
| # subcategory = subcat | |
| # break | |
| # # --- Color Extraction --- | |
| # colors = [ | |
| # "red","blue","green","yellow","black","white", | |
| # "orange","pink","purple","brown","grey","beige" | |
| # ] | |
| # for c in colors: | |
| # if c in query_lower: | |
| # color = c.capitalize() | |
| # break | |
| # # --- Invalid pairs --- | |
| # invalid_pairs = { | |
| # ("Men", "Dresses"), ("Men", "Sarees"), ("Men", "Skirts"), | |
| # ("Boys", "Dresses"), ("Boys", "Sarees"), | |
| # ("Girls", "Boxers"), ("Men", "Heels") | |
| # } | |
| # if (gender, category) in invalid_pairs: | |
| # print(f"⚠️ Invalid pair: {gender} + {category}, dropping gender") | |
| # gender = None | |
| # # fallback | |
| # if gender and not category: | |
| # category = "Apparel" | |
| # return gender, category, subcategory, color | |
| # def search_fashion(query: str, alpha: float, start: int = 0, end: int = 12, gender_override: str = None): | |
| # gender, category, subcategory, color = extract_metadata_filters(query) | |
| # # override from dropdown | |
| # if gender_override: | |
| # gender = gender_override | |
| # # --- Pinecone Filter --- | |
| # filter = {} | |
| # if gender: | |
| # filter["gender"] = gender | |
| # if category: | |
| # if category in ["Footwear", "Shoes"]: | |
| # shoe_article_types = [ | |
| # "Casual Shoes", "Sports Shoes", "Formal Shoes", "Training Shoes", | |
| # "Sneakers", "Sandals", "Slippers", "Boots", "Flip Flops" | |
| # ] | |
| # filter["articleType"] = {"$in": shoe_article_types} | |
| # else: | |
| # filter["articleType"] = category | |
| # if subcategory: | |
| # filter["subCategory"] = subcategory | |
| # if color: | |
| # filter["baseColour"] = color | |
| # print(f"🔍 Using filter: {filter} (showing {start} to {end})") | |
| # sparse = bm25.encode_queries(query) | |
| # dense = model.encode(query).tolist() | |
| # hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha) | |
| # result = index.query( | |
| # top_k=end, | |
| # vector=hdense, | |
| # sparse_vector=hsparse, | |
| # include_metadata=True, | |
| # filter=filter if filter else None | |
| # ) | |
| # # fallback if no results | |
| # if len(result["matches"]) == 0: | |
| # print("⚠️ No results, retrying with alpha=0 sparse only") | |
| # hdense, hsparse = hybrid_scale(dense, sparse, alpha=0) | |
| # result = index.query( | |
| # top_k=end, | |
| # vector=hdense, | |
| # sparse_vector=hsparse, | |
| # include_metadata=True, | |
| # filter=filter if filter else None | |
| # ) | |
| # # fallback if no results with gender | |
| # if gender and len(result["matches"]) == 0: | |
| # print(f"⚠️ No results for gender {gender}, relaxing gender filter") | |
| # filter.pop("gender", None) | |
| # result = index.query( | |
| # top_k=end, | |
| # vector=hdense, | |
| # sparse_vector=hsparse, | |
| # include_metadata=True, | |
| # filter=filter if filter else None | |
| # ) | |
| # matches = result["matches"][start:end] | |
| # imgs_with_captions = [] | |
| # for r in matches: | |
| # idx = int(r["id"]) | |
| # img = images[idx] | |
| # meta = r.get("metadata", {}) | |
| # if not isinstance(img, Image.Image): | |
| # img = Image.fromarray(np.array(img)) | |
| # padded = ImageOps.pad(img, (256, 256), color="white") | |
| # caption = str(meta.get("productDisplayName", "Unknown Product")) | |
| # imgs_with_captions.append((padded, caption)) | |
| # return imgs_with_captions | |
| # # this is working code block | |
| # from PIL import Image, ImageOps | |
| # import numpy as np | |
| # def search_by_image(uploaded_image, alpha=0.5, start=0, end=12): | |
| # """ | |
| # Search visually similar products with support for pagination. | |
| # """ | |
| # # Preprocess image for CLIP | |
| # processed = clip_processor(images=uploaded_image, return_tensors="pt").to(device) | |
| # with torch.no_grad(): | |
| # image_vec = clip_model.get_image_features(**processed) | |
| # image_vec = image_vec.cpu().numpy().flatten().tolist() | |
| # # Query a larger top_k so you have enough to paginate | |
| # result = index.query( | |
| # top_k=end, | |
| # vector=image_vec, | |
| # include_metadata=True | |
| # ) | |
| # matches = result["matches"][start:end] # slice for pagination | |
| # imgs_with_captions = [] | |
| # for r in matches: | |
| # idx = int(r["id"]) | |
| # img = images[idx] | |
| # meta = r.get("metadata", {}) | |
| # if not isinstance(img, Image.Image): | |
| # img = Image.fromarray(np.array(img)) | |
| # padded = ImageOps.pad(img, (256, 256), color="white") | |
| # caption = str(meta.get("productDisplayName", "Unknown Product")) | |
| # imgs_with_captions.append((padded, caption)) | |
| # return imgs_with_captions | |
| # # with gr.Blocks(css=custom_css) as demo: | |
| # # gr.Markdown("# 🛍️ Fashion Product Hybrid Search") | |
| # # with gr.Row(equal_height=True): | |
| # # with gr.Column(scale=5, elem_classes="query-slider"): | |
| # # query = gr.Textbox( | |
| # # label="Enter your fashion search query", | |
| # # placeholder="Type something or leave blank to only use the image" | |
| # # ) | |
| # # alpha = gr.Slider( | |
| # # 0, 1, value=0.5, | |
| # # label="Hybrid Weight (alpha: 0=sparse, 1=dense)" | |
| # # ) | |
| # # with gr.Column(scale=1): | |
| # # image_input = gr.Image( | |
| # # type="pil", | |
| # # label="Upload an image (optional)", | |
| # # height=256, | |
| # # width=356, | |
| # # show_label=True | |
| # # ) | |
| # # search_btn = gr.Button("Search", elem_classes="search-btn") | |
| # # gallery = gr.Gallery( | |
| # # label="Search Results", | |
| # # columns=6, | |
| # # height="40vh" | |
| # # ) | |
| # import gradio as gr | |
| # custom_css = """ | |
| # .search-btn { | |
| # width: 100%; | |
| # } | |
| # .gr-row { | |
| # gap: 8px !important; | |
| # } | |
| # .query-slider > div { | |
| # margin-bottom: 4px !important; | |
| # } | |
| # .upload-box .icon-container { | |
| # display: none !important; | |
| # } | |
| # """ | |
| # with gr.Blocks(css=custom_css) as demo: | |
| # gr.Markdown("# 🛍️ Fashion Product Hybrid Search") | |
| # with gr.Row(equal_height=True): | |
| # with gr.Column(scale=5, elem_classes="query-slider"): | |
| # query = gr.Textbox( | |
| # label="Enter your fashion search query", | |
| # placeholder="Type something or leave blank to only use the image" | |
| # ) | |
| # alpha = gr.Slider(0, 1, value=0.5, label="Hybrid Weight (alpha: 0=sparse, 1=dense)") | |
| # gender_dropdown = gr.Dropdown( | |
| # ["", "Men", "Women", "Boys", "Girls", "Kids", "Unisex"], | |
| # label="Gender Filter (optional)" | |
| # ) | |
| # # with gr.Column(scale=1): | |
| # # image_input = gr.Image( | |
| # # type="pil", | |
| # # label="Upload an image (optional)", | |
| # # height=256, | |
| # # width=356 | |
| # # ) | |
| # with gr.Column(scale=1): | |
| # image_input = gr.Image( | |
| # type="pil", | |
| # label="Upload an image (optional)", | |
| # height=256, | |
| # width=356, | |
| # sources=["upload", "clipboard"] # only upload and paste allowed | |
| # ) | |
| # search_btn = gr.Button("Search", elem_classes="search-btn") | |
| # gallery = gr.Gallery(label="Search Results", columns=6, height="50vh") | |
| # load_more_btn = gr.Button("Load More") | |
| # # States to track | |
| # search_offset = gr.State(0) | |
| # current_query = gr.State("") | |
| # current_image = gr.State(None) | |
| # current_gender = gr.State("") | |
| # shown_results = gr.State([]) # new: store the list of shown images | |
| # def unified_search(q, uploaded_image, a, offset, gender_ui): | |
| # start = 0 | |
| # end = 12 | |
| # gender_override = gender_ui if gender_ui else None | |
| # if uploaded_image is not None: | |
| # results = search_by_image(uploaded_image, a, start, end) | |
| # elif q.strip() != "": | |
| # results = search_fashion(q, a, start, end, gender_override) | |
| # else: | |
| # results = [] | |
| # # reset shown_results to just these first 12 | |
| # return results, end, q, uploaded_image, gender_ui, results | |
| # search_btn.click( | |
| # unified_search, | |
| # inputs=[query, image_input, alpha, search_offset, gender_dropdown], | |
| # outputs=[gallery, search_offset, current_query, current_image, current_gender, shown_results] | |
| # ) | |
| # def load_more_fn(a, offset, q, img, gender_ui, prev_results): | |
| # start = offset | |
| # end = offset + 12 | |
| # gender_override = gender_ui if gender_ui else None | |
| # if img is not None: | |
| # new_results = search_by_image(img, a, start, end) | |
| # elif q.strip() != "": | |
| # new_results = search_fashion(q, a, start, end, gender_override) | |
| # else: | |
| # new_results = [] | |
| # combined_results = prev_results + new_results | |
| # return combined_results, end, combined_results | |
| # load_more_btn.click( | |
| # load_more_fn, | |
| # inputs=[alpha, search_offset, current_query, current_image, current_gender, shown_results], | |
| # outputs=[gallery, search_offset, shown_results] | |
| # ) | |
| # gr.Markdown("Powered by your hybrid AI search model 🚀") | |
| # demo.launch() | |
| # app.py | |
| import os | |
| import time | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image, ImageOps | |
| from tqdm.auto import tqdm | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| from pinecone import Pinecone, ServerlessSpec | |
| from pinecone_text.sparse import BM25Encoder | |
| from transformers import CLIPProcessor, CLIPModel | |
| import openai | |
| # ------------------- Keys & Setup ------------------- | |
| openai.api_key = os.getenv("OPENAI_API_KEY") | |
| pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) | |
| spec = ServerlessSpec(cloud=os.getenv("PINECONE_CLOUD") or "aws", region=os.getenv("PINECONE_REGION") or "us-east-1") | |
| index_name = "hybrid-image-search" | |
| if index_name not in pc.list_indexes().names(): | |
| pc.create_index(index_name, dimension=512, metric='dotproduct', spec=spec) | |
| while not pc.describe_index(index_name).status['ready']: | |
| time.sleep(1) | |
| index = pc.Index(index_name) | |
| # ------------------- Models & Dataset ------------------- | |
| fashion = load_dataset("ashraq/fashion-product-images-small", split="train") | |
| images = fashion["image"] | |
| metadata = fashion.remove_columns("image").to_pandas() | |
| bm25 = BM25Encoder() | |
| bm25.fit(metadata["productDisplayName"]) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device=device) | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # ------------------- Helper Functions ------------------- | |
| def hybrid_scale(dense, sparse, alpha: float): | |
| if alpha < 0 or alpha > 1: | |
| raise ValueError("Alpha must be between 0 and 1") | |
| hsparse = { | |
| 'indices': sparse['indices'], | |
| 'values': [v * (1 - alpha) for v in sparse['values']] | |
| } | |
| hdense = [v * alpha for v in dense] | |
| return hdense, hsparse | |
| def extract_intent_from_openai(query: str): | |
| prompt = f''' | |
| You are an assistant for a fashion search engine. Extract the user's intent from the following query. | |
| Return a Python dictionary with keys: category, gender, subcategory, color. | |
| If something is missing, use null. | |
| Query: "{query}" | |
| Only return the dictionary. | |
| ''' | |
| try: | |
| response = openai.ChatCompletion.create( | |
| model="gpt-4", | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0 | |
| ) | |
| raw = response.choices[0].message['content'] | |
| structured = eval(raw) | |
| return structured | |
| except Exception as e: | |
| print(f"⚠️ OpenAI intent extraction failed: {e}") | |
| return {} | |
| def is_duplicate(img, seen_hashes): | |
| h = hash(img.tobytes()) | |
| if h in seen_hashes: | |
| return True | |
| seen_hashes.add(h) | |
| return False | |
| # ------------------- Search Functions ------------------- | |
| def search_fashion(query: str, alpha: float, start: int = 0, end: int = 12, gender_override: str = None): | |
| intent = extract_intent_from_openai(query) | |
| gender = intent.get("gender") | |
| category = intent.get("category") | |
| subcategory = intent.get("subcategory") | |
| color = intent.get("color") | |
| if gender_override: | |
| gender = gender_override | |
| filter = {} | |
| if gender: | |
| filter["gender"] = gender | |
| if category: | |
| if category in ["Footwear", "Shoes"]: | |
| filter["articleType"] = {"$regex": ".*(Shoe|Footwear).*"} | |
| else: | |
| filter["articleType"] = category | |
| if subcategory: | |
| filter["subCategory"] = subcategory | |
| if color: | |
| filter["baseColour"] = color | |
| sparse = bm25.encode_queries(query) | |
| dense = model.encode(query).tolist() | |
| hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha) | |
| result = index.query( | |
| top_k=100, | |
| vector=hdense, | |
| sparse_vector=hsparse, | |
| include_metadata=True, | |
| filter=filter if filter else None | |
| ) | |
| if len(result["matches"]) == 0: | |
| print("⚠️ No results, retrying with alpha=0 sparse only") | |
| hdense, hsparse = hybrid_scale(dense, sparse, alpha=0) | |
| result = index.query(top_k=100, vector=hdense, sparse_vector=hsparse, include_metadata=True, filter=filter) | |
| imgs_with_captions = [] | |
| seen_hashes = set() | |
| for r in result["matches"]: | |
| idx = int(r["id"]) | |
| img = images[idx] | |
| meta = r.get("metadata", {}) | |
| if not isinstance(img, Image.Image): | |
| img = Image.fromarray(np.array(img)) | |
| padded = ImageOps.pad(img, (256, 256), color="white") | |
| caption = str(meta.get("productDisplayName", "Unknown Product")) | |
| if not is_duplicate(padded, seen_hashes): | |
| imgs_with_captions.append((padded, caption)) | |
| if len(imgs_with_captions) >= end: | |
| break | |
| return imgs_with_captions | |
| def search_by_image(uploaded_image, alpha=0.5, start=0, end=12): | |
| processed = clip_processor(images=uploaded_image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| image_vec = clip_model.get_image_features(**processed) | |
| image_vec = image_vec.cpu().numpy().flatten().tolist() | |
| result = index.query(top_k=100, vector=image_vec, include_metadata=True) | |
| imgs_with_captions = [] | |
| seen_hashes = set() | |
| for r in result["matches"]: | |
| idx = int(r["id"]) | |
| img = images[idx] | |
| meta = r.get("metadata", {}) | |
| caption = str(meta.get("productDisplayName", "Unknown Product")) | |
| if not isinstance(img, Image.Image): | |
| img = Image.fromarray(np.array(img)) | |
| padded = ImageOps.pad(img, (256, 256), color="white") | |
| if not is_duplicate(padded, seen_hashes): | |
| imgs_with_captions.append((padded, caption)) | |
| if len(imgs_with_captions) >= end: | |
| break | |
| return imgs_with_captions | |
| # ------------------- UI ------------------- | |
| # custom_css = """ | |
| # .search-btn { width: 100%; } | |
| # .gr-row { gap: 8px !important; } | |
| # .query-slider > div { margin-bottom: 4px !important; } | |
| # .gr-gallery-item { width: 256px !important; height: 256px !important; } | |
| # .gr-gallery-item img { width: 100% !important; height: 100% !important; object-fit: cover !important; } | |
| # """ | |
| # with gr.Blocks(css=custom_css) as demo: | |
| # gr.Markdown("# 🛍️ Fashion Product Hybrid Search (with GPT-4 powered query parsing)") | |
| # with gr.Row(equal_height=True): | |
| # with gr.Column(scale=5, elem_classes="query-slider"): | |
| # query = gr.Textbox(label="Enter your fashion search query", placeholder="e.g., black sneakers for women") | |
| # alpha = gr.Slider(0, 1, value=0.5, label="Hybrid Weight (alpha: 0=sparse, 1=dense)") | |
| # gender_dropdown = gr.Dropdown(["", "Men", "Women", "Boys", "Girls", "Kids", "Unisex"], label="Gender Filter (optional)") | |
| # with gr.Column(scale=1): | |
| # image_input = gr.Image(type="pil", label="Upload an image (optional)", sources=["upload", "clipboard"], height=256, width=356) | |
| # search_btn = gr.Button("Search", elem_classes="search-btn") | |
| # gallery = gr.Gallery(label="Search Results", columns=6, height=None) | |
| # load_more_btn = gr.Button("Load More") | |
| # search_offset = gr.State(0) | |
| # current_query = gr.State("") | |
| # current_image = gr.State(None) | |
| # current_gender = gr.State("") | |
| # shown_results = gr.State([]) | |
| # shown_ids = gr.State(set()) | |
| # def unified_search(q, uploaded_image, a, offset, gender_ui): | |
| # start = 0 | |
| # end = 12 | |
| # filters = extract_intent_from_openai(q) if q.strip() else {} | |
| # gender_override = gender_ui if gender_ui else filters.get("gender") | |
| # if uploaded_image is not None: | |
| # results = search_by_image(uploaded_image, a, start, end) | |
| # elif q.strip(): | |
| # results = search_fashion(q, a, start, end, gender_override) | |
| # else: | |
| # results = [] | |
| # seen_ids = {r[1] for r in results} | |
| # return results, end, q, uploaded_image, gender_override, results, seen_ids | |
| # search_btn.click(unified_search, inputs=[query, image_input, alpha, search_offset, gender_dropdown], | |
| # outputs=[gallery, search_offset, current_query, current_image, current_gender, shown_results, shown_ids]) | |
| # def load_more_fn(a, offset, q, img, gender_ui, prev_results, prev_ids): | |
| # start = offset | |
| # end = offset + 12 | |
| # gender_override = gender_ui | |
| # if img is not None: | |
| # new_results = search_by_image(img, a, start, end) | |
| # elif q.strip(): | |
| # new_results = search_fashion(q, a, start, end, gender_override) | |
| # else: | |
| # new_results = [] | |
| # filtered_new = [] | |
| # new_ids = set() | |
| # for item in new_results: | |
| # img_obj, caption = item | |
| # if caption not in prev_ids: | |
| # filtered_new.append(item) | |
| # new_ids.add(caption) | |
| # combined = prev_results + filtered_new | |
| # updated_ids = prev_ids.union(new_ids) | |
| # return combined, end, combined, updated_ids | |
| # load_more_btn.click(load_more_fn, inputs=[alpha, search_offset, current_query, current_image, current_gender, shown_results, shown_ids], | |
| # outputs=[gallery, search_offset, shown_results, shown_ids]) | |
| # gr.Markdown("🧠 Powered by OpenAI + Hybrid AI Fashion Search") | |
| # demo.launch() | |
| import gradio as gr | |
| custom_css = """ | |
| /* Container */ | |
| .gr-gallery { | |
| display: flex !important; | |
| flex-wrap: wrap; | |
| gap: 10px; | |
| justify-content: center; | |
| } | |
| /* Each item */ | |
| .gr-gallery-item { | |
| flex: 0 0 calc(16.66% - 10px); /* 6 columns on desktop */ | |
| max-width: calc(16.66% - 10px); | |
| height: 256px !important; | |
| overflow: hidden; | |
| } | |
| /* Image inside each item */ | |
| .gr-gallery-item img { | |
| width: 100% !important; | |
| height: 100% !important; | |
| object-fit: cover !important; | |
| } | |
| /* For mobile: 3 columns */ | |
| @media (max-width: 768px) { | |
| .gr-gallery-item { | |
| flex: 0 0 calc(33.33% - 10px); /* 3 columns on mobile */ | |
| max-width: calc(33.33% - 10px); | |
| } | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css) as demo: | |
| gr.Markdown("## 🛍️ Responsive Fashion Product Search") | |
| with gr.Row(): | |
| with gr.Column(scale=5, elem_classes="query-slider"): | |
| query = gr.Textbox(label="Search", placeholder="e.g. black dress for women") | |
| alpha = gr.Slider(0, 1, value=0.5, label="Hybrid Weight") | |
| gender_dropdown = gr.Dropdown(["", "Men", "Women", "Unisex"], label="Gender (optional)") | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard"], height=256) | |
| search_btn = gr.Button("Search", elem_classes="search-btn") | |
| gallery = gr.Gallery(label="Results", columns=6, height=None, allow_preview=True) | |
| load_more_btn = gr.Button("Load More") | |
| search_offset = gr.State(0) | |
| current_query = gr.State("") | |
| current_image = gr.State(None) | |
| current_gender = gr.State("") | |
| shown_results = gr.State([]) | |
| shown_ids = gr.State(set()) | |
| def unified_search(q, uploaded_image, a, offset, gender_ui): | |
| start = 0 | |
| end = 12 | |
| filters = extract_intent_from_openai(q) if q.strip() else {} | |
| gender_override = gender_ui if gender_ui else filters.get("gender") | |
| if uploaded_image is not None: | |
| results = search_by_image(uploaded_image, a, start, end) | |
| elif q.strip(): | |
| results = search_fashion(q, a, start, end, gender_override) | |
| else: | |
| results = [] | |
| seen_ids = {r[1] for r in results} | |
| return results, end, q, uploaded_image, gender_override, results, seen_ids | |
| search_btn.click( | |
| unified_search, | |
| inputs=[query, image_input, alpha, search_offset, gender_dropdown], | |
| outputs=[gallery, search_offset, current_query, current_image, current_gender, shown_results, shown_ids] | |
| ) | |
| def load_more_fn(a, offset, q, img, gender_ui, prev_results, prev_ids): | |
| start = offset | |
| end = offset + 12 | |
| gender_override = gender_ui | |
| if img is not None: | |
| new_results = search_by_image(img, a, start, end) | |
| elif q.strip(): | |
| new_results = search_fashion(q, a, start, end, gender_override) | |
| else: | |
| new_results = [] | |
| filtered_new = [] | |
| new_ids = set() | |
| for item in new_results: | |
| img_obj, caption = item | |
| if caption not in prev_ids: | |
| filtered_new.append(item) | |
| new_ids.add(caption) | |
| combined = prev_results + filtered_new | |
| updated_ids = prev_ids.union(new_ids) | |
| return combined, end, combined, updated_ids | |
| load_more_btn.click( | |
| load_more_fn, | |
| inputs=[alpha, search_offset, current_query, current_image, current_gender, shown_results, shown_ids], | |
| outputs=[gallery, search_offset, shown_results, shown_ids] | |
| ) | |
| gr.Markdown("🧠 Powered by OpenAI + Hybrid AI Fashion Search") | |
| demo.launch() | |