FashionGPT / app.py
Anusha806
commit17
bdd1f17
raw
history blame
11.3 kB
import os
from pinecone import Pinecone, ServerlessSpec
from PIL import Image, ImageOps
import numpy as np
from datasets import load_dataset
from pinecone_text.sparse import BM25Encoder
from sentence_transformers import SentenceTransformer
import torch
from tqdm.auto import tqdm
import gradio as gr
# ------------------- Pinecone Setup -------------------
os.environ["PINECONE_API_KEY"] = "pcsk_TMCYK_LrbmZMTDhkxTjUXcr8iTcQ8LxurwKBFDvv4ahFis8SVob7QexVPPEt6g2zW6d3g"
api_key = os.environ.get('PINECONE_API_KEY')
pc = Pinecone(api_key=api_key)
cloud = os.environ.get('PINECONE_CLOUD') or 'aws'
region = os.environ.get('PINECONE_REGION') or 'us-east-1'
spec = ServerlessSpec(cloud=cloud, region=region)
index_name = "hybrid-image-search"
spec = ServerlessSpec(cloud="aws", region="us-east-1")
# choose a name for your index
index_name = "hybrid-image-search"
import time
# check if index already exists (it shouldn't if this is first time)
if index_name not in pc.list_indexes().names():
# if does not exist, create index
pc.create_index(
index_name,
dimension=512,
metric='dotproduct',
spec=spec
)
# wait for index to be initialized
while not pc.describe_index(index_name).status['ready']:
time.sleep(1)
# connect to index
index = pc.Index(index_name)
# view index stats
index.describe_index_stats()
# ------------------- Dataset Loading -------------------
fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
images = fashion["image"]
metadata = fashion.remove_columns("image").to_pandas()
# ------------------- Encoders -------------------
bm25 = BM25Encoder()
bm25.fit(metadata["productDisplayName"])
model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device='cuda' if torch.cuda.is_available() else 'cpu')
from sentence_transformers import SentenceTransformer
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# load a CLIP model from huggingface
model = SentenceTransformer(
'sentence-transformers/clip-ViT-B-32',
device=device
)
model
# ------------------- Hybrid Scaling -------------------
def hybrid_scale(dense, sparse, alpha: float):
if alpha < 0 or alpha > 1:
raise ValueError("Alpha must be between 0 and 1")
# scale sparse and dense vectors to create hybrid search vecs
hsparse = {
'indices': sparse['indices'],
'values': [v * (1 - alpha) for v in sparse['values']]
}
hdense = [v * alpha for v in dense]
return hdense, hsparse
# ------------------- Metadata Filter Extraction -------------------
from PIL import Image, ImageOps
import numpy as np
from PIL import Image, ImageOps
import numpy as np
def extract_metadata_filters(query: str):
query_lower = query.lower()
gender = None
category = None
subcategory = None
color = None
# --- Gender Mapping ---
gender_map = {
"men": "Men", "man": "Men", "mens": "Men", "mans": "Men", "male": "Men",
"women": "Women", "woman": "Women", "womens": "Women", "female": "Women",
"boys": "Boys", "boy": "Boys",
"girls": "Girls", "girl": "Girls",
"kids": "Kids", "unisex": "Unisex"
}
for term, mapped_value in gender_map.items():
if term in query_lower:
gender = mapped_value
break
# --- Category Mapping ---
category_map = {
"shirt": "Shirts",
"tshirt": "Tshirts", "t-shirt": "Tshirts",
"jeans": "Jeans",
"watch": "Watches",
"kurta": "Kurtas",
"dress": "Dresses", "dresses": "Dresses",
"trousers": "Trousers", "pants": "Trousers",
"shorts": "Shorts",
"footwear": "Footwear",
"shoes": "Footwear",
"fashion": "Apparel"
}
for term, mapped_value in category_map.items():
if term in query_lower:
category = mapped_value
break
# --- SubCategory Mapping ---
subCategory_list = [
"Accessories", "Apparel Set", "Bags", "Bath and Body", "Beauty Accessories",
"Belts", "Bottomwear", "Cufflinks", "Dress", "Eyes", "Eyewear", "Flip Flops",
"Fragrance", "Free Gifts", "Gloves", "Hair", "Headwear", "Home Furnishing",
"Innerwear", "Jewellery", "Lips", "Loungewear and Nightwear", "Makeup",
"Mufflers", "Nails", "Perfumes", "Sandal", "Saree", "Scarves", "Shoe Accessories",
"Shoes", "Skin", "Skin Care", "Socks", "Sports Accessories", "Sports Equipment",
"Stoles", "Ties", "Topwear", "Umbrellas", "Vouchers", "Wallets", "Watches",
"Water Bottle", "Wristbands"
]
if "topwear" in query_lower or "top" in query_lower:
subcategory = "Topwear"
else:
for subcat in subCategory_list:
if subcat.lower() in query_lower:
subcategory = subcat
break
# --- Color Extraction ---
colors = [
"red","blue","green","yellow","black","white",
"orange","pink","purple","brown","grey","beige"
]
for c in colors:
if c in query_lower:
color = c.capitalize()
break
# --- Invalid pairs ---
invalid_pairs = {
("Men", "Dresses"), ("Men", "Sarees"), ("Men", "Skirts"),
("Boys", "Dresses"), ("Boys", "Sarees"),
("Girls", "Boxers"), ("Men", "Heels")
}
if (gender, category) in invalid_pairs:
print(f"⚠️ Invalid pair: {gender} + {category}, dropping gender")
gender = None
# fallback
if gender and not category:
category = "Apparel"
return gender, category, subcategory, color
def search_fashion(query: str, alpha: float):
gender, category, subcategory, color = extract_metadata_filters(query)
# Build Pinecone filter
filter = {}
if gender:
filter["gender"] = gender
if category:
filter["articleType"] = category
if subcategory:
filter["subCategory"] = subcategory
if color:
filter["baseColour"] = color
print(f"πŸ” Using filter: {filter}")
# hybrid
sparse = bm25.encode_queries(query)
dense = model.encode(query).tolist()
hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha)
# initial search
result = index.query(
top_k=12,
vector=hdense,
sparse_vector=hsparse,
include_metadata=True,
filter=filter if filter else None
)
# fallback: if zero results with gender, relax gender
if gender and len(result["matches"]) == 0:
print(f"⚠️ No results with gender {gender}, relaxing gender filter")
filter.pop("gender")
result = index.query(
top_k=12,
vector=hdense,
sparse_vector=hsparse,
include_metadata=True,
filter=filter if filter else None
)
# results
imgs_with_captions = []
for r in result["matches"]:
idx = int(r["id"])
img = images[idx]
meta = r.get("metadata", {})
if not isinstance(img, Image.Image):
img = Image.fromarray(np.array(img))
padded = ImageOps.pad(img, (256, 256), color="white")
caption = str(meta.get("productDisplayName", "Unknown Product"))
imgs_with_captions.append((padded, caption))
return imgs_with_captions
from transformers import CLIPProcessor, CLIPModel
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
from PIL import Image, ImageOps
import numpy as np
def search_by_image(uploaded_image, alpha=0.5):
"""
Given a PIL image from Gradio, find visually similar products.
"""
# Preprocess as CLIP expects
processed = clip_processor(images=uploaded_image, return_tensors="pt").to(device)
with torch.no_grad():
image_vec = clip_model.get_image_features(**processed)
image_vec = image_vec.cpu().numpy().flatten().tolist()
# since your Pinecone is purely visual, we query on visual vector
result = index.query(
top_k=12,
vector=image_vec,
include_metadata=True
)
imgs_with_captions = []
for r in result["matches"]:
idx = int(r["id"])
img = images[idx]
meta = r.get("metadata", {})
if not isinstance(img, Image.Image):
img = Image.fromarray(np.array(img))
padded = ImageOps.pad(img, (256, 256), color="white")
caption = str(meta.get("productDisplayName", "Unknown Product"))
imgs_with_captions.append((padded, caption))
return imgs_with_captions
custom_css = """
.search-btn {
width: 100%;
}
.gr-row {
gap: 8px !important; /* slightly tighter column gap */
}
.query-slider > div {
margin-bottom: 4px !important; /* reduce space between textbox and slider */
}
"""
# with gr.Blocks(css=custom_css) as demo:
# gr.Markdown("# πŸ›οΈ Fashion Product Hybrid Search")
# with gr.Row(equal_height=True):
# with gr.Column(scale=5, elem_classes="query-slider"):
# query = gr.Textbox(
# label="Enter your fashion search query",
# placeholder="Type something or leave blank to only use the image"
# )
# alpha = gr.Slider(
# 0, 1, value=0.5,
# label="Hybrid Weight (alpha: 0=sparse, 1=dense)"
# )
# with gr.Column(scale=1):
# image_input = gr.Image(
# type="pil",
# label="Upload an image (optional)",
# height=256,
# width=356,
# show_label=True
# )
# search_btn = gr.Button("Search", elem_classes="search-btn")
# gallery = gr.Gallery(
# label="Search Results",
# columns=6,
# height="40vh"
# )
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("# πŸ›οΈ Fashion Product Hybrid Search")
with gr.Row(equal_height=True):
with gr.Column(scale=5, elem_classes="query-slider"):
query = gr.Textbox(
label="Enter your fashion search query",
placeholder="Type something or leave blank to only use the image"
)
alpha = gr.Slider(
0, 1, value=0.5,
label="Hybrid Weight (alpha: 0=sparse, 1=dense)"
)
with gr.Column(scale=1):
image_input = gr.Image(
source="webcam", # πŸ‘ˆ Enables webcam button
type="pil",
label="πŸ“· Capture or Upload Image",
height=256,
width=356,
show_label=True
)
search_btn = gr.Button("Search", elem_classes="search-btn")
gallery = gr.Gallery(
label="Search Results",
columns=6,
height="40vh"
)
def unified_search(q, uploaded_image, a):
if uploaded_image is not None:
return search_by_image(uploaded_image, a)
elif q.strip() != "":
return search_fashion(q, a)
else:
return []
search_btn.click(
unified_search,
inputs=[query, image_input, alpha],
outputs=gallery
)
gr.Markdown("Powered by your hybrid AI search model πŸš€")
demo.launch()