Spaces:
Paused
Paused
| import os | |
| import json | |
| import random | |
| import argparse | |
| from pathlib import Path | |
| from typing import Dict, Any, List, Set, Union | |
| def _normalize_outfits(obj: Union[List[Any], Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| """Normalize various Polyvore JSON formats into a list of {"items": [id,...]} dicts.""" | |
| result: List[Dict[str, Any]] = [] | |
| if isinstance(obj, dict): | |
| # Handle case where the file contains outfit_id -> outfit_data mapping | |
| for outfit_id, outfit_data in obj.items(): | |
| if isinstance(outfit_data, dict): | |
| if "items" in outfit_data: | |
| items = outfit_data["items"] | |
| if isinstance(items, list): | |
| if items and isinstance(items[0], dict): | |
| # Extract item IDs from dict format | |
| item_ids = [] | |
| for item in items: | |
| item_id = item.get("item_id") or item.get("id") or item.get("itemId") | |
| if item_id is not None: | |
| item_ids.append(str(item_id)) | |
| if item_ids: | |
| result.append({"items": item_ids, "outfit_id": outfit_id}) | |
| else: | |
| # Direct list of item IDs | |
| result.append({"items": [str(x) for x in items], "outfit_id": outfit_id}) | |
| elif "set_id" in outfit_data: | |
| # Alternative format with set_id | |
| if "items" in outfit_data: | |
| items = outfit_data["items"] | |
| if isinstance(items, list): | |
| if items and isinstance(items[0], dict): | |
| item_ids = [] | |
| for item in items: | |
| item_id = item.get("item_id") or item.get("id") or item.get("itemId") | |
| if item_id is not None: | |
| item_ids.append(str(item_id)) | |
| if item_ids: | |
| result.append({"items": item_ids, "outfit_id": outfit_id}) | |
| else: | |
| result.append({"items": [str(x) for x in items], "outfit_id": outfit_id}) | |
| elif isinstance(outfit_data, list): | |
| # Direct list of item IDs | |
| result.append({"items": [str(x) for x in outfit_data], "outfit_id": outfit_id}) | |
| elif isinstance(obj, list): | |
| for item in obj: | |
| if isinstance(item, dict): | |
| if "items" in item: | |
| items = item["items"] | |
| if isinstance(items, list): | |
| if items and isinstance(items[0], dict): | |
| # Extract item IDs from dict format | |
| item_ids = [] | |
| for it in items: | |
| item_id = it.get("item_id") or it.get("id") or it.get("itemId") | |
| if item_id is not None: | |
| item_ids.append(str(item_id)) | |
| if item_ids: | |
| result.append({"items": item_ids}) | |
| else: | |
| # Direct list of item IDs | |
| result.append({"items": [str(x) for x in items]}) | |
| elif "set_id" in item: | |
| # Alternative format | |
| if "items" in item: | |
| items = item["items"] | |
| if isinstance(items, list): | |
| if items and isinstance(items[0], dict): | |
| item_ids = [] | |
| for it in items: | |
| item_id = it.get("item_id") or it.get("id") or it.get("itemId") | |
| if item_id is not None: | |
| item_ids.append(str(item_id)) | |
| if item_ids: | |
| result.append({"items": item_ids}) | |
| else: | |
| result.append({"items": [str(x) for x in items]}) | |
| elif isinstance(item, list): | |
| # Direct list of item IDs | |
| result.append({"items": [str(x) for x in item]}) | |
| return result | |
| def load_outfits_json(root: str, split: str) -> List[Dict[str, Any]]: | |
| """Try to load outfit data from various possible locations and formats.""" | |
| candidates = [ | |
| os.path.join(root, f"{split}.json"), | |
| os.path.join(root, f"{split}_no_dup.json"), | |
| os.path.join(root, "splits", f"{split}.json"), | |
| os.path.join(root, "splits", f"{split}_no_dup.json"), | |
| # Official Polyvore often ships splits under nondisjoint/ or disjoint/ | |
| os.path.join(root, "nondisjoint", f"{split}.json"), | |
| os.path.join(root, "disjoint", f"{split}.json"), | |
| ] | |
| for p in candidates: | |
| if os.path.exists(p): | |
| try: | |
| with open(p, "r") as f: | |
| raw = json.load(f) | |
| data = _normalize_outfits(raw) | |
| if data: | |
| print(f"β Loaded {len(data)} outfits from {p}") | |
| return data | |
| except Exception as e: | |
| print(f"β οΈ Failed to load {p}: {e}") | |
| continue | |
| raise FileNotFoundError(f"Could not find usable {split} split in {root} or {root}/splits") | |
| def extract_outfits_from_metadata(root: str) -> List[Dict[str, Any]]: | |
| """Extract outfit information from polyvore_item_metadata.json using set_id grouping.""" | |
| print("π Extracting outfits from metadata using set_id grouping...") | |
| metadata_path = os.path.join(root, "polyvore_item_metadata.json") | |
| if not os.path.exists(metadata_path): | |
| print(f"β Metadata file not found: {metadata_path}") | |
| return [] | |
| try: | |
| with open(metadata_path, "r") as f: | |
| metadata = json.load(f) | |
| if not isinstance(metadata, dict): | |
| print("β Metadata is not a dictionary") | |
| return [] | |
| # Group items by set_id to create outfits | |
| outfits_by_set = {} | |
| for item_id, item_data in metadata.items(): | |
| if isinstance(item_data, dict) and "set_id" in item_data: | |
| set_id = item_data["set_id"] | |
| if set_id not in outfits_by_set: | |
| outfits_by_set[set_id] = [] | |
| outfits_by_set[set_id].append(str(item_id)) | |
| # Convert to outfit format | |
| outfits = [] | |
| for set_id, item_ids in outfits_by_set.items(): | |
| if len(item_ids) >= 2: # Minimum outfit size | |
| outfits.append({ | |
| "items": item_ids, | |
| "set_id": set_id, | |
| "outfit_id": f"set_{set_id}" | |
| }) | |
| print(f"β Extracted {len(outfits)} outfits from metadata (set_id grouping)") | |
| return outfits | |
| except Exception as e: | |
| print(f"β Failed to parse metadata: {e}") | |
| return [] | |
| def extract_outfits_from_titles(root: str) -> List[Dict[str, Any]]: | |
| """Extract outfit information from polyvore_outfit_titles.json.""" | |
| print("π Extracting outfits from outfit titles...") | |
| titles_path = os.path.join(root, "polyvore_outfit_titles.json") | |
| if not os.path.exists(titles_path): | |
| print(f"β Titles file not found: {titles_path}") | |
| return [] | |
| try: | |
| with open(titles_path, "r") as f: | |
| titles = json.load(f) | |
| if not isinstance(titles, dict): | |
| print("β Titles is not a dictionary") | |
| return [] | |
| outfits = [] | |
| for outfit_id, outfit_data in titles.items(): | |
| if isinstance(outfit_data, dict) and "items" in outfit_data: | |
| items = outfit_data["items"] | |
| if isinstance(items, list) and len(items) >= 2: | |
| # Convert all items to strings | |
| item_ids = [str(x) for x in items] | |
| outfits.append({ | |
| "items": item_ids, | |
| "outfit_id": outfit_id | |
| }) | |
| print(f"β Extracted {len(outfits)} outfits from titles") | |
| return outfits | |
| except Exception as e: | |
| print(f"β Failed to parse titles: {e}") | |
| return [] | |
| def try_load_any_outfits(root: str) -> List[Dict[str, Any]]: | |
| """Try to load outfits from any available source, prioritizing official splits.""" | |
| merged: List[Dict[str, Any]] = [] | |
| # First try official splits (nondisjoint and disjoint) | |
| print("π Looking for official splits...") | |
| for split in ["train", "valid", "test"]: | |
| try: | |
| data = load_outfits_json(root, split) | |
| merged.extend(data) | |
| print(f"β Found {split} split with {len(data)} outfits") | |
| except FileNotFoundError: | |
| print(f"β οΈ No {split} split found") | |
| continue | |
| if merged: | |
| print(f"β Total: {len(merged)} outfits from official splits") | |
| return merged | |
| # If no official splits, try to extract from metadata | |
| print("π§ No official splits found, extracting from metadata...") | |
| # Try metadata first (more reliable) | |
| outfits = extract_outfits_from_metadata(root) | |
| if outfits: | |
| return outfits | |
| # Try titles as fallback | |
| outfits = extract_outfits_from_titles(root) | |
| if outfits: | |
| return outfits | |
| print("β No outfits could be extracted from any source") | |
| return [] | |
| def collect_all_items(outfits: List[Dict[str, Any]]) -> List[str]: | |
| """Collect all unique item IDs from outfits.""" | |
| s: Set[str] = set() | |
| for o in outfits: | |
| for it in o.get("items", []): | |
| s.add(str(it)) | |
| return sorted(list(s)) | |
| def build_triplets(outfits: List[Dict[str, Any]], all_items: List[str], max_triplets: int = 200000) -> List[Dict[str, str]]: | |
| """Build training triplets from outfits.""" | |
| rng = random.Random(42) | |
| all_items_set = set(all_items) | |
| triplets: List[Dict[str, str]] = [] | |
| for o in outfits: | |
| items = [str(i) for i in o.get("items", [])] | |
| if len(items) < 2: | |
| continue | |
| local_set = set(items) | |
| for i in range(len(items) - 1): | |
| a = items[i] | |
| p = items[i + 1] | |
| # Pick a negative not in this outfit | |
| negatives = list(all_items_set - local_set) | |
| if not negatives: | |
| continue | |
| n = rng.choice(negatives) | |
| triplets.append({"anchor": a, "positive": p, "negative": n}) | |
| if len(triplets) >= max_triplets: | |
| return triplets | |
| return triplets | |
| def build_outfit_pairs(outfits: List[Dict[str, Any]], num_negatives_per_pos: int = 1) -> List[Dict[str, Any]]: | |
| """Build outfit pairs for training.""" | |
| rng = random.Random(123) | |
| all_items = collect_all_items(outfits) | |
| all_set = set(all_items) | |
| pairs: List[Dict[str, Any]] = [] | |
| # Positive samples | |
| for o in outfits: | |
| items = [str(i) for i in o.get("items", [])] | |
| if len(items) < 2: | |
| continue | |
| pairs.append({"items": items, "label": 1}) | |
| # Negative by corrupting one item | |
| for _ in range(num_negatives_per_pos): | |
| if not items: | |
| continue | |
| idx = rng.randrange(len(items)) | |
| neg_pool = list(all_set - set(items)) | |
| if not neg_pool: | |
| continue | |
| neg_item = rng.choice(neg_pool) | |
| neg_items = items.copy() | |
| neg_items[idx] = neg_item | |
| pairs.append({"items": neg_items, "label": 0}) | |
| return pairs | |
| def build_outfit_triplets(outfits: List[Dict[str, Any]], num_triplets: int = 200000) -> List[Dict[str, Any]]: | |
| """Build outfit-level triplets for ViT training.""" | |
| rng = random.Random(999) | |
| # Collect only valid positive outfits (len >= 3) | |
| pos = [o for o in outfits if len(o.get("items", [])) >= 3] | |
| if len(pos) < 2: | |
| print(f"β οΈ Only {len(pos)} valid outfits found, need at least 2 for triplets") | |
| return [] | |
| all_items = collect_all_items(outfits) | |
| all_set = set(all_items) | |
| triplets: List[Dict[str, Any]] = [] | |
| for _ in range(min(num_triplets, len(pos) * 10)): # Limit based on available outfits | |
| if len(pos) < 2: | |
| break | |
| ga = rng.choice(pos) | |
| gb = rng.choice(pos) | |
| # Ensure ga != gb | |
| if ga is gb: | |
| continue | |
| # Create bad by corrupting one item in ga | |
| items_ga = [str(i) for i in ga.get("items", [])] | |
| if not items_ga: | |
| continue | |
| corrupt_idx = rng.randrange(len(items_ga)) | |
| neg_pool = list(all_set - set(items_ga)) | |
| if not neg_pool: | |
| continue | |
| neg_item = rng.choice(neg_pool) | |
| bad = items_ga.copy() | |
| bad[corrupt_idx] = neg_item | |
| triplets.append({ | |
| "good_a": items_ga, | |
| "good_b": [str(i) for i in gb.get("items", [])], | |
| "bad": bad | |
| }) | |
| return triplets | |
| def main() -> None: | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--root", type=str, required=True, help="Polyvore dataset root") | |
| ap.add_argument("--out", type=str, default=None, help="Output directory for splits (default: <root>/splits)") | |
| ap.add_argument("--max_samples", type=int, default=None, help="Maximum number of samples to use (for testing)") | |
| ap.add_argument("--max_triplets", type=int, default=200000) | |
| ap.add_argument("--neg_per_pos", type=int, default=1) | |
| ap.add_argument("--force_random_split", action="store_true", help="Force random split creation (not recommended)") | |
| args = ap.parse_args() | |
| out_dir = args.out or os.path.join(args.root, "splits") | |
| Path(out_dir).mkdir(parents=True, exist_ok=True) | |
| print(f"π Preparing Polyvore dataset from {args.root}") | |
| print(f"π Output directory: {out_dir}") | |
| # Always try to use official splits first | |
| splits = {} | |
| found_any_official = False | |
| print("π― Looking for official splits...") | |
| for split in ["train", "valid", "test"]: | |
| try: | |
| data = load_outfits_json(args.root, split) | |
| splits[split] = data | |
| if data: | |
| found_any_official = True | |
| print(f"β Loaded {split} split: {len(data)} outfits") | |
| except FileNotFoundError as e: | |
| print(f"β οΈ Skipping {split}: {e}") | |
| splits[split] = [] | |
| if found_any_official: | |
| print("π Using official splits from dataset!") | |
| else: | |
| print("β οΈ No official splits found") | |
| if args.force_random_split: | |
| print("π§ Creating random split (not recommended for production)...") | |
| all_outfits = try_load_any_outfits(args.root) | |
| if not all_outfits: | |
| print("β No outfits found to split. Please check dataset structure.") | |
| print("π Expected files:") | |
| print(" - train.json, valid.json, test.json") | |
| print(" - nondisjoint/train.json, etc.") | |
| print(" - polyvore_item_metadata.json") | |
| print(" - polyvore_outfit_titles.json") | |
| return | |
| print(f"π― Creating random split from {len(all_outfits)} outfits") | |
| rng = random.Random(2024) | |
| rng.shuffle(all_outfits) | |
| n = len(all_outfits) | |
| n_train = int(0.7 * n) | |
| n_valid = int(0.1 * n) | |
| splits = { | |
| "train": all_outfits[:n_train], | |
| "valid": all_outfits[n_train:n_train + n_valid], | |
| "test": all_outfits[n_train + n_valid:], | |
| } | |
| print(f"π Split created: train={n_train}, valid={n_valid}, test={n-n_train-n_valid}") | |
| else: | |
| print("β Random split creation disabled. Use --force_random_split if needed.") | |
| print("π§ Please ensure official splits are available in nondisjoint/ or disjoint/ folders.") | |
| return | |
| # Apply dataset size limit if specified | |
| if args.max_samples: | |
| print(f"π― Limiting dataset to {args.max_samples} samples for testing...") | |
| for split in splits: | |
| if splits[split]: | |
| # Take only the first max_samples outfits | |
| splits[split] = splits[split][:args.max_samples] | |
| print(f" π {split}: Limited to {len(splits[split])} outfits") | |
| # Generate training data for each split | |
| for split, outfits in splits.items(): | |
| if not outfits: | |
| print(f"β οΈ No outfits for {split} split, skipping") | |
| continue | |
| print(f"\nπ§ Processing {split} split ({len(outfits)} outfits)...") | |
| all_items = collect_all_items(outfits) | |
| print(f" π¦ Total unique items: {len(all_items)}") | |
| triplets = build_triplets(outfits, all_items, max_triplets=args.max_triplets) | |
| print(f" π Generated {len(triplets)} item triplets") | |
| pairs = build_outfit_pairs(outfits, num_negatives_per_pos=args.neg_per_pos) | |
| print(f" π Generated {len(pairs)} outfit pairs") | |
| outfit_triplets = build_outfit_triplets(outfits) | |
| print(f" π Generated {len(outfit_triplets)} outfit triplets") | |
| # Save files | |
| with open(os.path.join(out_dir, f"{split}.json"), "w") as f: | |
| json.dump(triplets, f, indent=2) | |
| with open(os.path.join(out_dir, f"outfits_{split}.json"), "w") as f: | |
| json.dump(pairs, f, indent=2) | |
| with open(os.path.join(out_dir, f"outfit_triplets_{split}.json"), "w") as f: | |
| json.dump(outfit_triplets, f, indent=2) | |
| print(f" πΎ Saved {split} data to {out_dir}") | |
| print(f"\nπ Dataset preparation complete!") | |
| print(f"π All files saved to: {out_dir}") | |
| if found_any_official: | |
| print("β Used official dataset splits - production ready!") | |
| else: | |
| print("β οΈ Used random splits - not recommended for production") | |
| if __name__ == "__main__": | |
| main() | |