MohitG012 commited on
Commit
b7aeaae
Β·
verified Β·
1 Parent(s): 5569dae

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +158 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,161 @@
1
- import altair as alt
 
 
2
  import numpy as np
3
  import pandas as pd
4
- import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import streamlit as st
2
+ import os, pickle
3
+ import torch
4
  import numpy as np
5
  import pandas as pd
6
+ from PIL import Image as PILImage
7
+ from sentence_transformers import SentenceTransformer
8
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, CLIPProcessor, CLIPModel
9
+ import faiss
10
+
11
+ from huggingface_hub import login
12
+ # Paste your token inside the quotes
13
+ login(st.secrets["huggingface"]["token"])
14
+
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ print(device)
17
+
18
+ # Load assets
19
+ @st.cache_resource
20
+ def load_assets(asset_dir="Assets"):
21
+ with open(os.path.join(asset_dir, "image_embeddings.pkl"), "rb") as f:
22
+ image_embeddings = pickle.load(f)
23
+ with open(os.path.join(asset_dir, "text_embeddings.pkl"), "rb") as f:
24
+ text_embeddings = pickle.load(f)
25
+ with open(os.path.join(asset_dir, "product_ids.pkl"), "rb") as f:
26
+ ids = pickle.load(f)
27
+ combined_vectors = np.load(os.path.join(asset_dir, "combined_vectors.npy"))
28
+ faiss_index = faiss.read_index(os.path.join(asset_dir, "faiss_index.index"))
29
+ df = pd.read_pickle(os.path.join(asset_dir, "product_metadata_df.pkl"))
30
+ with open(os.path.join(asset_dir, "user_history.pkl"), "rb") as f:
31
+ user_history = pickle.load(f)
32
+ with open(os.path.join(asset_dir, "trend_string.pkl"), "rb") as f:
33
+ trend_string = pickle.load(f)
34
+ return image_embeddings, text_embeddings, ids, combined_vectors, faiss_index, df, user_history, trend_string
35
+
36
+ # Image + text search
37
+ def search_similar(image=None, text=None, top_k=5):
38
+ img_vec = np.zeros(768)
39
+ txt_vec = np.zeros(384)
40
+ if image:
41
+ inputs = clip_processor(images=image, return_tensors="pt").to(device)
42
+ with torch.no_grad():
43
+ img_vec = clip_model.get_image_features(**inputs).cpu().numpy()[0]
44
+ if text:
45
+ txt_vec = text_model.encode(text)
46
+ combined = np.concatenate([img_vec, txt_vec]).astype("float32")
47
+ D, I = faiss_index.search(np.array([combined]), top_k)
48
+ return [ids[i] for i in I[0]]
49
+
50
+ # Outfit suggestions
51
+ def generate_outfit_gemma(img, row, username, suggestions=5):
52
+ brands, styles, desc = summarize_user_preferences(username)
53
+ messages = [{
54
+ "role": "system",
55
+ "content": [{"type": "text", "text": "You are a highly experienced fashion stylist and personal shopper."}]
56
+ }, {
57
+ "role": "user",
58
+ "content": [
59
+ {"type": "image", "image": img.convert("RGB")},
60
+ {"type": "text", "text": f"""
61
+ Suggest {suggestions} stylish outfit items that complement this item:
62
+
63
+ **Product**:
64
+ Name: {row['product_name']}
65
+ Brand: {row['brand']}
66
+ Style: {row['style_attributes']}
67
+ Description: {row['description']}
68
+ Price: β‚Ή{row['selling_price']}
69
+
70
+ **User Likes**:
71
+ Brands: {brands}
72
+ Styles: {styles}
73
+ Liked Items: {desc}
74
+
75
+ **Trends**:
76
+ {trend_string}
77
+
78
+ Output in bullet list with name + explanation.
79
+ """}
80
+ ]
81
+ }]
82
+ prompt = gemma_processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
83
+ tokenized = gemma_processor(text=prompt, return_tensors="pt").to(model.device)
84
+ with torch.no_grad():
85
+ output = model.generate(**tokenized, max_new_tokens=300)
86
+ return gemma_processor.decode(output[0][tokenized["input_ids"].shape[-1]:], skip_special_tokens=True)
87
+
88
+ # User preference summary
89
+ def summarize_user_preferences(user_id, top_k=3):
90
+ pids = user_history.get(user_id, [])
91
+ rows = df[df["product_id"].isin(pids)]
92
+ if rows.empty:
93
+ return "None", "None", "None"
94
+ brands = rows["brand"].dropna().astype(str).value_counts().index.tolist()[:top_k]
95
+ styles = rows["style_attributes"].astype(str).value_counts().index.tolist()[:top_k]
96
+ descs = rows["meta_info"].dropna().astype(str).tolist()
97
+ return ", ".join(brands), ", ".join(styles), " ".join(descs[:top_k])
98
+
99
+ # ========== APP STARTS ==========
100
+ st.set_page_config("πŸ›οΈ Fashion Visual Search")
101
+ st.title("πŸ‘— Fashion Visual Search & Outfit Assistant")
102
+
103
+ image_embeddings, text_embeddings, ids, _, faiss_index, df, user_history, trend_string = load_assets()
104
+
105
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
106
+
107
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device).eval()
108
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14", use_fast=True)
109
+
110
+ text_model = SentenceTransformer('all-MiniLM-L6-v2')
111
+
112
+ model_id = "google/gemma-3-4b-it"
113
+ model = Gemma3ForConditionalGeneration.from_pretrained(model_id, torch_dtype=dtype, device_map="auto").eval()
114
+ gemma_processor = AutoProcessor.from_pretrained(model_id)
115
+
116
+ username = st.text_input("πŸ‘€ Enter your username:")
117
+ if username:
118
+ uploaded_image = st.file_uploader("πŸ“· Upload a fashion image", type=["jpg", "png"])
119
+ text_query = st.text_input("πŸ“ Optional: Describe what you're looking for")
120
+ num_results = st.slider("πŸ”’ Number of similar items", 1, 20, 5)
121
+ num_suggestions = st.slider("πŸ’‘ Number of outfit suggestions", 1, 10, 3)
122
+
123
+ if uploaded_image:
124
+ st.image(uploaded_image, caption="Uploaded Image", width=300)
125
+
126
+ img = PILImage.open(uploaded_image)
127
+ similar_ids = search_similar(image=img, text=text_query, top_k=num_results)
128
+ st.subheader("🎯 Similar Products")
129
+ for pid in similar_ids:
130
+ row = df[df["product_id"] == pid].iloc[0]
131
+ st.image(row["feature_image_s3"], width=200)
132
+ st.write(f"**{row['product_name']}** β€” β‚Ή{row['selling_price']}")
133
+ st.write(f"Brand: {row['brand']}")
134
+ if username not in user_history:
135
+ user_history[username] = []
136
+ user_history[username].append(pid)
137
+
138
+ st.subheader("🧠 Outfit Suggestions")
139
+ top_row = df[df["product_id"] == similar_ids[0]].iloc[0]
140
+ suggestions = generate_outfit_gemma(img, top_row, username, suggestions=num_suggestions)
141
+ st.markdown(suggestions)
142
+
143
+ st.subheader("🧾 Inventory Text Search")
144
+ text_only_ids = search_similar(image=None, text=text_query, top_k=num_results)
145
+ for pid in text_only_ids:
146
+ row = df[df["product_id"] == pid].iloc[0]
147
+ st.image(row["feature_image_s3"], width=200)
148
+ st.write(f"{row['product_name']} β€” β‚Ή{row['selling_price']}")
149
+ st.write(f"Brand: {row['brand']}")
150
 
151
+ st.subheader("πŸ“¦ Personalized History-Based Suggestions")
152
+ brands, styles, desc = summarize_user_preferences(username)
153
+ if brands == "None":
154
+ st.warning("⚠️ No history found yet. Try uploading images first!")
155
+ else:
156
+ hist_ids = [pid for pid in ids if any(b in text_embeddings[pid] for b in brands.split(", "))][:num_results]
157
+ for pid in hist_ids:
158
+ row = df[df["product_id"] == pid].iloc[0]
159
+ st.image(row["feature_image_s3"], width=200)
160
+ st.write(f"{row['product_name']} β€” β‚Ή{row['selling_price']}")
161
+ st.write(f"Brand: {row['brand']}")