galsaar commited on
Commit
bc1d54b
·
verified ·
1 Parent(s): 0ec0e2d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import gradio as gr
4
+ import torch
5
+ from PIL import Image
6
+ from datasets import load_dataset
7
+ from transformers import CLIPModel, CLIPProcessor
8
+
9
+ EMB_PATH = "deepfashion_clip_image_embeddings.parquet"
10
+ MODEL_NAME = "openai/clip-vit-base-patch32"
11
+
12
+ df_emb = pd.read_parquet(EMB_PATH)
13
+
14
+ required_cols = {"item_ID", "category1", "category2", "embedding"}
15
+ missing = required_cols - set(df_emb.columns)
16
+ if missing:
17
+ raise ValueError(f"Missing columns in {EMB_PATH}: {missing}")
18
+
19
+ X = np.stack(df_emb["embedding"].apply(lambda x: np.asarray(x, dtype=np.float32)).to_numpy())
20
+ norms = np.linalg.norm(X, axis=1, keepdims=True)
21
+ norms = np.clip(norms, 1e-12, None)
22
+ Xn = X / norms
23
+
24
+ meta = df_emb[["item_ID", "category1", "category2"]].copy().reset_index(drop=True)
25
+
26
+ ds = load_dataset("Marqo/deepfashion-multimodal", split="data")
27
+ id_to_idx = {ds[i]["item_ID"]: i for i in range(len(ds))}
28
+
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ clip_model = CLIPModel.from_pretrained(MODEL_NAME).to(device)
31
+ clip_processor = CLIPProcessor.from_pretrained(MODEL_NAME)
32
+ clip_model.eval()
33
+
34
+
35
+ def l2_normalize(v: np.ndarray) -> np.ndarray:
36
+ v = v.astype(np.float32)
37
+ n = np.linalg.norm(v)
38
+ if n < 1e-12:
39
+ return v
40
+ return v / n
41
+
42
+
43
+ @torch.no_grad()
44
+ def embed_text(text: str) -> np.ndarray:
45
+ inputs = clip_processor(text=[text], return_tensors="pt", padding=True, truncation=True)
46
+ inputs = {k: v.to(device) for k, v in inputs.items()}
47
+ feats = clip_model.get_text_features(**inputs)
48
+ vec = feats[0].detach().cpu().numpy().astype(np.float32)
49
+ return l2_normalize(vec)
50
+
51
+
52
+ @torch.no_grad()
53
+ def embed_image_pil(img: Image.Image) -> np.ndarray:
54
+ inputs = clip_processor(images=img, return_tensors="pt")
55
+ inputs = {k: v.to(device) for k, v in inputs.items()}
56
+ feats = clip_model.get_image_features(**inputs)
57
+ vec = feats[0].detach().cpu().numpy().astype(np.float32)
58
+ return l2_normalize(vec)
59
+
60
+
61
+ def topk_recommendations(query_vec: np.ndarray, k: int = 3, exclude_item_id: str | None = None) -> pd.DataFrame:
62
+ q = l2_normalize(query_vec).astype(np.float32)
63
+ sims = Xn @ q
64
+
65
+ if exclude_item_id is not None:
66
+ mask = (meta["item_ID"].to_numpy() == exclude_item_id)
67
+ sims = sims.copy()
68
+ sims[mask] = -np.inf
69
+
70
+ k = min(k, len(sims))
71
+ idx = np.argpartition(-sims, kth=k - 1)[:k]
72
+ idx = idx[np.argsort(-sims[idx])]
73
+
74
+ out = meta.iloc[idx].copy()
75
+ out["similarity"] = sims[idx]
76
+ return out.reset_index(drop=True)
77
+
78
+
79
+ def fetch_images(rec_df: pd.DataFrame):
80
+ gallery = []
81
+ for _, row in rec_df.iterrows():
82
+ item_id = row["item_ID"]
83
+ idx = id_to_idx.get(item_id, None)
84
+ if idx is None:
85
+ continue
86
+ ex = ds[idx]
87
+ img = ex["image"]
88
+ caption = f'{row["category1"]}/{row["category2"]} ({row["similarity"]:.3f})'
89
+ gallery.append((img, caption))
90
+ return gallery
91
+
92
+
93
+ def recommend_from_text_ui(query: str):
94
+ if query is None or not query.strip():
95
+ return pd.DataFrame(columns=["item_ID", "category1", "category2", "similarity"]), []
96
+ q = embed_text(query.strip())
97
+ rec = topk_recommendations(q, k=3)
98
+ return rec, fetch_images(rec)
99
+
100
+
101
+ def recommend_from_image_ui(img: Image.Image):
102
+ if img is None:
103
+ return pd.DataFrame(columns=["item_ID", "category1", "category2", "similarity"]), []
104
+ q = embed_image_pil(img)
105
+ rec = topk_recommendations(q, k=3)
106
+ return rec, fetch_images(rec)
107
+
108
+
109
+ with gr.Blocks() as demo:
110
+ gr.Markdown("# DeepFashion CLIP Recommender (Top-3)")
111
+
112
+ with gr.Tab("Text → Top-3"):
113
+ txt = gr.Textbox(label="Describe an item", placeholder="e.g., a sleeveless summer dress with a floral pattern")
114
+ btn1 = gr.Button("Recommend")
115
+ out_table1 = gr.Dataframe(label="Top-3 results", interactive=False)
116
+ out_gallery1 = gr.Gallery(label="Top-3 images", columns=3, height=320)
117
+ btn1.click(recommend_from_text_ui, inputs=txt, outputs=[out_table1, out_gallery1])
118
+
119
+ with gr.Tab("Image → Top-3"):
120
+ img_in = gr.Image(type="pil", label="Upload an image")
121
+ btn2 = gr.Button("Recommend")
122
+ out_table2 = gr.Dataframe(label="Top-3 results", interactive=False)
123
+ out_gallery2 = gr.Gallery(label="Top-3 images", columns=3, height=320)
124
+ btn2.click(recommend_from_image_ui, inputs=img_in, outputs=[out_table2, out_gallery2])
125
+
126
+ demo.launch()