# import os # from pinecone import Pinecone, ServerlessSpec # from PIL import Image, ImageOps # import numpy as np # from datasets import load_dataset # from pinecone_text.sparse import BM25Encoder # from sentence_transformers import SentenceTransformer # import torch # from tqdm.auto import tqdm # import gradio as gr # # ------------------- Pinecone Setup ------------------- # os.environ["PINECONE_API_KEY"] = "pcsk_TMCYK_LrbmZMTDhkxTjUXcr8iTcQ8LxurwKBFDvv4ahFis8SVob7QexVPPEt6g2zW6d3g" # api_key = os.environ.get('PINECONE_API_KEY') # pc = Pinecone(api_key=api_key) # cloud = os.environ.get('PINECONE_CLOUD') or 'aws' # region = os.environ.get('PINECONE_REGION') or 'us-east-1' # spec = ServerlessSpec(cloud=cloud, region=region) # index_name = "hybrid-image-search" # spec = ServerlessSpec(cloud="aws", region="us-east-1") # # choose a name for your index # index_name = "hybrid-image-search" # import time # # check if index already exists (it shouldn't if this is first time) # if index_name not in pc.list_indexes().names(): # # if does not exist, create index # pc.create_index( # index_name, # dimension=512, # metric='dotproduct', # spec=spec # ) # # wait for index to be initialized # while not pc.describe_index(index_name).status['ready']: # time.sleep(1) # # connect to index # index = pc.Index(index_name) # # view index stats # index.describe_index_stats() # # ------------------- Dataset Loading ------------------- # fashion = load_dataset("ashraq/fashion-product-images-small", split="train") # images = fashion["image"] # metadata = fashion.remove_columns("image").to_pandas() # # ------------------- Encoders ------------------- # bm25 = BM25Encoder() # bm25.fit(metadata["productDisplayName"]) # model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device='cuda' if torch.cuda.is_available() else 'cpu') # from sentence_transformers import SentenceTransformer # import torch # device = 'cuda' if torch.cuda.is_available() else 'cpu' # # load a CLIP model from huggingface # model = SentenceTransformer( # 'sentence-transformers/clip-ViT-B-32', # device=device # ) # model # # ------------------- Hybrid Scaling ------------------- # def hybrid_scale(dense, sparse, alpha: float): # if alpha < 0 or alpha > 1: # raise ValueError("Alpha must be between 0 and 1") # # scale sparse and dense vectors to create hybrid search vecs # hsparse = { # 'indices': sparse['indices'], # 'values': [v * (1 - alpha) for v in sparse['values']] # } # hdense = [v * alpha for v in dense] # return hdense, hsparse # # ------------------- Metadata Filter Extraction ------------------- # from PIL import Image, ImageOps # import numpy as np # from PIL import Image, ImageOps # import numpy as np # from PIL import Image, ImageOps # import numpy as np # from transformers import CLIPProcessor, CLIPModel # clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) # clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # def extract_metadata_filters(query: str): # query_lower = query.lower() # gender = None # category = None # subcategory = None # color = None # # --- Gender Mapping --- # gender_map = { # "men": "Men", "man": "Men", "mens": "Men", "mans": "Men", "male": "Men", # "women": "Women", "woman": "Women", "womens": "Women", "female": "Women", # "boys": "Boys", "boy": "Boys", # "girls": "Girls", "girl": "Girls", # "kids": "Kids","kid": "Kids", # "unisex": "Unisex" # } # for term, mapped_value in gender_map.items(): # if term in query_lower: # gender = mapped_value # break # # --- Category Mapping --- # category_map = { # "shirt": "Shirts", # "tshirt": "Tshirts", "t-shirt": "Tshirts", # "jeans": "Jeans", # "watch": "Watches", # "kurta": "Kurtas", # "dress": "Dresses", "dresses": "Dresses", # "trousers": "Trousers", "pants": "Trousers", # "shorts": "Shorts", # "footwear": "Footwear", # "shoes": "Shoes", # note kept as Shoes # "fashion": "Apparel" # } # for term, mapped_value in category_map.items(): # if term in query_lower: # category = mapped_value # break # # --- SubCategory Mapping --- # subCategory_list = [ # "Accessories", "Apparel Set", "Bags", "Bath and Body", "Beauty Accessories", # "Belts", "Bottomwear", "Cufflinks", "Dress", "Eyes", "Eyewear", "Flip Flops", # "Fragrance", "Free Gifts", "Gloves", "Hair", "Headwear", "Home Furnishing", # "Innerwear", "Jewellery", "Lips", "Loungewear and Nightwear", "Makeup", # "Mufflers", "Nails", "Perfumes", "Sandal", "Saree", "Scarves", "Shoe Accessories", # "Shoes", "Skin", "Skin Care", "Socks", "Sports Accessories", "Sports Equipment", # "Stoles", "Ties", "Topwear", "Umbrellas", "Vouchers", "Wallets", "Watches", # "Water Bottle", "Wristbands" # ] # if "topwear" in query_lower or "top" in query_lower: # subcategory = "Topwear" # else: # for subcat in subCategory_list: # if subcat.lower() in query_lower: # subcategory = subcat # break # # --- Color Extraction --- # colors = [ # "red","blue","green","yellow","black","white", # "orange","pink","purple","brown","grey","beige" # ] # for c in colors: # if c in query_lower: # color = c.capitalize() # break # # --- Invalid pairs --- # invalid_pairs = { # ("Men", "Dresses"), ("Men", "Sarees"), ("Men", "Skirts"), # ("Boys", "Dresses"), ("Boys", "Sarees"), # ("Girls", "Boxers"), ("Men", "Heels") # } # if (gender, category) in invalid_pairs: # print(f"⚠️ Invalid pair: {gender} + {category}, dropping gender") # gender = None # # fallback # if gender and not category: # category = "Apparel" # return gender, category, subcategory, color # def search_fashion(query: str, alpha: float, start: int = 0, end: int = 12, gender_override: str = None): # gender, category, subcategory, color = extract_metadata_filters(query) # # override from dropdown # if gender_override: # gender = gender_override # # --- Pinecone Filter --- # filter = {} # if gender: # filter["gender"] = gender # if category: # if category in ["Footwear", "Shoes"]: # shoe_article_types = [ # "Casual Shoes", "Sports Shoes", "Formal Shoes", "Training Shoes", # "Sneakers", "Sandals", "Slippers", "Boots", "Flip Flops" # ] # filter["articleType"] = {"$in": shoe_article_types} # else: # filter["articleType"] = category # if subcategory: # filter["subCategory"] = subcategory # if color: # filter["baseColour"] = color # print(f"🔍 Using filter: {filter} (showing {start} to {end})") # sparse = bm25.encode_queries(query) # dense = model.encode(query).tolist() # hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha) # result = index.query( # top_k=end, # vector=hdense, # sparse_vector=hsparse, # include_metadata=True, # filter=filter if filter else None # ) # # fallback if no results # if len(result["matches"]) == 0: # print("⚠️ No results, retrying with alpha=0 sparse only") # hdense, hsparse = hybrid_scale(dense, sparse, alpha=0) # result = index.query( # top_k=end, # vector=hdense, # sparse_vector=hsparse, # include_metadata=True, # filter=filter if filter else None # ) # # fallback if no results with gender # if gender and len(result["matches"]) == 0: # print(f"⚠️ No results for gender {gender}, relaxing gender filter") # filter.pop("gender", None) # result = index.query( # top_k=end, # vector=hdense, # sparse_vector=hsparse, # include_metadata=True, # filter=filter if filter else None # ) # matches = result["matches"][start:end] # imgs_with_captions = [] # for r in matches: # idx = int(r["id"]) # img = images[idx] # meta = r.get("metadata", {}) # if not isinstance(img, Image.Image): # img = Image.fromarray(np.array(img)) # padded = ImageOps.pad(img, (256, 256), color="white") # caption = str(meta.get("productDisplayName", "Unknown Product")) # imgs_with_captions.append((padded, caption)) # return imgs_with_captions # # this is working code block # from PIL import Image, ImageOps # import numpy as np # def search_by_image(uploaded_image, alpha=0.5, start=0, end=12): # """ # Search visually similar products with support for pagination. # """ # # Preprocess image for CLIP # processed = clip_processor(images=uploaded_image, return_tensors="pt").to(device) # with torch.no_grad(): # image_vec = clip_model.get_image_features(**processed) # image_vec = image_vec.cpu().numpy().flatten().tolist() # # Query a larger top_k so you have enough to paginate # result = index.query( # top_k=end, # vector=image_vec, # include_metadata=True # ) # matches = result["matches"][start:end] # slice for pagination # imgs_with_captions = [] # for r in matches: # idx = int(r["id"]) # img = images[idx] # meta = r.get("metadata", {}) # if not isinstance(img, Image.Image): # img = Image.fromarray(np.array(img)) # padded = ImageOps.pad(img, (256, 256), color="white") # caption = str(meta.get("productDisplayName", "Unknown Product")) # imgs_with_captions.append((padded, caption)) # return imgs_with_captions # # with gr.Blocks(css=custom_css) as demo: # # gr.Markdown("# 🛍️ Fashion Product Hybrid Search") # # with gr.Row(equal_height=True): # # with gr.Column(scale=5, elem_classes="query-slider"): # # query = gr.Textbox( # # label="Enter your fashion search query", # # placeholder="Type something or leave blank to only use the image" # # ) # # alpha = gr.Slider( # # 0, 1, value=0.5, # # label="Hybrid Weight (alpha: 0=sparse, 1=dense)" # # ) # # with gr.Column(scale=1): # # image_input = gr.Image( # # type="pil", # # label="Upload an image (optional)", # # height=256, # # width=356, # # show_label=True # # ) # # search_btn = gr.Button("Search", elem_classes="search-btn") # # gallery = gr.Gallery( # # label="Search Results", # # columns=6, # # height="40vh" # # ) # import gradio as gr # custom_css = """ # .search-btn { # width: 100%; # } # .gr-row { # gap: 8px !important; # } # .query-slider > div { # margin-bottom: 4px !important; # } # .upload-box .icon-container { # display: none !important; # } # """ # with gr.Blocks(css=custom_css) as demo: # gr.Markdown("# 🛍️ Fashion Product Hybrid Search") # with gr.Row(equal_height=True): # with gr.Column(scale=5, elem_classes="query-slider"): # query = gr.Textbox( # label="Enter your fashion search query", # placeholder="Type something or leave blank to only use the image" # ) # alpha = gr.Slider(0, 1, value=0.5, label="Hybrid Weight (alpha: 0=sparse, 1=dense)") # gender_dropdown = gr.Dropdown( # ["", "Men", "Women", "Boys", "Girls", "Kids", "Unisex"], # label="Gender Filter (optional)" # ) # # with gr.Column(scale=1): # # image_input = gr.Image( # # type="pil", # # label="Upload an image (optional)", # # height=256, # # width=356 # # ) # with gr.Column(scale=1): # image_input = gr.Image( # type="pil", # label="Upload an image (optional)", # height=256, # width=356, # sources=["upload", "clipboard"] # only upload and paste allowed # ) # search_btn = gr.Button("Search", elem_classes="search-btn") # gallery = gr.Gallery(label="Search Results", columns=6, height="50vh") # load_more_btn = gr.Button("Load More") # # States to track # search_offset = gr.State(0) # current_query = gr.State("") # current_image = gr.State(None) # current_gender = gr.State("") # shown_results = gr.State([]) # new: store the list of shown images # def unified_search(q, uploaded_image, a, offset, gender_ui): # start = 0 # end = 12 # gender_override = gender_ui if gender_ui else None # if uploaded_image is not None: # results = search_by_image(uploaded_image, a, start, end) # elif q.strip() != "": # results = search_fashion(q, a, start, end, gender_override) # else: # results = [] # # reset shown_results to just these first 12 # return results, end, q, uploaded_image, gender_ui, results # search_btn.click( # unified_search, # inputs=[query, image_input, alpha, search_offset, gender_dropdown], # outputs=[gallery, search_offset, current_query, current_image, current_gender, shown_results] # ) # def load_more_fn(a, offset, q, img, gender_ui, prev_results): # start = offset # end = offset + 12 # gender_override = gender_ui if gender_ui else None # if img is not None: # new_results = search_by_image(img, a, start, end) # elif q.strip() != "": # new_results = search_fashion(q, a, start, end, gender_override) # else: # new_results = [] # combined_results = prev_results + new_results # return combined_results, end, combined_results # load_more_btn.click( # load_more_fn, # inputs=[alpha, search_offset, current_query, current_image, current_gender, shown_results], # outputs=[gallery, search_offset, shown_results] # ) # gr.Markdown("Powered by your hybrid AI search model 🚀") # demo.launch() # app.py import os import time import torch import numpy as np import gradio as gr from PIL import Image, ImageOps from tqdm.auto import tqdm from datasets import load_dataset from sentence_transformers import SentenceTransformer from pinecone import Pinecone, ServerlessSpec from pinecone_text.sparse import BM25Encoder from transformers import CLIPProcessor, CLIPModel import openai # ------------------- Keys & Setup ------------------- openai.api_key = os.getenv("OPENAI_API_KEY") pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) spec = ServerlessSpec(cloud=os.getenv("PINECONE_CLOUD") or "aws", region=os.getenv("PINECONE_REGION") or "us-east-1") index_name = "hybrid-image-search" if index_name not in pc.list_indexes().names(): pc.create_index(index_name, dimension=512, metric='dotproduct', spec=spec) while not pc.describe_index(index_name).status['ready']: time.sleep(1) index = pc.Index(index_name) # ------------------- Models & Dataset ------------------- fashion = load_dataset("ashraq/fashion-product-images-small", split="train") images = fashion["image"] metadata = fashion.remove_columns("image").to_pandas() bm25 = BM25Encoder() bm25.fit(metadata["productDisplayName"]) device = 'cuda' if torch.cuda.is_available() else 'cpu' model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device=device) clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # ------------------- Helper Functions ------------------- def hybrid_scale(dense, sparse, alpha: float): if alpha < 0 or alpha > 1: raise ValueError("Alpha must be between 0 and 1") hsparse = { 'indices': sparse['indices'], 'values': [v * (1 - alpha) for v in sparse['values']] } hdense = [v * alpha for v in dense] return hdense, hsparse def extract_intent_from_openai(query: str): prompt = f''' You are an assistant for a fashion search engine. Extract the user's intent from the following query. Return a Python dictionary with keys: category, gender, subcategory, color. If something is missing, use null. Query: "{query}" Only return the dictionary. ''' try: response = openai.ChatCompletion.create( model="gpt-4", messages=[{"role": "user", "content": prompt}], temperature=0 ) raw = response.choices[0].message['content'] structured = eval(raw) return structured except Exception as e: print(f"⚠️ OpenAI intent extraction failed: {e}") return {} def is_duplicate(img, seen_hashes): h = hash(img.tobytes()) if h in seen_hashes: return True seen_hashes.add(h) return False # ------------------- Search Functions ------------------- def search_fashion(query: str, alpha: float, start: int = 0, end: int = 12, gender_override: str = None): intent = extract_intent_from_openai(query) gender = intent.get("gender") category = intent.get("category") subcategory = intent.get("subcategory") color = intent.get("color") if gender_override: gender = gender_override filter = {} if gender: filter["gender"] = gender if category: if category in ["Footwear", "Shoes"]: filter["articleType"] = {"$regex": ".*(Shoe|Footwear).*"} else: filter["articleType"] = category if subcategory: filter["subCategory"] = subcategory if color: filter["baseColour"] = color sparse = bm25.encode_queries(query) dense = model.encode(query).tolist() hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha) result = index.query( top_k=100, vector=hdense, sparse_vector=hsparse, include_metadata=True, filter=filter if filter else None ) if len(result["matches"]) == 0: print("⚠️ No results, retrying with alpha=0 sparse only") hdense, hsparse = hybrid_scale(dense, sparse, alpha=0) result = index.query(top_k=100, vector=hdense, sparse_vector=hsparse, include_metadata=True, filter=filter) imgs_with_captions = [] seen_hashes = set() for r in result["matches"]: idx = int(r["id"]) img = images[idx] meta = r.get("metadata", {}) if not isinstance(img, Image.Image): img = Image.fromarray(np.array(img)) padded = ImageOps.pad(img, (256, 256), color="white") caption = str(meta.get("productDisplayName", "Unknown Product")) if not is_duplicate(padded, seen_hashes): imgs_with_captions.append((padded, caption)) if len(imgs_with_captions) >= end: break return imgs_with_captions def search_by_image(uploaded_image, alpha=0.5, start=0, end=12): processed = clip_processor(images=uploaded_image, return_tensors="pt").to(device) with torch.no_grad(): image_vec = clip_model.get_image_features(**processed) image_vec = image_vec.cpu().numpy().flatten().tolist() result = index.query(top_k=100, vector=image_vec, include_metadata=True) imgs_with_captions = [] seen_hashes = set() for r in result["matches"]: idx = int(r["id"]) img = images[idx] meta = r.get("metadata", {}) caption = str(meta.get("productDisplayName", "Unknown Product")) if not isinstance(img, Image.Image): img = Image.fromarray(np.array(img)) padded = ImageOps.pad(img, (256, 256), color="white") if not is_duplicate(padded, seen_hashes): imgs_with_captions.append((padded, caption)) if len(imgs_with_captions) >= end: break return imgs_with_captions # ------------------- UI ------------------- # custom_css = """ # .search-btn { width: 100%; } # .gr-row { gap: 8px !important; } # .query-slider > div { margin-bottom: 4px !important; } # .gr-gallery-item { width: 256px !important; height: 256px !important; } # .gr-gallery-item img { width: 100% !important; height: 100% !important; object-fit: cover !important; } # """ # with gr.Blocks(css=custom_css) as demo: # gr.Markdown("# 🛍️ Fashion Product Hybrid Search (with GPT-4 powered query parsing)") # with gr.Row(equal_height=True): # with gr.Column(scale=5, elem_classes="query-slider"): # query = gr.Textbox(label="Enter your fashion search query", placeholder="e.g., black sneakers for women") # alpha = gr.Slider(0, 1, value=0.5, label="Hybrid Weight (alpha: 0=sparse, 1=dense)") # gender_dropdown = gr.Dropdown(["", "Men", "Women", "Boys", "Girls", "Kids", "Unisex"], label="Gender Filter (optional)") # with gr.Column(scale=1): # image_input = gr.Image(type="pil", label="Upload an image (optional)", sources=["upload", "clipboard"], height=256, width=356) # search_btn = gr.Button("Search", elem_classes="search-btn") # gallery = gr.Gallery(label="Search Results", columns=6, height=None) # load_more_btn = gr.Button("Load More") # search_offset = gr.State(0) # current_query = gr.State("") # current_image = gr.State(None) # current_gender = gr.State("") # shown_results = gr.State([]) # shown_ids = gr.State(set()) # def unified_search(q, uploaded_image, a, offset, gender_ui): # start = 0 # end = 12 # filters = extract_intent_from_openai(q) if q.strip() else {} # gender_override = gender_ui if gender_ui else filters.get("gender") # if uploaded_image is not None: # results = search_by_image(uploaded_image, a, start, end) # elif q.strip(): # results = search_fashion(q, a, start, end, gender_override) # else: # results = [] # seen_ids = {r[1] for r in results} # return results, end, q, uploaded_image, gender_override, results, seen_ids # search_btn.click(unified_search, inputs=[query, image_input, alpha, search_offset, gender_dropdown], # outputs=[gallery, search_offset, current_query, current_image, current_gender, shown_results, shown_ids]) # def load_more_fn(a, offset, q, img, gender_ui, prev_results, prev_ids): # start = offset # end = offset + 12 # gender_override = gender_ui # if img is not None: # new_results = search_by_image(img, a, start, end) # elif q.strip(): # new_results = search_fashion(q, a, start, end, gender_override) # else: # new_results = [] # filtered_new = [] # new_ids = set() # for item in new_results: # img_obj, caption = item # if caption not in prev_ids: # filtered_new.append(item) # new_ids.add(caption) # combined = prev_results + filtered_new # updated_ids = prev_ids.union(new_ids) # return combined, end, combined, updated_ids # load_more_btn.click(load_more_fn, inputs=[alpha, search_offset, current_query, current_image, current_gender, shown_results, shown_ids], # outputs=[gallery, search_offset, shown_results, shown_ids]) # gr.Markdown("🧠 Powered by OpenAI + Hybrid AI Fashion Search") # demo.launch() import gradio as gr custom_css = """ .search-btn { width: 100%; } .gr-row { gap: 8px !important; flex-wrap: wrap; } .query-slider > div { margin-bottom: 4px !important; } /* Default: 6 per row */ .gr-gallery-item { flex: 1 1 calc(16.66% - 10px); /* 6 per row with some gap */ max-width: calc(16.66% - 10px); box-sizing: border-box; height: auto !important; margin-bottom: 10px; } .gr-gallery-item img { width: 100% !important; height: auto !important; object-fit: cover !important; } /* On small screens: 3 per row */ @media (max-width: 768px) { .gr-gallery-item { flex: 1 1 calc(33.33% - 10px); /* 3 per row */ max-width: calc(33.33% - 10px); } } """ with gr.Blocks(css=custom_css) as demo: gr.Markdown("## 🛍️ Fashion Product Hybrid Search (with GPT-4 powered query parsing)") with gr.Row(equal_height=True): with gr.Column(scale=5, elem_classes="query-slider"): query = gr.Textbox(label="Enter your fashion search query", placeholder="e.g., black sneakers for women") alpha = gr.Slider(0, 1, value=0.5, label="Hybrid Weight (alpha: 0=sparse, 1=dense)") gender_dropdown = gr.Dropdown(["", "Men", "Women", "Boys", "Girls", "Kids", "Unisex"], label="Gender Filter (optional)") with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Upload an image (optional)", sources=["upload", "clipboard"], height=256) search_btn = gr.Button("Search", elem_classes="search-btn") gallery = gr.Gallery(label="Search Results", columns=6, height=None, allow_preview=True) # 'columns=6' ignored due to CSS override load_more_btn = gr.Button("Load More") search_offset = gr.State(0) current_query = gr.State("") current_image = gr.State(None) current_gender = gr.State("") shown_results = gr.State([]) shown_ids = gr.State(set()) def unified_search(q, uploaded_image, a, offset, gender_ui): start = 0 end = 12 filters = extract_intent_from_openai(q) if q.strip() else {} gender_override = gender_ui if gender_ui else filters.get("gender") if uploaded_image is not None: results = search_by_image(uploaded_image, a, start, end) elif q.strip(): results = search_fashion(q, a, start, end, gender_override) else: results = [] seen_ids = {r[1] for r in results} return results, end, q, uploaded_image, gender_override, results, seen_ids search_btn.click(unified_search, inputs=[query, image_input, alpha, search_offset, gender_dropdown], outputs=[gallery, search_offset, current_query, current_image, current_gender, shown_results, shown_ids]) def load_more_fn(a, offset, q, img, gender_ui, prev_results, prev_ids): start = offset end = offset + 12 gender_override = gender_ui if img is not None: new_results = search_by_image(img, a, start, end) elif q.strip(): new_results = search_fashion(q, a, start, end, gender_override) else: new_results = [] filtered_new = [] new_ids = set() for item in new_results: img_obj, caption = item if caption not in prev_ids: filtered_new.append(item) new_ids.add(caption) combined = prev_results + filtered_new updated_ids = prev_ids.union(new_ids) return combined, end, combined, updated_ids load_more_btn.click(load_more_fn, inputs=[alpha, search_offset, current_query, current_image, current_gender, shown_results, shown_ids], outputs=[gallery, search_offset, shown_results, shown_ids]) gr.Markdown("🧠 Powered by OpenAI + Hybrid AI Fashion Search") demo.launch()