bareethul's picture
Update app.py
c4710ad verified
import os, io, math, itertools, textwrap, json
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
import torch
from torch.nn.functional import normalize
from transformers import CLIPProcessor, CLIPModel
import datasets as hfds
import gradio as gr
# --- Load pre-saved metadata and embeddings (fast startup) ---
import numpy as np, pandas as pd, torch
# --- Load pre-saved metadata and embeddings (fast startup) ---
import numpy as np, pandas as pd, torch, glob, os, random
from datasets import load_dataset
print("Loading cached outfit data...")
df_items = pd.read_csv("closet_items.csv")
img_embs = torch.tensor(np.load("closet_clip_embs.npy"))
print(f"Loaded {len(df_items)} items, embeddings: {img_embs.shape}")
# --- Correct image folder ---
img_dir = "closet_small" # your unzipped image folder
print(f"Using image folder: {img_dir}")
# --- Unzip automatically if needed ---
if os.path.exists("closet_small.zip") and not os.path.isdir(img_dir):
import zipfile
with zipfile.ZipFile("closet_small.zip", "r") as zip_ref:
zip_ref.extractall(img_dir)
print(f"✅ Unzipped images into: {img_dir}")
# --- Build full paths for images ---
if "path" not in df_items.columns or df_items["path"].isnull().all():
# if CSV doesn’t already have a 'path' column
img_files = sorted(glob.glob(os.path.join(img_dir, "*.png")) + glob.glob(os.path.join(img_dir, "*.jpg")))
if len(img_files) >= len(df_items):
df_items["path"] = img_files[:len(df_items)]
else:
df_items["path"] = [os.path.join(img_dir, f"image_{i}.png") for i in range(len(df_items))]
# --- Verify file existence ---
df_items = df_items[df_items["path"].apply(os.path.exists)].reset_index(drop=True)
print(f"✅ Closet items ready: {len(df_items)} images found in '{img_dir}'.")
print(df_items.head(3))
# 2.1) Optional: Load Letterboxd dataset and helpers
# If the dataset is private, authenticate first:
# !huggingface-cli login
try:
lb_ds = load_dataset("pkchwy/letterboxd-all-movie-data")
if "train" in lb_ds:
lb = lb_ds["train"]
elif "default" in lb_ds:
lb = lb_ds["default"]
else:
lb = next(iter(lb_ds.values()))
print("Loaded Letterboxd dataset:", lb)
except Exception as e:
lb = None
print("Letterboxd dataset not available:", e)
_LBOX_TEXT_FIELDS = ["review_text", "review", "content", "text", "body", "description"]
_LBOX_TITLE_FIELDS = ["movie_title", "title", "film_name", "name"]
def _extract_text(example):
for k in _LBOX_TEXT_FIELDS:
if k in example and isinstance(example[k], str) and example[k].strip():
return example[k].strip()
if "tags" in example and isinstance(example["tags"], list) and example["tags"]:
return " ".join(map(str, example["tags"]))
return None
def _extract_title(example):
for k in _LBOX_TITLE_FIELDS:
if k in example and isinstance(example[k], str) and example[k].strip():
return example[k].strip()
return "(untitled)"
def sample_letterboxd_review(max_tries=50):
if lb is None or len(lb) == 0:
return None, None
for _ in range(max_tries):
ex = lb[random.randrange(0, len(lb))]
txt = _extract_text(ex)
if txt:
return _extract_title(ex), txt
return None, None
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID = "openai/clip-vit-base-patch32"
clip_model = CLIPModel.from_pretrained(MODEL_ID).to(device)
clip_proc = CLIPProcessor.from_pretrained(MODEL_ID)
# Use a local folder that always exists in Spaces or Colab
EMB_CACHE = os.path.join(os.getcwd(), "closet_clip_embeds.pt")
def pil_open(path):
img = Image.open(path).convert("RGB")
return img
@torch.no_grad()
def embed_images(paths, batch_size=32):
if not paths: # Handle empty list of paths
return torch.empty(0, clip_model.config.projection_dim) # Return an empty tensor of the correct dimension
embs = []
for i in tqdm(range(0, len(paths), batch_size), desc="Embedding images"):
batch_paths = paths[i:i+batch_size]
images = [pil_open(p) for p in batch_paths]
inp = clip_proc(images=images, return_tensors="pt", padding=True).to(device)
feats = clip_model.get_image_features(**inp)
feats = normalize(feats, dim=-1)
embs.append(feats.cpu())
return torch.cat(embs, dim=0) # (N, D)
# Build (or load) embedding cache
if os.path.exists(EMB_CACHE):
try:
cache = torch.load(EMB_CACHE)
img_embs = cache["img_embs"]
print("Loaded cached embeddings:", img_embs.shape)
except Exception as e:
print(f"Error loading cache: {e}. Re-embedding.")
if not df_items.empty:
img_embs = embed_images(df_items["path"].tolist())
torch.save({"img_embs": img_embs}, EMB_CACHE)
print("Saved cache:", EMB_CACHE)
else:
img_embs = torch.empty(0, clip_model.config.projection_dim) # Initialize with empty tensor
print("No items in the closet, no embeddings to compute.")
else:
if not df_items.empty:
img_embs = embed_images(df_items["path"].tolist())
torch.save({"img_embs": img_embs}, EMB_CACHE)
print("Saved cache:", EMB_CACHE)
else:
img_embs = torch.empty(0, clip_model.config.projection_dim) # Initialize with empty tensor
print("No items in the closet, no embeddings to compute.")
@torch.no_grad()
def embed_text(s: str):
# Simple: use the raw review text; you can add templates later.
inputs = clip_proc(text=[s], return_tensors="pt").to(device)
txt = clip_model.get_text_features(**inputs)
return normalize(txt, dim=-1).cpu()[0] # (D,)
def cos_sim(a: torch.Tensor, b: torch.Tensor):
# a: (D,), b: (N, D)
return (b @ a) # because both are normalized
# --- Filtering helpers ---
def filter_index(df, gender_pref=None, season_pref=None):
idx = np.ones(len(df), dtype=bool)
if gender_pref and gender_pref.lower() != "any":
idx &= df["gender"].fillna("").str.lower().str.contains(gender_pref.lower())
if season_pref and season_pref.lower() != "any":
idx &= df["seasonality"].fillna("").str.lower().str.contains(season_pref.lower())
return df[idx].reset_index(drop=True)
# --- pick top-k per category by text–image similarity ---
def topk_by_category(txt_emb, df, categories_k, item_embs):
"""
categories_k: dict like {"top":20, "bottom":20, "shoes":10, "outer":10, "onepiece":20}
returns dict: cat -> dataframe with columns [idx_in_df, score] sorted desc
"""
out = {}
sims = cos_sim(txt_emb, item_embs) # (N,)
for cat, k in categories_k.items():
mask = (df["category"] == cat).values
if mask.sum() == 0:
out[cat] = pd.DataFrame(columns=["idx", "score"])
continue
cat_idx = np.where(mask)[0]
cat_scores = sims[cat_idx].numpy()
order = np.argsort(-cat_scores)[:k]
out[cat] = pd.DataFrame({
"idx": cat_idx[order],
"score": cat_scores[order]
})
return out
# --- simple color compatibility ---
PALETTES = {
"neutrals": {"black","white","gray","beige","silver"},
"earthy": {"brown","beige","olive","gold"},
"cool": {"blue","navy","teal","green"},
"warm": {"red","orange","yellow","pink","burgundy"}
}
def color_score(*colors, c_outer=None):
"""
Flexible color compatibility scorer.
Accepts 2 or 3 colors (e.g., top/bottom/shoes or onepiece/shoes),
plus optional outer layer.
"""
def bucket(c):
if not isinstance(c, str):
return "neutrals"
x = c.lower()
for b, s in PALETTES.items():
if any(k in x for k in s):
return b
return "neutrals"
# Handle missing or extra colors gracefully
bs = [bucket(c) for c in colors if c is not None]
if c_outer:
bs.append(bucket(c_outer))
if len(bs) == 0:
return 0.0
# reward if majority colors are same bucket OR contain neutrals
majority = max(bs, key=bs.count)
maj_ct = bs.count(majority)
bonus = 0.04 if maj_ct >= len(bs) - 1 else 0.0
if any(b == "neutrals" for b in bs):
bonus += 0.02
return bonus
# --- Season & aesthetic nudges from text ---
def season_nudge(text, seasonality):
t = text.lower()
maps = {"summer":["summer","heat","beach"], "winter":["winter","cold","snow","coat"],
"transitional":["fall","spring","breeze","layer"], "all-season":[]}
if seasonality is None: return 0.0
for k, cues in maps.items():
if k in seasonality.lower() and any(w in t for w in cues):
return 0.03
return 0.0
def aesthetic_nudge(text, aesthetic):
if aesthetic and any(w in text.lower() for w in aesthetic.lower().split()):
return 0.03
return 0.0
# --- Score a combination ---
def score_combo(txt_sim_scores, items_meta, user_text, outer=None, weights=None):
"""
txt_sim_scores: list of per-item CLIP sims [top, bottom, shoes] or [onepiece, shoes]
items_meta: list of metadata rows matching above
outer: (sim, meta) or None
weights: dict of weights
"""
W = {
"sim": 1.0, # average CLIP similarity
"color": 1.0,
"season": 1.0,
"aesthetic": 1.0,
"outer_penalty": 0.0 # not used now
}
if weights: W.update(weights)
base_sim = np.mean(txt_sim_scores)
cols = [m["color"] for m in items_meta]
cols_outer = outer[1]["color"] if outer else None
c_bonus = color_score(*cols, c_outer=cols_outer)
s_bonus = np.mean([season_nudge(user_text, m["seasonality"]) for m in items_meta])
a_bonus = np.mean([aesthetic_nudge(user_text, m["aesthetic"]) for m in items_meta])
total = W["sim"]*base_sim + W["color"]*c_bonus + W["season"]*s_bonus + W["aesthetic"]*a_bonus
if outer:
total = (total + outer[0]) / 2.0 # simple blend if outer included
return float(total)
def build_outfits(
review_text: str,
df: pd.DataFrame,
item_embs: torch.Tensor,
mode: str = "separates",
gender_pref: str = "any",
season_pref: str = "any",
k_tops=20, k_bottoms=20, k_shoes=10, k_outer=10, k_onepiece=20,
include_outer=False,
max_results=20
):
df_filt = filter_index(df, gender_pref=gender_pref, season_pref=season_pref)
# Map from filtered df back to original indices
map_to_orig = df_filt.index.values
# slice embeddings for filtered subset
emb_filt = item_embs[map_to_orig]
t_emb = embed_text(review_text)
cat_k = {"top":k_tops, "bottom":k_bottoms, "shoes":k_shoes, "outer":k_outer, "onepiece":k_onepiece}
topcats = topk_by_category(t_emb, df_filt, cat_k, emb_filt)
results = []
if mode == "separates":
A = topcats["top"].to_dict("records")
B = topcats["bottom"].to_dict("records")
C = topcats["shoes"].to_dict("records")
O = topcats["outer"].to_dict("records") if include_outer else [None]
for a, b, c in itertools.product(A, B, C):
# optional outer picks: choose best 1 quickly (or None)
outer_choice = None
if include_outer and len(O) > 0:
outer_choice = O[0] # already sorted, first is best
idxs = [a["idx"], b["idx"], c["idx"]]
metas = [df_filt.iloc[i].to_dict() for i in idxs]
sims = [a["score"], b["score"], c["score"]]
if outer_choice:
out_meta = df_filt.iloc[outer_choice["idx"]].to_dict()
total = score_combo(sims, metas, review_text, outer=(outer_choice["score"], out_meta))
outfit = idxs + [outer_choice["idx"]]
else:
total = score_combo(sims, metas, review_text)
outfit = idxs
results.append({"score": total, "idxs": outfit})
elif mode == "onepiece":
A = topcats["onepiece"].to_dict("records")
C = topcats["shoes"].to_dict("records")
O = topcats["outer"].to_dict("records") if include_outer else [None]
for a, c in itertools.product(A, C):
outer_choice = None
if include_outer and len(O) > 0:
outer_choice = O[0]
idxs = [a["idx"], c["idx"]]
metas = [df_filt.iloc[i].to_dict() for i in idxs]
sims = [a["score"], c["score"]]
if outer_choice:
out_meta = df_filt.iloc[outer_choice["idx"]].to_dict()
total = score_combo(sims, metas, review_text, outer=(outer_choice["score"], out_meta))
outfit = idxs + [outer_choice["idx"]]
else:
total = score_combo(sims, metas, review_text)
outfit = idxs
results.append({"score": total, "idxs": outfit})
else:
raise ValueError("mode must be 'separates' or 'onepiece'")
# Sort and keep top N
results = sorted(results, key=lambda x: -x["score"])[:max_results]
# Convert back to original global indices (for df_items / emb index)
final = []
for r in results:
orig = map_to_orig[r["idxs"]].tolist()
final.append({"score": r["score"], "global_idxs": orig})
return final, df_filt
def stitch_images(paths, size=(256,256), pad=6, bg=(245,245,245)):
imgs = [Image.open(p).convert("RGB").resize(size) for p in paths]
W = size[0]*len(imgs) + pad*(len(imgs)-1)
H = size[1]
canvas = Image.new("RGB", (W, H), bg)
x = 0
for im in imgs:
canvas.paste(im, (x,0))
x += size[0] + pad
return canvas
def explain_outfit(rows, score, review_text):
parts = []
for r in rows:
parts.append(f"{r['garment_type']} ({r['color']}, {r['aesthetic']})")
txt = f"Score {score:.3f}: " + " • ".join(parts)
return textwrap.fill(txt, width=90)
def recommend_ui(review_text, outfit_mode, gender_pref, season_pref,
include_outer, n_results, k_tops, k_bottoms, k_shoes, k_outer, k_onepiece):
n_results = int(n_results)
k_tops, k_bottoms = int(k_tops), int(k_bottoms)
k_shoes, k_outer, k_onepiece = int(k_shoes), int(k_outer), int(k_onepiece)
res, df_used = build_outfits(
review_text=review_text,
df=df_items,
item_embs=img_embs,
mode=outfit_mode,
gender_pref=gender_pref,
season_pref=season_pref,
k_tops=k_tops, k_bottoms=k_bottoms, k_shoes=k_shoes, k_outer=k_outer, k_onepiece=k_onepiece,
include_outer=bool(include_outer),
max_results=n_results
)
images, captions = [], []
for r in res:
rows = [df_items.iloc[i].to_dict() for i in r["global_idxs"]]
paths = [df_items.iloc[i]["path"] for i in r["global_idxs"]]
comp = stitch_images(paths)
images.append(comp)
captions.append(explain_outfit(rows, r["score"], review_text))
return images, captions
with gr.Blocks() as demo:
gr.Markdown("## 🎬 Closet to Main-Character Outfits")
gr.Markdown("Paste a Letterboxd review (or vibe text), and get ranked outfit suggestions from your closet.")
with gr.Row():
review = gr.Textbox(label="Letterboxd review (or vibe text)", lines=5, placeholder="Paste review text here…")
with gr.Row():
outfit_mode = gr.Radio(choices=["separates","onepiece"], value="separates", label="Outfit mode")
gender_pref = gr.Dropdown(choices=["any","Women","Men","Neutral"], value="any", label="Gender preference")
season_pref = gr.Dropdown(choices=["any","Summer","Winter","Transitional","All-season"], value="any", label="Season preference")
include_outer = gr.Checkbox(value=False, label="Include outer layer")
with gr.Accordion("Advanced (top-k per category)", open=False):
with gr.Row():
n_results = gr.Slider(5, 30, value=12, step=1, label="How many results?")
with gr.Row():
k_tops = gr.Slider(5, 40, value=20, step=1, label="k_tops")
k_bottoms = gr.Slider(5, 40, value=20, step=1, label="k_bottoms")
k_shoes = gr.Slider(5, 30, value=10, step=1, label="k_shoes")
k_outer = gr.Slider(0, 30, value=10, step=1, label="k_outer")
k_onepiece = gr.Slider(5, 40, value=20, step=1, label="k_onepiece")
go = gr.Button("Recommend Outfits")
gallery = gr.Gallery(label="Ranked outfits", columns=2, height=600, show_label=True)
captions = gr.Textbox(label="Explanations", lines=10)
def run_and_format(*args):
imgs, caps = recommend_ui(*args)
# Gradio gallery wants [(img, caption), ...]
pairs = [(im, c) for im, c in zip(imgs, caps)]
# Also return the concatenated captions
return pairs, "\n\n".join(caps)
go.click(run_and_format, inputs=[review, outfit_mode, gender_pref, season_pref,
include_outer, n_results, k_tops, k_bottoms, k_shoes, k_outer, k_onepiece],
outputs=[gallery, captions])
demo.launch(share=True)
#End