MohitG012's picture
Update src/streamlit_app.py
214daca verified
import os
import streamlit as st
# Fix permission errors by forcing writable directories
os.environ["HF_HOME"] = "/tmp/hf_home"
# os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/sbert"
os.environ["XDG_CONFIG_HOME"] = "/tmp/.config"
os.environ["STREAMLIT_HOME"] = "/tmp/.streamlit"
# Make those folders if not already present
# for path in ["/tmp/hf_home", "/tmp/hf_cache", "/tmp/sbert", "/tmp/.config", "/tmp/.streamlit"]:
# os.makedirs(path, exist_ok=True)
# HuggingFace login (requires HF_TOKEN to be added as secret in Hugging Face Spaces)
from huggingface_hub import login
login(os.environ.get("HF_TOKEN", ""))
import streamlit as st
import torch
import numpy as np
import pandas as pd
from PIL import Image as PILImage
from sentence_transformers import SentenceTransformer
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, CLIPProcessor, CLIPModel, BitsAndBytesConfig
import faiss
import pickle
# Load assets
@st.cache_resource
def load_assets(asset_dir = os.path.join(os.path.dirname(__file__), "Assets")):
with open(os.path.join(asset_dir, "image_embeddings.pkl"), "rb") as f:
image_embeddings = pickle.load(f)
with open(os.path.join(asset_dir, "text_embeddings.pkl"), "rb") as f:
text_embeddings = pickle.load(f)
with open(os.path.join(asset_dir, "product_ids.pkl"), "rb") as f:
ids = pickle.load(f)
combined_vectors = np.load(os.path.join(asset_dir, "combined_vectors.npy"))
faiss_index = faiss.read_index(os.path.join(asset_dir, "faiss_index.index"))
df = pd.read_pickle(os.path.join(asset_dir, "product_metadata_df.pkl"))
with open(os.path.join(asset_dir, "user_history.pkl"), "rb") as f:
user_history = pickle.load(f)
with open(os.path.join(asset_dir, "trend_string.pkl"), "rb") as f:
trend_string = pickle.load(f)
return image_embeddings, text_embeddings, ids, combined_vectors, faiss_index, df, user_history, trend_string
# Image + text search
def search_similar(image=None, text=None, top_k=5):
img_vec = np.zeros(768)
txt_vec = np.zeros(384)
if image:
inputs = clip_processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
img_vec = clip_model.get_image_features(**inputs).cpu().numpy()[0]
if text:
txt_vec = text_model.encode(text)
combined = np.concatenate([img_vec, txt_vec]).astype("float32")
D, I = faiss_index.search(np.array([combined]), top_k)
return [ids[i] for i in I[0]]
# Outfit suggestions
def generate_outfit_gemma(img, row, username, suggestions=5):
brands, styles, desc = summarize_user_preferences(username)
messages = [{
"role": "system",
"content": [{"type": "text", "text": "You are a highly experienced fashion stylist and personal shopper."}]
}, {
"role": "user",
"content": [
{"type": "image", "image": img.convert("RGB")},
{"type": "text", "text": f"""
Suggest {suggestions} stylish outfit items that complement this item:
**Product**:
Name: {row['product_name']}
Brand: {row['brand']}
Style: {row['style_attributes']}
Description: {row['description']}
Price: β‚Ή{row['selling_price']}
**User Likes**:
Brands: {brands}
Styles: {styles}
Liked Items: {desc}
**Trends**:
{trend_string}
Output in bullet list with name + explanation.
"""}
]
}]
prompt = gemma_processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
tokenized = gemma_processor(text=prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(**tokenized, max_new_tokens=300)
return gemma_processor.decode(output[0][tokenized["input_ids"].shape[-1]:], skip_special_tokens=True)
# User preference summary
def summarize_user_preferences(user_id, top_k=3):
pids = user_history.get(user_id, [])
rows = df[df["product_id"].isin(pids)]
if rows.empty:
return "None", "None", "None"
brands = rows["brand"].dropna().astype(str).value_counts().index.tolist()[:top_k]
styles = rows["style_attributes"].astype(str).value_counts().index.tolist()[:top_k]
descs = rows["meta_info"].dropna().astype(str).tolist()
return ", ".join(brands), ", ".join(styles), " ".join(descs[:top_k])
# ========== APP STARTS ==========
import streamlit.runtime.metrics_util
streamlit.runtime.metrics_util._config_file = "/tmp/.config/streamlit/config.toml"
# os.makedirs("/tmp/.config/streamlit", exist_ok=True)
st.set_page_config("πŸ›οΈ Fashion Visual Search")
st.title("πŸ‘— Fashion Visual Search & Outfit Assistant")
image_embeddings, text_embeddings, ids, _, faiss_index, df, user_history, trend_string = load_assets()
# Load models
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14", cache_dir="/tmp/hf_cache")
clip_model.eval()
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14", use_fast=True, cache_dir="/tmp/hf_cache")
text_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', cache_folder="/tmp/sbert", device="cpu")
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float32,
device_map="auto",
cache_dir="/tmp/hf_cache"
).eval()
gemma_processor = AutoProcessor.from_pretrained(model_id, cache_dir="/tmp/hf_cache")
username = st.text_input("πŸ‘€ Enter your username:")
if username not in user_history:
user_history[username] = []
if username:
uploaded_image = st.file_uploader("πŸ“· Upload a fashion image", type=["jpg", "png"])
text_query = st.text_input("πŸ“ Optional: Describe what you're looking for")
num_results = st.slider("πŸ”’ Number of similar items", 1, 20, 5)
num_suggestions = st.slider("πŸ’‘ Number of outfit suggestions", 1, 10, 3)
if uploaded_image:
st.image(uploaded_image, caption="Uploaded Image", width=300)
img = PILImage.open(uploaded_image)
similar_ids = search_similar(image=img, text=text_query, top_k=num_results)
st.subheader("🎯 Similar Products")
for pid in similar_ids:
row = df[df["product_id"] == pid].iloc[0]
st.image(row["feature_image_s3"], width=200)
st.write(f"**{row['product_name']}** β€” β‚Ή{row['selling_price']}")
st.write(f"Brand: {row['brand']}")
if username not in user_history:
user_history[username] = []
user_history[username].append(pid)
st.subheader("🧠 Outfit Suggestions")
top_row = df[df["product_id"] == similar_ids[0]].iloc[0]
suggestions = generate_outfit_gemma(img, top_row, username, suggestions=num_suggestions)
st.markdown(suggestions)
st.subheader("🧾 Inventory Text Search")
text_only_ids = search_similar(image=None, text=text_query, top_k=num_results)
for pid in text_only_ids:
row = df[df["product_id"] == pid].iloc[0]
st.image(row["feature_image_s3"], width=200)
st.write(f"{row['product_name']} β€” β‚Ή{row['selling_price']}")
st.write(f"Brand: {row['brand']}")
st.subheader("πŸ“¦ Personalized History-Based Suggestions")
brands, styles, desc = summarize_user_preferences(username)
if brands == "None":
st.warning("⚠️ No history found yet. Try uploading images first!")
else:
hist_ids = [pid for pid in ids if any(b in text_embeddings[pid] for b in brands.split(", "))][:num_results]
for pid in hist_ids:
row = df[df["product_id"] == pid].iloc[0]
st.image(row["feature_image_s3"], width=200)
st.write(f"{row['product_name']} β€” β‚Ή{row['selling_price']}")
st.write(f"Brand: {row['brand']}")