FashionGPT / app.py
Anusha806
gradionotworking
c3e083b
raw
history blame
24 kB
# import os
# from pinecone import Pinecone, ServerlessSpec
# from PIL import Image, ImageOps
# import numpy as np
# from datasets import load_dataset
# from pinecone_text.sparse import BM25Encoder
# from sentence_transformers import SentenceTransformer
# import torch
# from tqdm.auto import tqdm
# import gradio as gr
# # ------------------- Pinecone Setup -------------------
# os.environ["PINECONE_API_KEY"] = "pcsk_TMCYK_LrbmZMTDhkxTjUXcr8iTcQ8LxurwKBFDvv4ahFis8SVob7QexVPPEt6g2zW6d3g"
# api_key = os.environ.get('PINECONE_API_KEY')
# pc = Pinecone(api_key=api_key)
# cloud = os.environ.get('PINECONE_CLOUD') or 'aws'
# region = os.environ.get('PINECONE_REGION') or 'us-east-1'
# spec = ServerlessSpec(cloud=cloud, region=region)
# index_name = "hybrid-image-search"
# spec = ServerlessSpec(cloud="aws", region="us-east-1")
# # choose a name for your index
# index_name = "hybrid-image-search"
# import time
# # check if index already exists (it shouldn't if this is first time)
# if index_name not in pc.list_indexes().names():
# # if does not exist, create index
# pc.create_index(
# index_name,
# dimension=512,
# metric='dotproduct',
# spec=spec
# )
# # wait for index to be initialized
# while not pc.describe_index(index_name).status['ready']:
# time.sleep(1)
# # connect to index
# index = pc.Index(index_name)
# # view index stats
# index.describe_index_stats()
# # ------------------- Dataset Loading -------------------
# fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
# images = fashion["image"]
# metadata = fashion.remove_columns("image").to_pandas()
# # ------------------- Encoders -------------------
# bm25 = BM25Encoder()
# bm25.fit(metadata["productDisplayName"])
# model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device='cuda' if torch.cuda.is_available() else 'cpu')
# from sentence_transformers import SentenceTransformer
# import torch
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# # load a CLIP model from huggingface
# model = SentenceTransformer(
# 'sentence-transformers/clip-ViT-B-32',
# device=device
# )
# model
# # ------------------- Hybrid Scaling -------------------
# def hybrid_scale(dense, sparse, alpha: float):
# if alpha < 0 or alpha > 1:
# raise ValueError("Alpha must be between 0 and 1")
# # scale sparse and dense vectors to create hybrid search vecs
# hsparse = {
# 'indices': sparse['indices'],
# 'values': [v * (1 - alpha) for v in sparse['values']]
# }
# hdense = [v * alpha for v in dense]
# return hdense, hsparse
# # ------------------- Metadata Filter Extraction -------------------
# from PIL import Image, ImageOps
# import numpy as np
# from PIL import Image, ImageOps
# import numpy as np
# from PIL import Image, ImageOps
# import numpy as np
# from transformers import CLIPProcessor, CLIPModel
# clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
# clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# def extract_metadata_filters(query: str):
# query_lower = query.lower()
# gender = None
# category = None
# subcategory = None
# color = None
# # --- Gender Mapping ---
# gender_map = {
# "men": "Men", "man": "Men", "mens": "Men", "mans": "Men", "male": "Men",
# "women": "Women", "woman": "Women", "womens": "Women", "female": "Women",
# "boys": "Boys", "boy": "Boys",
# "girls": "Girls", "girl": "Girls",
# "kids": "Kids","kid": "Kids",
# "unisex": "Unisex"
# }
# for term, mapped_value in gender_map.items():
# if term in query_lower:
# gender = mapped_value
# break
# # --- Category Mapping ---
# category_map = {
# "shirt": "Shirts",
# "tshirt": "Tshirts", "t-shirt": "Tshirts",
# "jeans": "Jeans",
# "watch": "Watches",
# "kurta": "Kurtas",
# "dress": "Dresses", "dresses": "Dresses",
# "trousers": "Trousers", "pants": "Trousers",
# "shorts": "Shorts",
# "footwear": "Footwear",
# "shoes": "Shoes", # note kept as Shoes
# "fashion": "Apparel"
# }
# for term, mapped_value in category_map.items():
# if term in query_lower:
# category = mapped_value
# break
# # --- SubCategory Mapping ---
# subCategory_list = [
# "Accessories", "Apparel Set", "Bags", "Bath and Body", "Beauty Accessories",
# "Belts", "Bottomwear", "Cufflinks", "Dress", "Eyes", "Eyewear", "Flip Flops",
# "Fragrance", "Free Gifts", "Gloves", "Hair", "Headwear", "Home Furnishing",
# "Innerwear", "Jewellery", "Lips", "Loungewear and Nightwear", "Makeup",
# "Mufflers", "Nails", "Perfumes", "Sandal", "Saree", "Scarves", "Shoe Accessories",
# "Shoes", "Skin", "Skin Care", "Socks", "Sports Accessories", "Sports Equipment",
# "Stoles", "Ties", "Topwear", "Umbrellas", "Vouchers", "Wallets", "Watches",
# "Water Bottle", "Wristbands"
# ]
# if "topwear" in query_lower or "top" in query_lower:
# subcategory = "Topwear"
# else:
# for subcat in subCategory_list:
# if subcat.lower() in query_lower:
# subcategory = subcat
# break
# # --- Color Extraction ---
# colors = [
# "red","blue","green","yellow","black","white",
# "orange","pink","purple","brown","grey","beige"
# ]
# for c in colors:
# if c in query_lower:
# color = c.capitalize()
# break
# # --- Invalid pairs ---
# invalid_pairs = {
# ("Men", "Dresses"), ("Men", "Sarees"), ("Men", "Skirts"),
# ("Boys", "Dresses"), ("Boys", "Sarees"),
# ("Girls", "Boxers"), ("Men", "Heels")
# }
# if (gender, category) in invalid_pairs:
# print(f"⚠️ Invalid pair: {gender} + {category}, dropping gender")
# gender = None
# # fallback
# if gender and not category:
# category = "Apparel"
# return gender, category, subcategory, color
# def search_fashion(query: str, alpha: float, start: int = 0, end: int = 12, gender_override: str = None):
# gender, category, subcategory, color = extract_metadata_filters(query)
# # override from dropdown
# if gender_override:
# gender = gender_override
# # --- Pinecone Filter ---
# filter = {}
# if gender:
# filter["gender"] = gender
# if category:
# if category in ["Footwear", "Shoes"]:
# shoe_article_types = [
# "Casual Shoes", "Sports Shoes", "Formal Shoes", "Training Shoes",
# "Sneakers", "Sandals", "Slippers", "Boots", "Flip Flops"
# ]
# filter["articleType"] = {"$in": shoe_article_types}
# else:
# filter["articleType"] = category
# if subcategory:
# filter["subCategory"] = subcategory
# if color:
# filter["baseColour"] = color
# print(f"🔍 Using filter: {filter} (showing {start} to {end})")
# sparse = bm25.encode_queries(query)
# dense = model.encode(query).tolist()
# hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha)
# result = index.query(
# top_k=end,
# vector=hdense,
# sparse_vector=hsparse,
# include_metadata=True,
# filter=filter if filter else None
# )
# # fallback if no results
# if len(result["matches"]) == 0:
# print("⚠️ No results, retrying with alpha=0 sparse only")
# hdense, hsparse = hybrid_scale(dense, sparse, alpha=0)
# result = index.query(
# top_k=end,
# vector=hdense,
# sparse_vector=hsparse,
# include_metadata=True,
# filter=filter if filter else None
# )
# # fallback if no results with gender
# if gender and len(result["matches"]) == 0:
# print(f"⚠️ No results for gender {gender}, relaxing gender filter")
# filter.pop("gender", None)
# result = index.query(
# top_k=end,
# vector=hdense,
# sparse_vector=hsparse,
# include_metadata=True,
# filter=filter if filter else None
# )
# matches = result["matches"][start:end]
# imgs_with_captions = []
# for r in matches:
# idx = int(r["id"])
# img = images[idx]
# meta = r.get("metadata", {})
# if not isinstance(img, Image.Image):
# img = Image.fromarray(np.array(img))
# padded = ImageOps.pad(img, (256, 256), color="white")
# caption = str(meta.get("productDisplayName", "Unknown Product"))
# imgs_with_captions.append((padded, caption))
# return imgs_with_captions
# # this is working code block
# from PIL import Image, ImageOps
# import numpy as np
# def search_by_image(uploaded_image, alpha=0.5, start=0, end=12):
# """
# Search visually similar products with support for pagination.
# """
# # Preprocess image for CLIP
# processed = clip_processor(images=uploaded_image, return_tensors="pt").to(device)
# with torch.no_grad():
# image_vec = clip_model.get_image_features(**processed)
# image_vec = image_vec.cpu().numpy().flatten().tolist()
# # Query a larger top_k so you have enough to paginate
# result = index.query(
# top_k=end,
# vector=image_vec,
# include_metadata=True
# )
# matches = result["matches"][start:end] # slice for pagination
# imgs_with_captions = []
# for r in matches:
# idx = int(r["id"])
# img = images[idx]
# meta = r.get("metadata", {})
# if not isinstance(img, Image.Image):
# img = Image.fromarray(np.array(img))
# padded = ImageOps.pad(img, (256, 256), color="white")
# caption = str(meta.get("productDisplayName", "Unknown Product"))
# imgs_with_captions.append((padded, caption))
# return imgs_with_captions
# # with gr.Blocks(css=custom_css) as demo:
# # gr.Markdown("# 🛍️ Fashion Product Hybrid Search")
# # with gr.Row(equal_height=True):
# # with gr.Column(scale=5, elem_classes="query-slider"):
# # query = gr.Textbox(
# # label="Enter your fashion search query",
# # placeholder="Type something or leave blank to only use the image"
# # )
# # alpha = gr.Slider(
# # 0, 1, value=0.5,
# # label="Hybrid Weight (alpha: 0=sparse, 1=dense)"
# # )
# # with gr.Column(scale=1):
# # image_input = gr.Image(
# # type="pil",
# # label="Upload an image (optional)",
# # height=256,
# # width=356,
# # show_label=True
# # )
# # search_btn = gr.Button("Search", elem_classes="search-btn")
# # gallery = gr.Gallery(
# # label="Search Results",
# # columns=6,
# # height="40vh"
# # )
# import gradio as gr
# custom_css = """
# .search-btn {
# width: 100%;
# }
# .gr-row {
# gap: 8px !important;
# }
# .query-slider > div {
# margin-bottom: 4px !important;
# }
# .upload-box .icon-container {
# display: none !important;
# }
# """
# with gr.Blocks(css=custom_css) as demo:
# gr.Markdown("# 🛍️ Fashion Product Hybrid Search")
# with gr.Row(equal_height=True):
# with gr.Column(scale=5, elem_classes="query-slider"):
# query = gr.Textbox(
# label="Enter your fashion search query",
# placeholder="Type something or leave blank to only use the image"
# )
# alpha = gr.Slider(0, 1, value=0.5, label="Hybrid Weight (alpha: 0=sparse, 1=dense)")
# gender_dropdown = gr.Dropdown(
# ["", "Men", "Women", "Boys", "Girls", "Kids", "Unisex"],
# label="Gender Filter (optional)"
# )
# # with gr.Column(scale=1):
# # image_input = gr.Image(
# # type="pil",
# # label="Upload an image (optional)",
# # height=256,
# # width=356
# # )
# with gr.Column(scale=1):
# image_input = gr.Image(
# type="pil",
# label="Upload an image (optional)",
# height=256,
# width=356,
# sources=["upload", "clipboard"] # only upload and paste allowed
# )
# search_btn = gr.Button("Search", elem_classes="search-btn")
# gallery = gr.Gallery(label="Search Results", columns=6, height="50vh")
# load_more_btn = gr.Button("Load More")
# # States to track
# search_offset = gr.State(0)
# current_query = gr.State("")
# current_image = gr.State(None)
# current_gender = gr.State("")
# shown_results = gr.State([]) # new: store the list of shown images
# def unified_search(q, uploaded_image, a, offset, gender_ui):
# start = 0
# end = 12
# gender_override = gender_ui if gender_ui else None
# if uploaded_image is not None:
# results = search_by_image(uploaded_image, a, start, end)
# elif q.strip() != "":
# results = search_fashion(q, a, start, end, gender_override)
# else:
# results = []
# # reset shown_results to just these first 12
# return results, end, q, uploaded_image, gender_ui, results
# search_btn.click(
# unified_search,
# inputs=[query, image_input, alpha, search_offset, gender_dropdown],
# outputs=[gallery, search_offset, current_query, current_image, current_gender, shown_results]
# )
# def load_more_fn(a, offset, q, img, gender_ui, prev_results):
# start = offset
# end = offset + 12
# gender_override = gender_ui if gender_ui else None
# if img is not None:
# new_results = search_by_image(img, a, start, end)
# elif q.strip() != "":
# new_results = search_fashion(q, a, start, end, gender_override)
# else:
# new_results = []
# combined_results = prev_results + new_results
# return combined_results, end, combined_results
# load_more_btn.click(
# load_more_fn,
# inputs=[alpha, search_offset, current_query, current_image, current_gender, shown_results],
# outputs=[gallery, search_offset, shown_results]
# )
# gr.Markdown("Powered by your hybrid AI search model 🚀")
# demo.launch()
# app.py
import os
import time
import torch
import numpy as np
import gradio as gr
from PIL import Image, ImageOps
from tqdm.auto import tqdm
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from pinecone import Pinecone, ServerlessSpec
from pinecone_text.sparse import BM25Encoder
from transformers import CLIPProcessor, CLIPModel
import openai
# ------------------- Keys & Setup -------------------
openai.api_key = os.getenv("OPENAI_API_KEY")
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
spec = ServerlessSpec(cloud=os.getenv("PINECONE_CLOUD") or "aws", region=os.getenv("PINECONE_REGION") or "us-east-1")
index_name = "hybrid-image-search"
if index_name not in pc.list_indexes().names():
pc.create_index(index_name, dimension=512, metric='dotproduct', spec=spec)
while not pc.describe_index(index_name).status['ready']:
time.sleep(1)
index = pc.Index(index_name)
# ------------------- Models & Dataset -------------------
fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
images = fashion["image"]
metadata = fashion.remove_columns("image").to_pandas()
bm25 = BM25Encoder()
bm25.fit(metadata["productDisplayName"])
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device=device)
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# ------------------- Helper Functions -------------------
def hybrid_scale(dense, sparse, alpha: float):
if alpha < 0 or alpha > 1:
raise ValueError("Alpha must be between 0 and 1")
hsparse = {
'indices': sparse['indices'],
'values': [v * (1 - alpha) for v in sparse['values']]
}
hdense = [v * alpha for v in dense]
return hdense, hsparse
def extract_intent_from_openai(query: str):
prompt = f'''
You are an assistant for a fashion search engine. Extract the user's intent from the following query.
Return a Python dictionary with keys: category, gender, subcategory, color.
If something is missing, use null.
Query: "{query}"
Only return the dictionary.
'''
try:
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt}],
temperature=0
)
raw = response.choices[0].message['content']
structured = eval(raw)
return structured
except Exception as e:
print(f"⚠️ OpenAI intent extraction failed: {e}")
return {}
def is_duplicate(img, seen_hashes):
h = hash(img.tobytes())
if h in seen_hashes:
return True
seen_hashes.add(h)
return False
# ------------------- Search Functions -------------------
def search_fashion(query: str, alpha: float, start: int = 0, end: int = 12, gender_override: str = None):
intent = extract_intent_from_openai(query)
gender = intent.get("gender")
category = intent.get("category")
subcategory = intent.get("subcategory")
color = intent.get("color")
if gender_override:
gender = gender_override
filter = {}
if gender:
filter["gender"] = gender
if category:
if category in ["Footwear", "Shoes"]:
filter["articleType"] = {"$regex": ".*(Shoe|Footwear).*"}
else:
filter["articleType"] = category
if subcategory:
filter["subCategory"] = subcategory
if color:
filter["baseColour"] = color
sparse = bm25.encode_queries(query)
dense = model.encode(query).tolist()
hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha)
result = index.query(
top_k=100,
vector=hdense,
sparse_vector=hsparse,
include_metadata=True,
filter=filter if filter else None
)
if len(result["matches"]) == 0:
print("⚠️ No results, retrying with alpha=0 sparse only")
hdense, hsparse = hybrid_scale(dense, sparse, alpha=0)
result = index.query(top_k=100, vector=hdense, sparse_vector=hsparse, include_metadata=True, filter=filter)
imgs_with_captions = []
seen_hashes = set()
for r in result["matches"]:
idx = int(r["id"])
img = images[idx]
meta = r.get("metadata", {})
if not isinstance(img, Image.Image):
img = Image.fromarray(np.array(img))
padded = ImageOps.pad(img, (256, 256), color="white")
caption = str(meta.get("productDisplayName", "Unknown Product"))
if not is_duplicate(padded, seen_hashes):
imgs_with_captions.append((padded, caption))
if len(imgs_with_captions) >= end:
break
return imgs_with_captions
def search_by_image(uploaded_image, alpha=0.5, start=0, end=12):
processed = clip_processor(images=uploaded_image, return_tensors="pt").to(device)
with torch.no_grad():
image_vec = clip_model.get_image_features(**processed)
image_vec = image_vec.cpu().numpy().flatten().tolist()
result = index.query(top_k=100, vector=image_vec, include_metadata=True)
imgs_with_captions = []
seen_hashes = set()
for r in result["matches"]:
idx = int(r["id"])
img = images[idx]
meta = r.get("metadata", {})
caption = str(meta.get("productDisplayName", "Unknown Product"))
if not isinstance(img, Image.Image):
img = Image.fromarray(np.array(img))
padded = ImageOps.pad(img, (256, 256), color="white")
if not is_duplicate(padded, seen_hashes):
imgs_with_captions.append((padded, caption))
if len(imgs_with_captions) >= end:
break
return imgs_with_captions
# ------------------- UI -------------------
custom_css = """
.search-btn { width: 100%; }
.gr-row { gap: 8px !important; }
.query-slider > div { margin-bottom: 4px !important; }
.gr-gallery-item { width: 256px !important; height: 256px !important; }
.gr-gallery-item img { width: 100% !important; height: 100% !important; object-fit: cover !important; }
"""
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("# 🛍️ Fashion Product Hybrid Search (with GPT-4 powered query parsing)")
with gr.Row(equal_height=True):
with gr.Column(scale=5, elem_classes="query-slider"):
query = gr.Textbox(label="Enter your fashion search query", placeholder="e.g., black sneakers for women")
alpha = gr.Slider(0, 1, value=0.5, label="Hybrid Weight (alpha: 0=sparse, 1=dense)")
gender_dropdown = gr.Dropdown(["", "Men", "Women", "Boys", "Girls", "Kids", "Unisex"], label="Gender Filter (optional)")
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload an image (optional)", sources=["upload", "clipboard"], height=256, width=356)
search_btn = gr.Button("Search", elem_classes="search-btn")
gallery = gr.Gallery(label="Search Results", columns=6, height=None)
load_more_btn = gr.Button("Load More")
search_offset = gr.State(0)
current_query = gr.State("")
current_image = gr.State(None)
current_gender = gr.State("")
shown_results = gr.State([])
shown_ids = gr.State(set())
def unified_search(q, uploaded_image, a, offset, gender_ui):
start = 0
end = 12
filters = extract_intent_from_openai(q) if q.strip() else {}
gender_override = gender_ui if gender_ui else filters.get("gender")
if uploaded_image is not None:
results = search_by_image(uploaded_image, a, start, end)
elif q.strip():
results = search_fashion(q, a, start, end, gender_override)
else:
results = []
seen_ids = {r[1] for r in results}
return results, end, q, uploaded_image, gender_override, results, seen_ids
search_btn.click(unified_search, inputs=[query, image_input, alpha, search_offset, gender_dropdown],
outputs=[gallery, search_offset, current_query, current_image, current_gender, shown_results, shown_ids])
def load_more_fn(a, offset, q, img, gender_ui, prev_results, prev_ids):
start = offset
end = offset + 12
gender_override = gender_ui
if img is not None:
new_results = search_by_image(img, a, start, end)
elif q.strip():
new_results = search_fashion(q, a, start, end, gender_override)
else:
new_results = []
filtered_new = []
new_ids = set()
for item in new_results:
img_obj, caption = item
if caption not in prev_ids:
filtered_new.append(item)
new_ids.add(caption)
combined = prev_results + filtered_new
updated_ids = prev_ids.union(new_ids)
return combined, end, combined, updated_ids
load_more_btn.click(load_more_fn, inputs=[alpha, search_offset, current_query, current_image, current_gender, shown_results, shown_ids],
outputs=[gallery, search_offset, shown_results, shown_ids])
gr.Markdown("🧠 Powered by OpenAI + Hybrid AI Fashion Search")
demo.launch()