DeepFashion / app.py
galsaar's picture
Update app.py
169068d verified
import numpy as np
import pandas as pd
import gradio as gr
import torch
from PIL import Image
from datasets import load_dataset
from transformers import CLIPModel, CLIPProcessor
# -------------------------
# Config
# -------------------------
EMB_PATH = "deepfashion_clip_image_embeddings.parquet"
MODEL_NAME = "openai/clip-vit-base-patch32"
LOOM_SHARE = "https://www.loom.com/share/cc9bac8c26104c7caaca2a10968deada"
LOOM_EMBED = "https://www.loom.com/embed/cc9bac8c26104c7caaca2a10968deada"
VIDEO_HTML = f"""
<div style="width:100%; height:420px;">
<iframe src="{LOOM_EMBED}"
style="width:100%; height:100%; border:0;"
allowfullscreen></iframe>
</div>
<p>
<a href="{LOOM_SHARE}" target="_blank" rel="noopener noreferrer">
Open video in Loom
</a>
</p>
"""
# -------------------------
# Load embeddings + metadata
# -------------------------
df_emb = pd.read_parquet(EMB_PATH)
required_cols = {"item_ID", "category1", "category2", "embedding"}
missing = required_cols - set(df_emb.columns)
if missing:
raise ValueError(f"Missing columns in {EMB_PATH}: {missing}")
X = np.stack(df_emb["embedding"].apply(lambda x: np.asarray(x, dtype=np.float32)).to_numpy())
norms = np.linalg.norm(X, axis=1, keepdims=True)
norms = np.clip(norms, 1e-12, None)
Xn = X / norms
meta = df_emb[["item_ID", "category1", "category2"]].copy().reset_index(drop=True)
# -------------------------
# Load dataset (for images) + ID mapping
# -------------------------
ds = load_dataset("Marqo/deepfashion-multimodal", split="data")
id_to_idx = {ds[i]["item_ID"]: i for i in range(len(ds))}
# -------------------------
# Load CLIP
# -------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model = CLIPModel.from_pretrained(MODEL_NAME).to(device)
clip_processor = CLIPProcessor.from_pretrained(MODEL_NAME)
clip_model.eval()
def l2_normalize(v: np.ndarray) -> np.ndarray:
v = v.astype(np.float32)
n = np.linalg.norm(v)
if n < 1e-12:
return v
return v / n
@torch.no_grad()
def embed_text(text: str) -> np.ndarray:
inputs = clip_processor(text=[text], return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
feats = clip_model.get_text_features(**inputs)
vec = feats[0].detach().cpu().numpy().astype(np.float32)
return l2_normalize(vec)
@torch.no_grad()
def embed_image_pil(img: Image.Image) -> np.ndarray:
inputs = clip_processor(images=img, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
feats = clip_model.get_image_features(**inputs)
vec = feats[0].detach().cpu().numpy().astype(np.float32)
return l2_normalize(vec)
def topk_recommendations(query_vec: np.ndarray, k: int = 3, exclude_item_id: str | None = None) -> pd.DataFrame:
q = l2_normalize(query_vec).astype(np.float32)
sims = Xn @ q
if exclude_item_id is not None:
mask = (meta["item_ID"].to_numpy() == exclude_item_id)
sims = sims.copy()
sims[mask] = -np.inf
k = min(k, len(sims))
idx = np.argpartition(-sims, kth=k - 1)[:k]
idx = idx[np.argsort(-sims[idx])]
out = meta.iloc[idx].copy()
out["similarity"] = sims[idx]
return out.reset_index(drop=True)
def fetch_images(rec_df: pd.DataFrame):
gallery = []
for _, row in rec_df.iterrows():
item_id = row["item_ID"]
idx = id_to_idx.get(item_id, None)
if idx is None:
continue
ex = ds[idx]
img = ex["image"]
caption = f'{row["category1"]}/{row["category2"]} ({row["similarity"]:.3f})'
gallery.append((img, caption))
return gallery
def recommend_from_text_ui(query: str):
if query is None or not query.strip():
return pd.DataFrame(columns=["item_ID", "category1", "category2", "similarity"]), []
q = embed_text(query.strip())
rec = topk_recommendations(q, k=3)
return rec, fetch_images(rec)
def recommend_from_image_ui(img: Image.Image):
if img is None:
return pd.DataFrame(columns=["item_ID", "category1", "category2", "similarity"]), []
q = embed_image_pil(img)
rec = topk_recommendations(q, k=3)
return rec, fetch_images(rec)
# -------------------------
# Gradio UI (single Blocks)
# -------------------------
with gr.Blocks() as demo:
gr.Markdown("# DeepFashion CLIP Recommender (Top-3)")
with gr.Tab("Text → Top-3"):
txt = gr.Textbox(label="Describe an item", placeholder="e.g., a sleeveless summer dress with a floral pattern")
btn1 = gr.Button("Recommend")
out_table1 = gr.Dataframe(label="Top-3 results", interactive=False)
out_gallery1 = gr.Gallery(label="Top-3 images", columns=3, height=320)
btn1.click(recommend_from_text_ui, inputs=txt, outputs=[out_table1, out_gallery1])
with gr.Tab("Image → Top-3"):
img_in = gr.Image(type="pil", label="Upload an image")
btn2 = gr.Button("Recommend")
out_table2 = gr.Dataframe(label="Top-3 results", interactive=False)
out_gallery2 = gr.Gallery(label="Top-3 images", columns=3, height=320)
btn2.click(recommend_from_image_ui, inputs=img_in, outputs=[out_table2, out_gallery2])
with gr.Tab("Video"):
gr.Markdown("## Presentation video")
gr.Markdown(f"🎥 Loom link: {LOOM_SHARE}")
gr.HTML(VIDEO_HTML)
demo.queue()
demo.launch()