Spaces:
Sleeping
Sleeping
| import os, io, math, itertools, textwrap, json | |
| import numpy as np | |
| import pandas as pd | |
| from tqdm import tqdm | |
| from PIL import Image | |
| import torch | |
| from torch.nn.functional import normalize | |
| from transformers import CLIPProcessor, CLIPModel | |
| import datasets as hfds | |
| import gradio as gr | |
| # --- Load pre-saved metadata and embeddings (fast startup) --- | |
| import numpy as np, pandas as pd, torch | |
| # --- Load pre-saved metadata and embeddings (fast startup) --- | |
| import numpy as np, pandas as pd, torch, glob, os, random | |
| from datasets import load_dataset | |
| print("Loading cached outfit data...") | |
| df_items = pd.read_csv("closet_items.csv") | |
| img_embs = torch.tensor(np.load("closet_clip_embs.npy")) | |
| print(f"Loaded {len(df_items)} items, embeddings: {img_embs.shape}") | |
| # --- Correct image folder --- | |
| img_dir = "closet_small" # your unzipped image folder | |
| print(f"Using image folder: {img_dir}") | |
| # --- Unzip automatically if needed --- | |
| if os.path.exists("closet_small.zip") and not os.path.isdir(img_dir): | |
| import zipfile | |
| with zipfile.ZipFile("closet_small.zip", "r") as zip_ref: | |
| zip_ref.extractall(img_dir) | |
| print(f"✅ Unzipped images into: {img_dir}") | |
| # --- Build full paths for images --- | |
| if "path" not in df_items.columns or df_items["path"].isnull().all(): | |
| # if CSV doesn’t already have a 'path' column | |
| img_files = sorted(glob.glob(os.path.join(img_dir, "*.png")) + glob.glob(os.path.join(img_dir, "*.jpg"))) | |
| if len(img_files) >= len(df_items): | |
| df_items["path"] = img_files[:len(df_items)] | |
| else: | |
| df_items["path"] = [os.path.join(img_dir, f"image_{i}.png") for i in range(len(df_items))] | |
| # --- Verify file existence --- | |
| df_items = df_items[df_items["path"].apply(os.path.exists)].reset_index(drop=True) | |
| print(f"✅ Closet items ready: {len(df_items)} images found in '{img_dir}'.") | |
| print(df_items.head(3)) | |
| # 2.1) Optional: Load Letterboxd dataset and helpers | |
| # If the dataset is private, authenticate first: | |
| # !huggingface-cli login | |
| try: | |
| lb_ds = load_dataset("pkchwy/letterboxd-all-movie-data") | |
| if "train" in lb_ds: | |
| lb = lb_ds["train"] | |
| elif "default" in lb_ds: | |
| lb = lb_ds["default"] | |
| else: | |
| lb = next(iter(lb_ds.values())) | |
| print("Loaded Letterboxd dataset:", lb) | |
| except Exception as e: | |
| lb = None | |
| print("Letterboxd dataset not available:", e) | |
| _LBOX_TEXT_FIELDS = ["review_text", "review", "content", "text", "body", "description"] | |
| _LBOX_TITLE_FIELDS = ["movie_title", "title", "film_name", "name"] | |
| def _extract_text(example): | |
| for k in _LBOX_TEXT_FIELDS: | |
| if k in example and isinstance(example[k], str) and example[k].strip(): | |
| return example[k].strip() | |
| if "tags" in example and isinstance(example["tags"], list) and example["tags"]: | |
| return " ".join(map(str, example["tags"])) | |
| return None | |
| def _extract_title(example): | |
| for k in _LBOX_TITLE_FIELDS: | |
| if k in example and isinstance(example[k], str) and example[k].strip(): | |
| return example[k].strip() | |
| return "(untitled)" | |
| def sample_letterboxd_review(max_tries=50): | |
| if lb is None or len(lb) == 0: | |
| return None, None | |
| for _ in range(max_tries): | |
| ex = lb[random.randrange(0, len(lb))] | |
| txt = _extract_text(ex) | |
| if txt: | |
| return _extract_title(ex), txt | |
| return None, None | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_ID = "openai/clip-vit-base-patch32" | |
| clip_model = CLIPModel.from_pretrained(MODEL_ID).to(device) | |
| clip_proc = CLIPProcessor.from_pretrained(MODEL_ID) | |
| # Use a local folder that always exists in Spaces or Colab | |
| EMB_CACHE = os.path.join(os.getcwd(), "closet_clip_embeds.pt") | |
| def pil_open(path): | |
| img = Image.open(path).convert("RGB") | |
| return img | |
| def embed_images(paths, batch_size=32): | |
| if not paths: # Handle empty list of paths | |
| return torch.empty(0, clip_model.config.projection_dim) # Return an empty tensor of the correct dimension | |
| embs = [] | |
| for i in tqdm(range(0, len(paths), batch_size), desc="Embedding images"): | |
| batch_paths = paths[i:i+batch_size] | |
| images = [pil_open(p) for p in batch_paths] | |
| inp = clip_proc(images=images, return_tensors="pt", padding=True).to(device) | |
| feats = clip_model.get_image_features(**inp) | |
| feats = normalize(feats, dim=-1) | |
| embs.append(feats.cpu()) | |
| return torch.cat(embs, dim=0) # (N, D) | |
| # Build (or load) embedding cache | |
| if os.path.exists(EMB_CACHE): | |
| try: | |
| cache = torch.load(EMB_CACHE) | |
| img_embs = cache["img_embs"] | |
| print("Loaded cached embeddings:", img_embs.shape) | |
| except Exception as e: | |
| print(f"Error loading cache: {e}. Re-embedding.") | |
| if not df_items.empty: | |
| img_embs = embed_images(df_items["path"].tolist()) | |
| torch.save({"img_embs": img_embs}, EMB_CACHE) | |
| print("Saved cache:", EMB_CACHE) | |
| else: | |
| img_embs = torch.empty(0, clip_model.config.projection_dim) # Initialize with empty tensor | |
| print("No items in the closet, no embeddings to compute.") | |
| else: | |
| if not df_items.empty: | |
| img_embs = embed_images(df_items["path"].tolist()) | |
| torch.save({"img_embs": img_embs}, EMB_CACHE) | |
| print("Saved cache:", EMB_CACHE) | |
| else: | |
| img_embs = torch.empty(0, clip_model.config.projection_dim) # Initialize with empty tensor | |
| print("No items in the closet, no embeddings to compute.") | |
| def embed_text(s: str): | |
| # Simple: use the raw review text; you can add templates later. | |
| inputs = clip_proc(text=[s], return_tensors="pt").to(device) | |
| txt = clip_model.get_text_features(**inputs) | |
| return normalize(txt, dim=-1).cpu()[0] # (D,) | |
| def cos_sim(a: torch.Tensor, b: torch.Tensor): | |
| # a: (D,), b: (N, D) | |
| return (b @ a) # because both are normalized | |
| # --- Filtering helpers --- | |
| def filter_index(df, gender_pref=None, season_pref=None): | |
| idx = np.ones(len(df), dtype=bool) | |
| if gender_pref and gender_pref.lower() != "any": | |
| idx &= df["gender"].fillna("").str.lower().str.contains(gender_pref.lower()) | |
| if season_pref and season_pref.lower() != "any": | |
| idx &= df["seasonality"].fillna("").str.lower().str.contains(season_pref.lower()) | |
| return df[idx].reset_index(drop=True) | |
| # --- pick top-k per category by text–image similarity --- | |
| def topk_by_category(txt_emb, df, categories_k, item_embs): | |
| """ | |
| categories_k: dict like {"top":20, "bottom":20, "shoes":10, "outer":10, "onepiece":20} | |
| returns dict: cat -> dataframe with columns [idx_in_df, score] sorted desc | |
| """ | |
| out = {} | |
| sims = cos_sim(txt_emb, item_embs) # (N,) | |
| for cat, k in categories_k.items(): | |
| mask = (df["category"] == cat).values | |
| if mask.sum() == 0: | |
| out[cat] = pd.DataFrame(columns=["idx", "score"]) | |
| continue | |
| cat_idx = np.where(mask)[0] | |
| cat_scores = sims[cat_idx].numpy() | |
| order = np.argsort(-cat_scores)[:k] | |
| out[cat] = pd.DataFrame({ | |
| "idx": cat_idx[order], | |
| "score": cat_scores[order] | |
| }) | |
| return out | |
| # --- simple color compatibility --- | |
| PALETTES = { | |
| "neutrals": {"black","white","gray","beige","silver"}, | |
| "earthy": {"brown","beige","olive","gold"}, | |
| "cool": {"blue","navy","teal","green"}, | |
| "warm": {"red","orange","yellow","pink","burgundy"} | |
| } | |
| def color_score(*colors, c_outer=None): | |
| """ | |
| Flexible color compatibility scorer. | |
| Accepts 2 or 3 colors (e.g., top/bottom/shoes or onepiece/shoes), | |
| plus optional outer layer. | |
| """ | |
| def bucket(c): | |
| if not isinstance(c, str): | |
| return "neutrals" | |
| x = c.lower() | |
| for b, s in PALETTES.items(): | |
| if any(k in x for k in s): | |
| return b | |
| return "neutrals" | |
| # Handle missing or extra colors gracefully | |
| bs = [bucket(c) for c in colors if c is not None] | |
| if c_outer: | |
| bs.append(bucket(c_outer)) | |
| if len(bs) == 0: | |
| return 0.0 | |
| # reward if majority colors are same bucket OR contain neutrals | |
| majority = max(bs, key=bs.count) | |
| maj_ct = bs.count(majority) | |
| bonus = 0.04 if maj_ct >= len(bs) - 1 else 0.0 | |
| if any(b == "neutrals" for b in bs): | |
| bonus += 0.02 | |
| return bonus | |
| # --- Season & aesthetic nudges from text --- | |
| def season_nudge(text, seasonality): | |
| t = text.lower() | |
| maps = {"summer":["summer","heat","beach"], "winter":["winter","cold","snow","coat"], | |
| "transitional":["fall","spring","breeze","layer"], "all-season":[]} | |
| if seasonality is None: return 0.0 | |
| for k, cues in maps.items(): | |
| if k in seasonality.lower() and any(w in t for w in cues): | |
| return 0.03 | |
| return 0.0 | |
| def aesthetic_nudge(text, aesthetic): | |
| if aesthetic and any(w in text.lower() for w in aesthetic.lower().split()): | |
| return 0.03 | |
| return 0.0 | |
| # --- Score a combination --- | |
| def score_combo(txt_sim_scores, items_meta, user_text, outer=None, weights=None): | |
| """ | |
| txt_sim_scores: list of per-item CLIP sims [top, bottom, shoes] or [onepiece, shoes] | |
| items_meta: list of metadata rows matching above | |
| outer: (sim, meta) or None | |
| weights: dict of weights | |
| """ | |
| W = { | |
| "sim": 1.0, # average CLIP similarity | |
| "color": 1.0, | |
| "season": 1.0, | |
| "aesthetic": 1.0, | |
| "outer_penalty": 0.0 # not used now | |
| } | |
| if weights: W.update(weights) | |
| base_sim = np.mean(txt_sim_scores) | |
| cols = [m["color"] for m in items_meta] | |
| cols_outer = outer[1]["color"] if outer else None | |
| c_bonus = color_score(*cols, c_outer=cols_outer) | |
| s_bonus = np.mean([season_nudge(user_text, m["seasonality"]) for m in items_meta]) | |
| a_bonus = np.mean([aesthetic_nudge(user_text, m["aesthetic"]) for m in items_meta]) | |
| total = W["sim"]*base_sim + W["color"]*c_bonus + W["season"]*s_bonus + W["aesthetic"]*a_bonus | |
| if outer: | |
| total = (total + outer[0]) / 2.0 # simple blend if outer included | |
| return float(total) | |
| def build_outfits( | |
| review_text: str, | |
| df: pd.DataFrame, | |
| item_embs: torch.Tensor, | |
| mode: str = "separates", | |
| gender_pref: str = "any", | |
| season_pref: str = "any", | |
| k_tops=20, k_bottoms=20, k_shoes=10, k_outer=10, k_onepiece=20, | |
| include_outer=False, | |
| max_results=20 | |
| ): | |
| df_filt = filter_index(df, gender_pref=gender_pref, season_pref=season_pref) | |
| # Map from filtered df back to original indices | |
| map_to_orig = df_filt.index.values | |
| # slice embeddings for filtered subset | |
| emb_filt = item_embs[map_to_orig] | |
| t_emb = embed_text(review_text) | |
| cat_k = {"top":k_tops, "bottom":k_bottoms, "shoes":k_shoes, "outer":k_outer, "onepiece":k_onepiece} | |
| topcats = topk_by_category(t_emb, df_filt, cat_k, emb_filt) | |
| results = [] | |
| if mode == "separates": | |
| A = topcats["top"].to_dict("records") | |
| B = topcats["bottom"].to_dict("records") | |
| C = topcats["shoes"].to_dict("records") | |
| O = topcats["outer"].to_dict("records") if include_outer else [None] | |
| for a, b, c in itertools.product(A, B, C): | |
| # optional outer picks: choose best 1 quickly (or None) | |
| outer_choice = None | |
| if include_outer and len(O) > 0: | |
| outer_choice = O[0] # already sorted, first is best | |
| idxs = [a["idx"], b["idx"], c["idx"]] | |
| metas = [df_filt.iloc[i].to_dict() for i in idxs] | |
| sims = [a["score"], b["score"], c["score"]] | |
| if outer_choice: | |
| out_meta = df_filt.iloc[outer_choice["idx"]].to_dict() | |
| total = score_combo(sims, metas, review_text, outer=(outer_choice["score"], out_meta)) | |
| outfit = idxs + [outer_choice["idx"]] | |
| else: | |
| total = score_combo(sims, metas, review_text) | |
| outfit = idxs | |
| results.append({"score": total, "idxs": outfit}) | |
| elif mode == "onepiece": | |
| A = topcats["onepiece"].to_dict("records") | |
| C = topcats["shoes"].to_dict("records") | |
| O = topcats["outer"].to_dict("records") if include_outer else [None] | |
| for a, c in itertools.product(A, C): | |
| outer_choice = None | |
| if include_outer and len(O) > 0: | |
| outer_choice = O[0] | |
| idxs = [a["idx"], c["idx"]] | |
| metas = [df_filt.iloc[i].to_dict() for i in idxs] | |
| sims = [a["score"], c["score"]] | |
| if outer_choice: | |
| out_meta = df_filt.iloc[outer_choice["idx"]].to_dict() | |
| total = score_combo(sims, metas, review_text, outer=(outer_choice["score"], out_meta)) | |
| outfit = idxs + [outer_choice["idx"]] | |
| else: | |
| total = score_combo(sims, metas, review_text) | |
| outfit = idxs | |
| results.append({"score": total, "idxs": outfit}) | |
| else: | |
| raise ValueError("mode must be 'separates' or 'onepiece'") | |
| # Sort and keep top N | |
| results = sorted(results, key=lambda x: -x["score"])[:max_results] | |
| # Convert back to original global indices (for df_items / emb index) | |
| final = [] | |
| for r in results: | |
| orig = map_to_orig[r["idxs"]].tolist() | |
| final.append({"score": r["score"], "global_idxs": orig}) | |
| return final, df_filt | |
| def stitch_images(paths, size=(256,256), pad=6, bg=(245,245,245)): | |
| imgs = [Image.open(p).convert("RGB").resize(size) for p in paths] | |
| W = size[0]*len(imgs) + pad*(len(imgs)-1) | |
| H = size[1] | |
| canvas = Image.new("RGB", (W, H), bg) | |
| x = 0 | |
| for im in imgs: | |
| canvas.paste(im, (x,0)) | |
| x += size[0] + pad | |
| return canvas | |
| def explain_outfit(rows, score, review_text): | |
| parts = [] | |
| for r in rows: | |
| parts.append(f"{r['garment_type']} ({r['color']}, {r['aesthetic']})") | |
| txt = f"Score {score:.3f}: " + " • ".join(parts) | |
| return textwrap.fill(txt, width=90) | |
| def recommend_ui(review_text, outfit_mode, gender_pref, season_pref, | |
| include_outer, n_results, k_tops, k_bottoms, k_shoes, k_outer, k_onepiece): | |
| n_results = int(n_results) | |
| k_tops, k_bottoms = int(k_tops), int(k_bottoms) | |
| k_shoes, k_outer, k_onepiece = int(k_shoes), int(k_outer), int(k_onepiece) | |
| res, df_used = build_outfits( | |
| review_text=review_text, | |
| df=df_items, | |
| item_embs=img_embs, | |
| mode=outfit_mode, | |
| gender_pref=gender_pref, | |
| season_pref=season_pref, | |
| k_tops=k_tops, k_bottoms=k_bottoms, k_shoes=k_shoes, k_outer=k_outer, k_onepiece=k_onepiece, | |
| include_outer=bool(include_outer), | |
| max_results=n_results | |
| ) | |
| images, captions = [], [] | |
| for r in res: | |
| rows = [df_items.iloc[i].to_dict() for i in r["global_idxs"]] | |
| paths = [df_items.iloc[i]["path"] for i in r["global_idxs"]] | |
| comp = stitch_images(paths) | |
| images.append(comp) | |
| captions.append(explain_outfit(rows, r["score"], review_text)) | |
| return images, captions | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🎬 Closet to Main-Character Outfits") | |
| gr.Markdown("Paste a Letterboxd review (or vibe text), and get ranked outfit suggestions from your closet.") | |
| with gr.Row(): | |
| review = gr.Textbox(label="Letterboxd review (or vibe text)", lines=5, placeholder="Paste review text here…") | |
| with gr.Row(): | |
| outfit_mode = gr.Radio(choices=["separates","onepiece"], value="separates", label="Outfit mode") | |
| gender_pref = gr.Dropdown(choices=["any","Women","Men","Neutral"], value="any", label="Gender preference") | |
| season_pref = gr.Dropdown(choices=["any","Summer","Winter","Transitional","All-season"], value="any", label="Season preference") | |
| include_outer = gr.Checkbox(value=False, label="Include outer layer") | |
| with gr.Accordion("Advanced (top-k per category)", open=False): | |
| with gr.Row(): | |
| n_results = gr.Slider(5, 30, value=12, step=1, label="How many results?") | |
| with gr.Row(): | |
| k_tops = gr.Slider(5, 40, value=20, step=1, label="k_tops") | |
| k_bottoms = gr.Slider(5, 40, value=20, step=1, label="k_bottoms") | |
| k_shoes = gr.Slider(5, 30, value=10, step=1, label="k_shoes") | |
| k_outer = gr.Slider(0, 30, value=10, step=1, label="k_outer") | |
| k_onepiece = gr.Slider(5, 40, value=20, step=1, label="k_onepiece") | |
| go = gr.Button("Recommend Outfits") | |
| gallery = gr.Gallery(label="Ranked outfits", columns=2, height=600, show_label=True) | |
| captions = gr.Textbox(label="Explanations", lines=10) | |
| def run_and_format(*args): | |
| imgs, caps = recommend_ui(*args) | |
| # Gradio gallery wants [(img, caption), ...] | |
| pairs = [(im, c) for im, c in zip(imgs, caps)] | |
| # Also return the concatenated captions | |
| return pairs, "\n\n".join(caps) | |
| go.click(run_and_format, inputs=[review, outfit_mode, gender_pref, season_pref, | |
| include_outer, n_results, k_tops, k_bottoms, k_shoes, k_outer, k_onepiece], | |
| outputs=[gallery, captions]) | |
| demo.launch(share=True) | |
| #End | |