|
|
import os |
|
|
import streamlit as st |
|
|
|
|
|
|
|
|
os.environ["HF_HOME"] = "/tmp/hf_home" |
|
|
|
|
|
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/sbert" |
|
|
os.environ["XDG_CONFIG_HOME"] = "/tmp/.config" |
|
|
os.environ["STREAMLIT_HOME"] = "/tmp/.streamlit" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from huggingface_hub import login |
|
|
login(os.environ.get("HF_TOKEN", "")) |
|
|
|
|
|
import streamlit as st |
|
|
import torch |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from PIL import Image as PILImage |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, CLIPProcessor, CLIPModel, BitsAndBytesConfig |
|
|
import faiss |
|
|
import pickle |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_assets(asset_dir = os.path.join(os.path.dirname(__file__), "Assets")): |
|
|
with open(os.path.join(asset_dir, "image_embeddings.pkl"), "rb") as f: |
|
|
image_embeddings = pickle.load(f) |
|
|
with open(os.path.join(asset_dir, "text_embeddings.pkl"), "rb") as f: |
|
|
text_embeddings = pickle.load(f) |
|
|
with open(os.path.join(asset_dir, "product_ids.pkl"), "rb") as f: |
|
|
ids = pickle.load(f) |
|
|
combined_vectors = np.load(os.path.join(asset_dir, "combined_vectors.npy")) |
|
|
faiss_index = faiss.read_index(os.path.join(asset_dir, "faiss_index.index")) |
|
|
df = pd.read_pickle(os.path.join(asset_dir, "product_metadata_df.pkl")) |
|
|
with open(os.path.join(asset_dir, "user_history.pkl"), "rb") as f: |
|
|
user_history = pickle.load(f) |
|
|
with open(os.path.join(asset_dir, "trend_string.pkl"), "rb") as f: |
|
|
trend_string = pickle.load(f) |
|
|
return image_embeddings, text_embeddings, ids, combined_vectors, faiss_index, df, user_history, trend_string |
|
|
|
|
|
|
|
|
def search_similar(image=None, text=None, top_k=5): |
|
|
img_vec = np.zeros(768) |
|
|
txt_vec = np.zeros(384) |
|
|
if image: |
|
|
inputs = clip_processor(images=image, return_tensors="pt").to(device) |
|
|
with torch.no_grad(): |
|
|
img_vec = clip_model.get_image_features(**inputs).cpu().numpy()[0] |
|
|
if text: |
|
|
txt_vec = text_model.encode(text) |
|
|
combined = np.concatenate([img_vec, txt_vec]).astype("float32") |
|
|
D, I = faiss_index.search(np.array([combined]), top_k) |
|
|
return [ids[i] for i in I[0]] |
|
|
|
|
|
|
|
|
def generate_outfit_gemma(img, row, username, suggestions=5): |
|
|
brands, styles, desc = summarize_user_preferences(username) |
|
|
messages = [{ |
|
|
"role": "system", |
|
|
"content": [{"type": "text", "text": "You are a highly experienced fashion stylist and personal shopper."}] |
|
|
}, { |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image", "image": img.convert("RGB")}, |
|
|
{"type": "text", "text": f""" |
|
|
Suggest {suggestions} stylish outfit items that complement this item: |
|
|
|
|
|
**Product**: |
|
|
Name: {row['product_name']} |
|
|
Brand: {row['brand']} |
|
|
Style: {row['style_attributes']} |
|
|
Description: {row['description']} |
|
|
Price: βΉ{row['selling_price']} |
|
|
|
|
|
**User Likes**: |
|
|
Brands: {brands} |
|
|
Styles: {styles} |
|
|
Liked Items: {desc} |
|
|
|
|
|
**Trends**: |
|
|
{trend_string} |
|
|
|
|
|
Output in bullet list with name + explanation. |
|
|
"""} |
|
|
] |
|
|
}] |
|
|
prompt = gemma_processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
|
|
tokenized = gemma_processor(text=prompt, return_tensors="pt").to(model.device) |
|
|
with torch.no_grad(): |
|
|
output = model.generate(**tokenized, max_new_tokens=300) |
|
|
return gemma_processor.decode(output[0][tokenized["input_ids"].shape[-1]:], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
def summarize_user_preferences(user_id, top_k=3): |
|
|
pids = user_history.get(user_id, []) |
|
|
rows = df[df["product_id"].isin(pids)] |
|
|
if rows.empty: |
|
|
return "None", "None", "None" |
|
|
brands = rows["brand"].dropna().astype(str).value_counts().index.tolist()[:top_k] |
|
|
styles = rows["style_attributes"].astype(str).value_counts().index.tolist()[:top_k] |
|
|
descs = rows["meta_info"].dropna().astype(str).tolist() |
|
|
return ", ".join(brands), ", ".join(styles), " ".join(descs[:top_k]) |
|
|
|
|
|
|
|
|
import streamlit.runtime.metrics_util |
|
|
streamlit.runtime.metrics_util._config_file = "/tmp/.config/streamlit/config.toml" |
|
|
|
|
|
|
|
|
st.set_page_config("ποΈ Fashion Visual Search") |
|
|
st.title("π Fashion Visual Search & Outfit Assistant") |
|
|
|
|
|
image_embeddings, text_embeddings, ids, _, faiss_index, df, user_history, trend_string = load_assets() |
|
|
|
|
|
|
|
|
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14", cache_dir="/tmp/hf_cache") |
|
|
clip_model.eval() |
|
|
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14", use_fast=True, cache_dir="/tmp/hf_cache") |
|
|
|
|
|
text_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', cache_folder="/tmp/sbert", device="cpu") |
|
|
|
|
|
model_id = "google/gemma-3-4b-it" |
|
|
model = Gemma3ForConditionalGeneration.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float32, |
|
|
device_map="auto", |
|
|
cache_dir="/tmp/hf_cache" |
|
|
).eval() |
|
|
gemma_processor = AutoProcessor.from_pretrained(model_id, cache_dir="/tmp/hf_cache") |
|
|
|
|
|
username = st.text_input("π€ Enter your username:") |
|
|
|
|
|
if username not in user_history: |
|
|
user_history[username] = [] |
|
|
|
|
|
if username: |
|
|
uploaded_image = st.file_uploader("π· Upload a fashion image", type=["jpg", "png"]) |
|
|
text_query = st.text_input("π Optional: Describe what you're looking for") |
|
|
num_results = st.slider("π’ Number of similar items", 1, 20, 5) |
|
|
num_suggestions = st.slider("π‘ Number of outfit suggestions", 1, 10, 3) |
|
|
|
|
|
if uploaded_image: |
|
|
st.image(uploaded_image, caption="Uploaded Image", width=300) |
|
|
|
|
|
img = PILImage.open(uploaded_image) |
|
|
similar_ids = search_similar(image=img, text=text_query, top_k=num_results) |
|
|
st.subheader("π― Similar Products") |
|
|
for pid in similar_ids: |
|
|
row = df[df["product_id"] == pid].iloc[0] |
|
|
st.image(row["feature_image_s3"], width=200) |
|
|
st.write(f"**{row['product_name']}** β βΉ{row['selling_price']}") |
|
|
st.write(f"Brand: {row['brand']}") |
|
|
if username not in user_history: |
|
|
user_history[username] = [] |
|
|
user_history[username].append(pid) |
|
|
|
|
|
st.subheader("π§ Outfit Suggestions") |
|
|
top_row = df[df["product_id"] == similar_ids[0]].iloc[0] |
|
|
suggestions = generate_outfit_gemma(img, top_row, username, suggestions=num_suggestions) |
|
|
st.markdown(suggestions) |
|
|
|
|
|
st.subheader("π§Ύ Inventory Text Search") |
|
|
text_only_ids = search_similar(image=None, text=text_query, top_k=num_results) |
|
|
for pid in text_only_ids: |
|
|
row = df[df["product_id"] == pid].iloc[0] |
|
|
st.image(row["feature_image_s3"], width=200) |
|
|
st.write(f"{row['product_name']} β βΉ{row['selling_price']}") |
|
|
st.write(f"Brand: {row['brand']}") |
|
|
|
|
|
st.subheader("π¦ Personalized History-Based Suggestions") |
|
|
brands, styles, desc = summarize_user_preferences(username) |
|
|
if brands == "None": |
|
|
st.warning("β οΈ No history found yet. Try uploading images first!") |
|
|
else: |
|
|
hist_ids = [pid for pid in ids if any(b in text_embeddings[pid] for b in brands.split(", "))][:num_results] |
|
|
for pid in hist_ids: |
|
|
row = df[df["product_id"] == pid].iloc[0] |
|
|
st.image(row["feature_image_s3"], width=200) |
|
|
st.write(f"{row['product_name']} β βΉ{row['selling_price']}") |
|
|
st.write(f"Brand: {row['brand']}") |