FashionGPT / app.py
Anusha806
commit25
c535089
raw
history blame
36.5 kB
# import os
# from pinecone import Pinecone, ServerlessSpec
# from PIL import Image, ImageOps
# import numpy as np
# from datasets import load_dataset
# from pinecone_text.sparse import BM25Encoder
# from sentence_transformers import SentenceTransformer
# import torch
# from tqdm.auto import tqdm
# import gradio as gr
# # ------------------- Pinecone Setup -------------------
# os.environ["PINECONE_API_KEY"] = "pcsk_TMCYK_LrbmZMTDhkxTjUXcr8iTcQ8LxurwKBFDvv4ahFis8SVob7QexVPPEt6g2zW6d3g"
# api_key = os.environ.get('PINECONE_API_KEY')
# pc = Pinecone(api_key=api_key)
# cloud = os.environ.get('PINECONE_CLOUD') or 'aws'
# region = os.environ.get('PINECONE_REGION') or 'us-east-1'
# spec = ServerlessSpec(cloud=cloud, region=region)
# index_name = "hybrid-image-search"
# spec = ServerlessSpec(cloud="aws", region="us-east-1")
# # choose a name for your index
# index_name = "hybrid-image-search"
# import time
# # check if index already exists (it shouldn't if this is first time)
# if index_name not in pc.list_indexes().names():
# # if does not exist, create index
# pc.create_index(
# index_name,
# dimension=512,
# metric='dotproduct',
# spec=spec
# )
# # wait for index to be initialized
# while not pc.describe_index(index_name).status['ready']:
# time.sleep(1)
# # connect to index
# index = pc.Index(index_name)
# # view index stats
# index.describe_index_stats()
# # ------------------- Dataset Loading -------------------
# fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
# images = fashion["image"]
# metadata = fashion.remove_columns("image").to_pandas()
# # ------------------- Encoders -------------------
# bm25 = BM25Encoder()
# bm25.fit(metadata["productDisplayName"])
# model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device='cuda' if torch.cuda.is_available() else 'cpu')
# from sentence_transformers import SentenceTransformer
# import torch
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# # load a CLIP model from huggingface
# model = SentenceTransformer(
# 'sentence-transformers/clip-ViT-B-32',
# device=device
# )
# model
# # ------------------- Hybrid Scaling -------------------
# def hybrid_scale(dense, sparse, alpha: float):
# if alpha < 0 or alpha > 1:
# raise ValueError("Alpha must be between 0 and 1")
# # scale sparse and dense vectors to create hybrid search vecs
# hsparse = {
# 'indices': sparse['indices'],
# 'values': [v * (1 - alpha) for v in sparse['values']]
# }
# hdense = [v * alpha for v in dense]
# return hdense, hsparse
# # ------------------- Metadata Filter Extraction -------------------
# from PIL import Image, ImageOps
# import numpy as np
# from PIL import Image, ImageOps
# import numpy as np
# from PIL import Image, ImageOps
# import numpy as np
# from transformers import CLIPProcessor, CLIPModel
# clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
# clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# def extract_metadata_filters(query: str):
# query_lower = query.lower()
# gender = None
# category = None
# subcategory = None
# color = None
# # --- Gender Mapping ---
# gender_map = {
# "men": "Men", "man": "Men", "mens": "Men", "mans": "Men", "male": "Men",
# "women": "Women", "woman": "Women", "womens": "Women", "female": "Women",
# "boys": "Boys", "boy": "Boys",
# "girls": "Girls", "girl": "Girls",
# "kids": "Kids","kid": "Kids",
# "unisex": "Unisex"
# }
# for term, mapped_value in gender_map.items():
# if term in query_lower:
# gender = mapped_value
# break
# # --- Category Mapping ---
# category_map = {
# "shirt": "Shirts",
# "tshirt": "Tshirts", "t-shirt": "Tshirts",
# "jeans": "Jeans",
# "watch": "Watches",
# "kurta": "Kurtas",
# "dress": "Dresses", "dresses": "Dresses",
# "trousers": "Trousers", "pants": "Trousers",
# "shorts": "Shorts",
# "footwear": "Footwear",
# "shoes": "Shoes", # note kept as Shoes
# "fashion": "Apparel"
# }
# for term, mapped_value in category_map.items():
# if term in query_lower:
# category = mapped_value
# break
# # --- SubCategory Mapping ---
# subCategory_list = [
# "Accessories", "Apparel Set", "Bags", "Bath and Body", "Beauty Accessories",
# "Belts", "Bottomwear", "Cufflinks", "Dress", "Eyes", "Eyewear", "Flip Flops",
# "Fragrance", "Free Gifts", "Gloves", "Hair", "Headwear", "Home Furnishing",
# "Innerwear", "Jewellery", "Lips", "Loungewear and Nightwear", "Makeup",
# "Mufflers", "Nails", "Perfumes", "Sandal", "Saree", "Scarves", "Shoe Accessories",
# "Shoes", "Skin", "Skin Care", "Socks", "Sports Accessories", "Sports Equipment",
# "Stoles", "Ties", "Topwear", "Umbrellas", "Vouchers", "Wallets", "Watches",
# "Water Bottle", "Wristbands"
# ]
# if "topwear" in query_lower or "top" in query_lower:
# subcategory = "Topwear"
# else:
# for subcat in subCategory_list:
# if subcat.lower() in query_lower:
# subcategory = subcat
# break
# # --- Color Extraction ---
# colors = [
# "red","blue","green","yellow","black","white",
# "orange","pink","purple","brown","grey","beige"
# ]
# for c in colors:
# if c in query_lower:
# color = c.capitalize()
# break
# # --- Invalid pairs ---
# invalid_pairs = {
# ("Men", "Dresses"), ("Men", "Sarees"), ("Men", "Skirts"),
# ("Boys", "Dresses"), ("Boys", "Sarees"),
# ("Girls", "Boxers"), ("Men", "Heels")
# }
# if (gender, category) in invalid_pairs:
# print(f"⚠️ Invalid pair: {gender} + {category}, dropping gender")
# gender = None
# # fallback
# if gender and not category:
# category = "Apparel"
# return gender, category, subcategory, color
# def search_fashion(query: str, alpha: float, start: int = 0, end: int = 12, gender_override: str = None):
# gender, category, subcategory, color = extract_metadata_filters(query)
# # override from dropdown
# if gender_override:
# gender = gender_override
# # --- Pinecone Filter ---
# filter = {}
# if gender:
# filter["gender"] = gender
# if category:
# if category in ["Footwear", "Shoes"]:
# shoe_article_types = [
# "Casual Shoes", "Sports Shoes", "Formal Shoes", "Training Shoes",
# "Sneakers", "Sandals", "Slippers", "Boots", "Flip Flops"
# ]
# filter["articleType"] = {"$in": shoe_article_types}
# else:
# filter["articleType"] = category
# if subcategory:
# filter["subCategory"] = subcategory
# if color:
# filter["baseColour"] = color
# print(f"🔍 Using filter: {filter} (showing {start} to {end})")
# sparse = bm25.encode_queries(query)
# dense = model.encode(query).tolist()
# hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha)
# result = index.query(
# top_k=end,
# vector=hdense,
# sparse_vector=hsparse,
# include_metadata=True,
# filter=filter if filter else None
# )
# # fallback if no results
# if len(result["matches"]) == 0:
# print("⚠️ No results, retrying with alpha=0 sparse only")
# hdense, hsparse = hybrid_scale(dense, sparse, alpha=0)
# result = index.query(
# top_k=end,
# vector=hdense,
# sparse_vector=hsparse,
# include_metadata=True,
# filter=filter if filter else None
# )
# # fallback if no results with gender
# if gender and len(result["matches"]) == 0:
# print(f"⚠️ No results for gender {gender}, relaxing gender filter")
# filter.pop("gender", None)
# result = index.query(
# top_k=end,
# vector=hdense,
# sparse_vector=hsparse,
# include_metadata=True,
# filter=filter if filter else None
# )
# matches = result["matches"][start:end]
# imgs_with_captions = []
# for r in matches:
# idx = int(r["id"])
# img = images[idx]
# meta = r.get("metadata", {})
# if not isinstance(img, Image.Image):
# img = Image.fromarray(np.array(img))
# padded = ImageOps.pad(img, (256, 256), color="white")
# caption = str(meta.get("productDisplayName", "Unknown Product"))
# imgs_with_captions.append((padded, caption))
# return imgs_with_captions
# # this is working code block
# from PIL import Image, ImageOps
# import numpy as np
# def search_by_image(uploaded_image, alpha=0.5, start=0, end=12):
# """
# Search visually similar products with support for pagination.
# """
# # Preprocess image for CLIP
# processed = clip_processor(images=uploaded_image, return_tensors="pt").to(device)
# with torch.no_grad():
# image_vec = clip_model.get_image_features(**processed)
# image_vec = image_vec.cpu().numpy().flatten().tolist()
# # Query a larger top_k so you have enough to paginate
# result = index.query(
# top_k=end,
# vector=image_vec,
# include_metadata=True
# )
# matches = result["matches"][start:end] # slice for pagination
# imgs_with_captions = []
# for r in matches:
# idx = int(r["id"])
# img = images[idx]
# meta = r.get("metadata", {})
# if not isinstance(img, Image.Image):
# img = Image.fromarray(np.array(img))
# padded = ImageOps.pad(img, (256, 256), color="white")
# caption = str(meta.get("productDisplayName", "Unknown Product"))
# imgs_with_captions.append((padded, caption))
# return imgs_with_captions
# # with gr.Blocks(css=custom_css) as demo:
# # gr.Markdown("# 🛍️ Fashion Product Hybrid Search")
# # with gr.Row(equal_height=True):
# # with gr.Column(scale=5, elem_classes="query-slider"):
# # query = gr.Textbox(
# # label="Enter your fashion search query",
# # placeholder="Type something or leave blank to only use the image"
# # )
# # alpha = gr.Slider(
# # 0, 1, value=0.5,
# # label="Hybrid Weight (alpha: 0=sparse, 1=dense)"
# # )
# # with gr.Column(scale=1):
# # image_input = gr.Image(
# # type="pil",
# # label="Upload an image (optional)",
# # height=256,
# # width=356,
# # show_label=True
# # )
# # search_btn = gr.Button("Search", elem_classes="search-btn")
# # gallery = gr.Gallery(
# # label="Search Results",
# # columns=6,
# # height="40vh"
# # )
# import gradio as gr
# custom_css = """
# .search-btn {
# width: 100%;
# }
# .gr-row {
# gap: 8px !important;
# }
# .query-slider > div {
# margin-bottom: 4px !important;
# }
# .upload-box .icon-container {
# display: none !important;
# }
# """
# with gr.Blocks(css=custom_css) as demo:
# gr.Markdown("# 🛍️ Fashion Product Hybrid Search")
# with gr.Row(equal_height=True):
# with gr.Column(scale=5, elem_classes="query-slider"):
# query = gr.Textbox(
# label="Enter your fashion search query",
# placeholder="Type something or leave blank to only use the image"
# )
# alpha = gr.Slider(0, 1, value=0.5, label="Hybrid Weight (alpha: 0=sparse, 1=dense)")
# gender_dropdown = gr.Dropdown(
# ["", "Men", "Women", "Boys", "Girls", "Kids", "Unisex"],
# label="Gender Filter (optional)"
# )
# # with gr.Column(scale=1):
# # image_input = gr.Image(
# # type="pil",
# # label="Upload an image (optional)",
# # height=256,
# # width=356
# # )
# with gr.Column(scale=1):
# image_input = gr.Image(
# type="pil",
# label="Upload an image (optional)",
# height=256,
# width=356,
# sources=["upload", "clipboard"] # only upload and paste allowed
# )
# search_btn = gr.Button("Search", elem_classes="search-btn")
# gallery = gr.Gallery(label="Search Results", columns=6, height="50vh")
# load_more_btn = gr.Button("Load More")
# # States to track
# search_offset = gr.State(0)
# current_query = gr.State("")
# current_image = gr.State(None)
# current_gender = gr.State("")
# shown_results = gr.State([]) # new: store the list of shown images
# def unified_search(q, uploaded_image, a, offset, gender_ui):
# start = 0
# end = 12
# gender_override = gender_ui if gender_ui else None
# if uploaded_image is not None:
# results = search_by_image(uploaded_image, a, start, end)
# elif q.strip() != "":
# results = search_fashion(q, a, start, end, gender_override)
# else:
# results = []
# # reset shown_results to just these first 12
# return results, end, q, uploaded_image, gender_ui, results
# search_btn.click(
# unified_search,
# inputs=[query, image_input, alpha, search_offset, gender_dropdown],
# outputs=[gallery, search_offset, current_query, current_image, current_gender, shown_results]
# )
# def load_more_fn(a, offset, q, img, gender_ui, prev_results):
# start = offset
# end = offset + 12
# gender_override = gender_ui if gender_ui else None
# if img is not None:
# new_results = search_by_image(img, a, start, end)
# elif q.strip() != "":
# new_results = search_fashion(q, a, start, end, gender_override)
# else:
# new_results = []
# combined_results = prev_results + new_results
# return combined_results, end, combined_results
# load_more_btn.click(
# load_more_fn,
# inputs=[alpha, search_offset, current_query, current_image, current_gender, shown_results],
# outputs=[gallery, search_offset, shown_results]
# )
# gr.Markdown("Powered by your hybrid AI search model 🚀")
# demo.launch()
# app.py
import os
import time
import torch
import numpy as np
import gradio as gr
from PIL import Image, ImageOps
from tqdm.auto import tqdm
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from pinecone import Pinecone, ServerlessSpec
from pinecone_text.sparse import BM25Encoder
from transformers import CLIPProcessor, CLIPModel
import openai
# ------------------- Keys & Setup -------------------
openai.api_key = os.getenv("OPENAI_API_KEY")
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
spec = ServerlessSpec(cloud=os.getenv("PINECONE_CLOUD") or "aws", region=os.getenv("PINECONE_REGION") or "us-east-1")
index_name = "hybrid-image-search"
if index_name not in pc.list_indexes().names():
pc.create_index(index_name, dimension=512, metric='dotproduct', spec=spec)
while not pc.describe_index(index_name).status['ready']:
time.sleep(1)
index = pc.Index(index_name)
# ------------------- Models & Dataset -------------------
fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
images = fashion["image"]
metadata = fashion.remove_columns("image").to_pandas()
bm25 = BM25Encoder()
bm25.fit(metadata["productDisplayName"])
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device=device)
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# ------------------- Helper Functions -------------------
def hybrid_scale(dense, sparse, alpha: float):
if alpha < 0 or alpha > 1:
raise ValueError("Alpha must be between 0 and 1")
hsparse = {
'indices': sparse['indices'],
'values': [v * (1 - alpha) for v in sparse['values']]
}
hdense = [v * alpha for v in dense]
return hdense, hsparse
def extract_intent_from_openai(query: str):
prompt = f"""
You are an assistant for a fashion search engine. Extract the user's intent from the following query.
Return a Python dictionary with keys: category, gender, subcategory, color.
If something is missing, use null.
Query: "{query}"
Only return the dictionary.
"""
try:
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt}],
temperature=0
)
raw = response.choices[0].message['content']
structured = eval(raw)
return structured
except Exception as e:
print(f"⚠️ OpenAI intent extraction failed: {e}")
return {"include": {}, "exclude": {}}
#-----------------below changed------------------------------#
import imagehash
from PIL import Image
def is_duplicate(img, existing_hashes, hash_size=16, tolerance=0):
"""
Checks if the image is a near-duplicate based on perceptual hash.
:param img: PIL Image
:param existing_hashes: set of previously seen hashes
:param hash_size: size of the hash (default=16 for more precision)
:param tolerance: allowable Hamming distance for near-duplicates
:return: (bool) whether image is duplicate
"""
img_hash = imagehash.phash(img, hash_size=hash_size)
for h in existing_hashes:
if abs(img_hash - h) <= tolerance:
return True
existing_hashes.add(img_hash)
return False
def extract_metadata_filters(query: str):
query_lower = query.lower()
gender = None
category = None
subcategory = None
color = None
# --- Gender Mapping ---
gender_map = {
"men": "Men", "man": "Men", "mens": "Men", "mans": "Men", "male": "Men",
"women": "Women", "woman": "Women", "womens": "Women", "female": "Women",
"boys": "Boys", "boy": "Boys",
"girls": "Girls", "girl": "Girls",
"kids": "Kids", "kid": "Kids",
"unisex": "Unisex"
}
for term, mapped_value in gender_map.items():
if term in query_lower:
gender = mapped_value
break
# --- Category Mapping ---
category_map = {
"shirt": "Shirts",
"tshirt": "Tshirts",
"t-shirt": "Tshirts",
"jeans": "Jeans",
"watch": "Watches",
"kurta": "Kurtas",
"dress": "Dresses",
"trousers": "Trousers", "pants": "Trousers",
"shorts": "Shorts",
"footwear": "Footwear",
"shoes": "Shoes",
"fashion": "Apparel"
}
for term, mapped_value in category_map.items():
if term in query_lower:
category = mapped_value
break
# --- SubCategory Mapping ---
subCategory_list = [
"Accessories", "Apparel Set", "Bags", "Bath and Body", "Beauty Accessories",
"Belts", "Bottomwear", "Cufflinks", "Dress", "Eyes", "Eyewear", "Flip Flops",
"Fragrance", "Free Gifts", "Gloves", "Hair", "Headwear", "Home Furnishing",
"Innerwear", "Jewellery", "Lips", "Loungewear and Nightwear", "Makeup",
"Mufflers", "Nails", "Perfumes", "Sandal", "Saree", "Scarves", "Shoe Accessories",
"Shoes", "Skin", "Skin Care", "Socks", "Sports Accessories", "Sports Equipment",
"Stoles", "Ties", "Topwear", "Umbrellas", "Vouchers", "Wallets", "Watches",
"Water Bottle", "Wristbands"
]
if "topwear" in query_lower or "top" in query_lower:
subcategory = "Topwear"
else:
query_words = query_lower.split()
for subcat in subCategory_list:
if subcat.lower() in query_words:
subcategory = subcat
break
# --- Color Extraction ---
color_list = [
"red", "blue", "green", "yellow", "black", "white",
"orange", "pink", "purple", "brown", "grey", "beige"
]
for c in color_list:
if c in query_lower:
color = c.capitalize()
break
# --- Invalid pairs ---
invalid_pairs = {
("Men", "Dresses"), ("Men", "Sarees"), ("Men", "Skirts"),
("Boys", "Dresses"), ("Boys", "Sarees"),
("Girls", "Boxers"), ("Men", "Heels")
}
if (gender, category) in invalid_pairs:
print(f"⚠️ Invalid pair: {gender} + {category}, dropping gender")
gender = None
# --- Fallback for missing category ---
if gender and not category:
category = "Apparel"
# --- Refine subcategory for party/wedding-related queries ---
if "party" in query_lower or "wedding" in query_lower or "cocktail" in query_lower:
if subcategory in ["Loungewear and Nightwear", "Nightdress", "Innerwear"]:
subcategory = None # reset it to avoid filtering into wrong items
return gender, category, subcategory, color
# ------------------- Search Functions -------------------
def search_fashion(query: str, alpha: float, start: int = 0, end: int = 12, gender_override: str = None):
intent = extract_intent_from_openai(query)
include = intent.get("include", {})
exclude = intent.get("exclude", {})
gender = include.get("gender")
category = include.get("category")
subcategory = include.get("subcategory")
color = include.get("color")
# Apply override from dropdown
if gender_override:
gender = gender_override
# Build Pinecone filter
filter = {}
# Inclusion filters
if gender:
filter["gender"] = gender
if category:
if category in ["Footwear", "Shoes"]:
filter["articleType"] = {"$regex": ".*(Shoe|Footwear).*"}
else:
filter["articleType"] = category
if subcategory:
filter["subCategory"] = subcategory
# Step 4: Exclude irrelevant items for party-like queries
query_lower = query.lower()
if any(word in query_lower for word in ["party", "wedding", "cocktail", "traditional", "reception"]):
filter.setdefault("subCategory", {})
if isinstance(filter["subCategory"], dict):
filter["subCategory"]["$nin"] = [
"Loungewear and Nightwear", "Nightdress", "Innerwear", "Sleepwear", "Vests", "Boxers"
]
if color:
filter["baseColour"] = color
# Exclusion filters
exclude_filter = {}
if exclude.get("color"):
exclude_filter["baseColour"] = {"$ne": exclude["color"]}
if exclude.get("subcategory"):
exclude_filter["subCategory"] = {"$ne": exclude["subcategory"]}
if exclude.get("category"):
exclude_filter["articleType"] = {"$ne": exclude["category"]}
# Combine all filters
if filter and exclude_filter:
final_filter = {"$and": [filter, exclude_filter]}
elif filter:
final_filter = filter
elif exclude_filter:
final_filter = exclude_filter
else:
final_filter = None
print(f"🔍 Using filter: {final_filter} (showing {start} to {end})")
# Hybrid encoding
sparse = bm25.encode_queries(query)
dense = model.encode(query).tolist()
hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha)
result = index.query(
top_k=100,
vector=hdense,
sparse_vector=hsparse,
include_metadata=True,
filter=final_filter
)
# Retry fallback
if len(result["matches"]) == 0:
print("⚠️ No results, retrying with alpha=0 sparse only")
hdense, hsparse = hybrid_scale(dense, sparse, alpha=0)
result = index.query(
top_k=100,
vector=hdense,
sparse_vector=hsparse,
include_metadata=True,
filter=final_filter
)
# Format results
imgs_with_captions = []
seen_hashes = set()
for r in result["matches"]:
idx = int(r["id"])
img = images[idx]
meta = r.get("metadata", {})
if not isinstance(img, Image.Image):
img = Image.fromarray(np.array(img))
padded = ImageOps.pad(img, (256, 256), color="white")
caption = str(meta.get("productDisplayName", "Unknown Product"))
if not is_duplicate(padded, seen_hashes):
imgs_with_captions.append((padded, caption))
if len(imgs_with_captions) >= end:
break
return imgs_with_captions
def search_by_image(uploaded_image, alpha=0.5, start=0, end=12):
# Step 1: Preprocess image for CLIP model
processed = clip_processor(images=uploaded_image, return_tensors="pt").to(device)
with torch.no_grad():
image_vec = clip_model.get_image_features(**processed)
image_vec = image_vec.cpu().numpy().flatten().tolist()
# Step 2: Query Pinecone index for similar images
result = index.query(
top_k=100, # fetch more to allow deduplication
vector=image_vec,
include_metadata=True
)
matches = result["matches"]
imgs_with_captions = []
seen_hashes = set()
# Step 3: Deduplicate based on image hash
for r in matches:
idx = int(r["id"])
img = images[idx]
meta = r.get("metadata", {})
caption = str(meta.get("productDisplayName", "Unknown Product"))
if not isinstance(img, Image.Image):
img = Image.fromarray(np.array(img))
padded = ImageOps.pad(img, (256, 256), color="white")
if not is_duplicate(padded, seen_hashes):
imgs_with_captions.append((padded, caption))
if len(imgs_with_captions) >= end:
break
return imgs_with_captions
import gradio as gr
import whisper
asr_model = whisper.load_model("base")
def handle_voice_search(vf_path, a, offset, gender_ui):
try:
transcription = asr_model.transcribe(vf_path)["text"].strip()
except:
transcription = ""
filters = extract_intent_from_openai(transcription) if transcription else {}
gender_override = gender_ui if gender_ui else filters.get("gender")
results = search_fashion(transcription, a, 0, 12, gender_override)
seen_ids = {r[1] for r in results}
return results, 12, transcription, None, gender_override, results, seen_ids
custom_css = """
/* === Global Styling === */
/* === Override Gradio default background === */
html, body {
height: 100% !important;
margin: 0 !important;
padding: 0 !important;
background: radial-gradient(circle at center, #0b1f36 0%, #033e3e 100%) !important;
background-attachment: fixed;
}
.gr-root, .gr-block {
background: transparent !important;
}
body::before {
content: "";
position: fixed;
top: 0; left: 0;
width: 100%; height: 100%;
background: radial-gradient(circle at center, rgba(0, 255, 255, 0.08), transparent);
z-index: -1;
}
#app-bg {
min-height: 100vh;
padding: 0;
margin: 0;
background: radial-gradient(circle at center, #0b1f36 0%, #033e3e 100%);
display: flex;
justify-content: center;
align-items: flex-start;
background-attachment: fixed;
position: relative;
overflow: hidden;
}
#app-bg::before {
content: "";
position: absolute;
top: 0; left: 0;
width: 100%; height: 100%;
background: radial-gradient(circle at center, rgba(0, 255, 255, 0.08), transparent);
z-index: 0;
}
#main-container {
z-index: 1;
position: relative;
}
/* === Heading Style === */
h1, .gr-markdown h1 {
font-size: 2.2rem !important;
font-weight: bold;
color: #000000;
text-align: center;
margin-bottom: 1rem;
}
/* === Tabs === */
.gr-tab {
border-radius: 12px !important;
background-color: #ffffff !important;
box-shadow: 0 3px 10px rgba(0, 0, 0, 0.08);
padding: 16px !important;
margin-top: 12px;
}
/* === Textbox, Dropdown, Slider === */
input[type="text"], .gr-textbox textarea, .gr-dropdown, .gr-slider {
border-radius: 8px !important;
border: 1px solid #ccc !important;
padding: 10px !important;
font-size: 16px;
box-shadow: 0 1px 3px rgba(0,0,0,0.05);
}
/* === Image Upload === */
.gr-image {
width: 100% !important;
max-width: 100% !important;
border-radius: 12px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
/* === Buttons (custom style .button-36) === */
.gr-button {
background-color: #DBDBDB !important;
background-image: linear-gradient(92.88deg, #455EB5 9.16%, #5643CC 43.89%, #673FD7 64.72%);
border-radius: 8px !important;
border-style: none !important;
box-sizing: border-box;
color: #FFFFFF !important;
cursor: pointer;
flex-shrink: 0;
font-family: "Inter UI","SF Pro Display",-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,Oxygen,Ubuntu,Cantarell,"Open Sans","Helvetica Neue",sans-serif;
font-size: 16px;
font-weight: 500;
height: 4rem;
padding: 0 1.6rem;
text-align: center;
text-shadow: rgba(0, 0, 0, 0.25) 0 3px 8px;
transition: all .5s;
user-select: none;
-webkit-user-select: none;
touch-action: manipulation;
}
.gr-button:hover {
box-shadow: rgba(80, 63, 205, 0.5) 0 1px 30px;
transition-duration: .1s;
}
/* === Responsive padding === */
@media (min-width: 768px) {
.gr-button {
padding: 0 2.6rem;
}
}
/* === Gallery Grid === */
.gr-gallery {
padding-top: 12px;
}
.gr-gallery-item {
width: 128px !important;
height: 128px !important;
transition: transform 0.3s ease-in-out;
border-radius: 8px;
overflow: hidden;
}
.gr-gallery-item:hover {
transform: scale(1.06);
box-shadow: 0 3px 12px rgba(0,0,0,0.15);
}
.gr-gallery-item img {
object-fit: cover !important;
width: 100% !important;
height: 100% !important;
border-radius: 8px;
}
/* === Audio Upload === */
.gr-audio {
width: 100% !important;
border-radius: 12px;
background-color: #fff !important;
box-shadow: 0 1px 5px rgba(0,0,0,0.1);
}
/* === Footer === */
.gr-markdown:last-child {
text-align: center;
font-size: 14px;
color: #666;
padding-top: 1rem;
}
/* === Main Container Centered and Wide === */
#main-container {
max-width: 90%;
width: 1100px;
margin: 40px auto !important;
padding: 24px;
background: #ffffff;
border-radius: 18px;
box-shadow: 0 10px 30px rgba(0,0,0,0.08);
border: 3px solid orange; /* Orange border */
}
/* === Tab Label Styling === */
button[role="tab"] {
color: #000000 !important; /* Default tab text color: black */
font-weight: 500;
transition: color 0.3s ease-in-out;
font-size: 16px;
}
/* Active tab title */
button[role="tab"][aria-selected="true"] {
color: #f57c00 !important; /* Active tab text color: orange */
font-weight: bold !important;
}
/* Hover effect on tab titles */
button[role="tab"]:hover {
color: #f57c00 !important; /* Orange on hover */
font-weight: 600;
cursor: pointer;
}
/* === Uniform Input Sizes for Text, Audio, Image === */
.gr-textbox, .gr-audio, .gr-image {
max-width: 100% !important;
width: 100% !important;
}
.gr-audio, .gr-image {
max-width: 500px !important;
margin: 0 auto;
}
.gr-image {
height: 256px !important;
}
"""
with gr.Blocks(css=custom_css) as demo:
with gr.Column(elem_id="app-bg"):
with gr.Column(elem_id="main-container"):
gr.Markdown("# 🛍️ Fashion Product Hybrid Search")
alpha = gr.Slider(0, 1, value=0.5, label="Hybrid Weight (alpha: 0=sparse, 1=dense)")
with gr.Tabs():
with gr.Tab("Text Search"):
query = gr.Textbox(
label="Text Query",
placeholder="e.g., floral summer dress for women"
)
gender_dropdown = gr.Dropdown(
["", "Men", "Women", "Boys", "Girls", "Kids", "Unisex"],
label="Gender Filter (optional)"
)
text_search_btn = gr.Button("Search by Text", elem_classes="search-btn")
with gr.Tab("🎙️ Voice Search"):
voice_input = gr.Audio(label="Speak Your Query", type="filepath")
voice_gender_dropdown = gr.Dropdown(["", "Men", "Women", "Boys", "Girls", "Kids", "Unisex"], label="Gender")
voice_search_btn = gr.Button("Search by Voice")
with gr.Tab("Image Search"):
# image_input = gr.Image(
# type="pil",
# label="Upload an image",
# sources=["upload", "clipboard"],
# height=256,
# width=356
# )
image_input = gr.Image(
type="pil",
label="Upload an image",
sources=["upload", "clipboard"],
# tool=None,
height=400
)
image_gender_dropdown = gr.Dropdown(
["", "Men", "Women", "Boys", "Girls", "Kids", "Unisex"],
label="Gender Filter (optional)"
)
image_search_btn = gr.Button("Search by Image", elem_classes="search-btn")
gallery = gr.Gallery(label="Search Results", columns=6, height=None)
load_more_btn = gr.Button("Load More")
# --- UI State Holders ---
search_offset = gr.State(0)
current_query = gr.State("")
current_image = gr.State(None)
current_gender = gr.State("")
shown_results = gr.State([])
shown_ids = gr.State(set())
# --- Unified Search Function ---
def unified_search(q, uploaded_image, a, offset, gender_ui):
start = 0
end = 12
filters = extract_intent_from_openai(q) if q.strip() else {}
gender_override = gender_ui if gender_ui else filters.get("gender")
if uploaded_image is not None:
results = search_by_image(uploaded_image, a, start, end)
elif q.strip():
results = search_fashion(q, a, start, end, gender_override)
else:
results = []
seen_ids = {r[1] for r in results}
return results, end, q, uploaded_image, gender_override, results, seen_ids
# Text Search
# Text Search
text_search_btn.click(
unified_search,
inputs=[query, gr.State(None), alpha, search_offset, gender_dropdown],
outputs=[gallery, search_offset, current_query, current_image, current_gender, shown_results, shown_ids]
)
voice_search_btn.click(
handle_voice_search,
inputs=[voice_input, alpha, search_offset, voice_gender_dropdown],
outputs=[gallery, search_offset, current_query, current_image, current_gender, shown_results, shown_ids]
)
# Image Search
image_search_btn.click(
unified_search,
inputs=[gr.State(""), image_input, alpha, search_offset, image_gender_dropdown],
outputs=[gallery, search_offset, current_query, current_image, current_gender, shown_results, shown_ids]
)
# --- Load More Button ---
def load_more_fn(a, offset, q, img, gender_ui, prev_results, prev_ids):
start = offset
end = offset + 12
gender_override = gender_ui
if img is not None:
new_results = search_by_image(img, a, start, end)
elif q.strip():
new_results = search_fashion(q, a, start, end, gender_override)
else:
new_results = []
filtered_new = []
new_ids = set()
for item in new_results:
img_obj, caption = item
if caption not in prev_ids:
filtered_new.append(item)
new_ids.add(caption)
combined = prev_results + filtered_new
updated_ids = prev_ids.union(new_ids)
return combined, end, combined, updated_ids
load_more_btn.click(
load_more_fn,
inputs=[alpha, search_offset, current_query, current_image, current_gender, shown_results, shown_ids],
outputs=[gallery, search_offset, shown_results, shown_ids]
)
# gr.Markdown("🧠 Powered by OpenAI + Hybrid AI Fashion Search")
demo.launch()