import os import json import random import argparse from pathlib import Path from typing import Dict, Any, List, Set, Union def _normalize_outfits(obj: Union[List[Any], Dict[str, Any]]) -> List[Dict[str, Any]]: """Normalize various Polyvore JSON formats into a list of {"items": [id,...]} dicts.""" result: List[Dict[str, Any]] = [] if isinstance(obj, dict): # Handle case where the file contains outfit_id -> outfit_data mapping for outfit_id, outfit_data in obj.items(): if isinstance(outfit_data, dict): if "items" in outfit_data: items = outfit_data["items"] if isinstance(items, list): if items and isinstance(items[0], dict): # Extract item IDs from dict format item_ids = [] for item in items: item_id = item.get("item_id") or item.get("id") or item.get("itemId") if item_id is not None: item_ids.append(str(item_id)) if item_ids: result.append({"items": item_ids, "outfit_id": outfit_id}) else: # Direct list of item IDs result.append({"items": [str(x) for x in items], "outfit_id": outfit_id}) elif "set_id" in outfit_data: # Alternative format with set_id if "items" in outfit_data: items = outfit_data["items"] if isinstance(items, list): if items and isinstance(items[0], dict): item_ids = [] for item in items: item_id = item.get("item_id") or item.get("id") or item.get("itemId") if item_id is not None: item_ids.append(str(item_id)) if item_ids: result.append({"items": item_ids, "outfit_id": outfit_id}) else: result.append({"items": [str(x) for x in items], "outfit_id": outfit_id}) elif isinstance(outfit_data, list): # Direct list of item IDs result.append({"items": [str(x) for x in outfit_data], "outfit_id": outfit_id}) elif isinstance(obj, list): for item in obj: if isinstance(item, dict): if "items" in item: items = item["items"] if isinstance(items, list): if items and isinstance(items[0], dict): # Extract item IDs from dict format item_ids = [] for it in items: item_id = it.get("item_id") or it.get("id") or it.get("itemId") if item_id is not None: item_ids.append(str(item_id)) if item_ids: result.append({"items": item_ids}) else: # Direct list of item IDs result.append({"items": [str(x) for x in items]}) elif "set_id" in item: # Alternative format if "items" in item: items = item["items"] if isinstance(items, list): if items and isinstance(items[0], dict): item_ids = [] for it in items: item_id = it.get("item_id") or it.get("id") or it.get("itemId") if item_id is not None: item_ids.append(str(item_id)) if item_ids: result.append({"items": item_ids}) else: result.append({"items": [str(x) for x in items]}) elif isinstance(item, list): # Direct list of item IDs result.append({"items": [str(x) for x in item]}) return result def load_outfits_json(root: str, split: str) -> List[Dict[str, Any]]: """Try to load outfit data from various possible locations and formats.""" candidates = [ os.path.join(root, f"{split}.json"), os.path.join(root, f"{split}_no_dup.json"), os.path.join(root, "splits", f"{split}.json"), os.path.join(root, "splits", f"{split}_no_dup.json"), # Official Polyvore often ships splits under nondisjoint/ or disjoint/ os.path.join(root, "nondisjoint", f"{split}.json"), os.path.join(root, "disjoint", f"{split}.json"), ] for p in candidates: if os.path.exists(p): try: with open(p, "r") as f: raw = json.load(f) data = _normalize_outfits(raw) if data: print(f"āœ… Loaded {len(data)} outfits from {p}") return data except Exception as e: print(f"āš ļø Failed to load {p}: {e}") continue raise FileNotFoundError(f"Could not find usable {split} split in {root} or {root}/splits") def extract_outfits_from_metadata(root: str) -> List[Dict[str, Any]]: """Extract outfit information from polyvore_item_metadata.json using set_id grouping.""" print("šŸ” Extracting outfits from metadata using set_id grouping...") metadata_path = os.path.join(root, "polyvore_item_metadata.json") if not os.path.exists(metadata_path): print(f"āŒ Metadata file not found: {metadata_path}") return [] try: with open(metadata_path, "r") as f: metadata = json.load(f) if not isinstance(metadata, dict): print("āŒ Metadata is not a dictionary") return [] # Group items by set_id to create outfits outfits_by_set = {} for item_id, item_data in metadata.items(): if isinstance(item_data, dict) and "set_id" in item_data: set_id = item_data["set_id"] if set_id not in outfits_by_set: outfits_by_set[set_id] = [] outfits_by_set[set_id].append(str(item_id)) # Convert to outfit format outfits = [] for set_id, item_ids in outfits_by_set.items(): if len(item_ids) >= 2: # Minimum outfit size outfits.append({ "items": item_ids, "set_id": set_id, "outfit_id": f"set_{set_id}" }) print(f"āœ… Extracted {len(outfits)} outfits from metadata (set_id grouping)") return outfits except Exception as e: print(f"āŒ Failed to parse metadata: {e}") return [] def extract_outfits_from_titles(root: str) -> List[Dict[str, Any]]: """Extract outfit information from polyvore_outfit_titles.json.""" print("šŸ” Extracting outfits from outfit titles...") titles_path = os.path.join(root, "polyvore_outfit_titles.json") if not os.path.exists(titles_path): print(f"āŒ Titles file not found: {titles_path}") return [] try: with open(titles_path, "r") as f: titles = json.load(f) if not isinstance(titles, dict): print("āŒ Titles is not a dictionary") return [] outfits = [] for outfit_id, outfit_data in titles.items(): if isinstance(outfit_data, dict) and "items" in outfit_data: items = outfit_data["items"] if isinstance(items, list) and len(items) >= 2: # Convert all items to strings item_ids = [str(x) for x in items] outfits.append({ "items": item_ids, "outfit_id": outfit_id }) print(f"āœ… Extracted {len(outfits)} outfits from titles") return outfits except Exception as e: print(f"āŒ Failed to parse titles: {e}") return [] def try_load_any_outfits(root: str) -> List[Dict[str, Any]]: """Try to load outfits from any available source, prioritizing official splits.""" merged: List[Dict[str, Any]] = [] # First try official splits (nondisjoint and disjoint) print("šŸ” Looking for official splits...") for split in ["train", "valid", "test"]: try: data = load_outfits_json(root, split) merged.extend(data) print(f"āœ… Found {split} split with {len(data)} outfits") except FileNotFoundError: print(f"āš ļø No {split} split found") continue if merged: print(f"āœ… Total: {len(merged)} outfits from official splits") return merged # If no official splits, try to extract from metadata print("šŸ”§ No official splits found, extracting from metadata...") # Try metadata first (more reliable) outfits = extract_outfits_from_metadata(root) if outfits: return outfits # Try titles as fallback outfits = extract_outfits_from_titles(root) if outfits: return outfits print("āŒ No outfits could be extracted from any source") return [] def collect_all_items(outfits: List[Dict[str, Any]]) -> List[str]: """Collect all unique item IDs from outfits.""" s: Set[str] = set() for o in outfits: for it in o.get("items", []): s.add(str(it)) return sorted(list(s)) def build_triplets(outfits: List[Dict[str, Any]], all_items: List[str], max_triplets: int = 200000) -> List[Dict[str, str]]: """Build training triplets from outfits.""" rng = random.Random(42) all_items_set = set(all_items) triplets: List[Dict[str, str]] = [] for o in outfits: items = [str(i) for i in o.get("items", [])] if len(items) < 2: continue local_set = set(items) for i in range(len(items) - 1): a = items[i] p = items[i + 1] # Pick a negative not in this outfit negatives = list(all_items_set - local_set) if not negatives: continue n = rng.choice(negatives) triplets.append({"anchor": a, "positive": p, "negative": n}) if len(triplets) >= max_triplets: return triplets return triplets def build_outfit_pairs(outfits: List[Dict[str, Any]], num_negatives_per_pos: int = 1) -> List[Dict[str, Any]]: """Build outfit pairs for training.""" rng = random.Random(123) all_items = collect_all_items(outfits) all_set = set(all_items) pairs: List[Dict[str, Any]] = [] # Positive samples for o in outfits: items = [str(i) for i in o.get("items", [])] if len(items) < 2: continue pairs.append({"items": items, "label": 1}) # Negative by corrupting one item for _ in range(num_negatives_per_pos): if not items: continue idx = rng.randrange(len(items)) neg_pool = list(all_set - set(items)) if not neg_pool: continue neg_item = rng.choice(neg_pool) neg_items = items.copy() neg_items[idx] = neg_item pairs.append({"items": neg_items, "label": 0}) return pairs def build_outfit_triplets(outfits: List[Dict[str, Any]], num_triplets: int = 200000) -> List[Dict[str, Any]]: """Build outfit-level triplets for ViT training.""" rng = random.Random(999) # Collect only valid positive outfits (len >= 3) pos = [o for o in outfits if len(o.get("items", [])) >= 3] if len(pos) < 2: print(f"āš ļø Only {len(pos)} valid outfits found, need at least 2 for triplets") return [] all_items = collect_all_items(outfits) all_set = set(all_items) triplets: List[Dict[str, Any]] = [] for _ in range(min(num_triplets, len(pos) * 10)): # Limit based on available outfits if len(pos) < 2: break ga = rng.choice(pos) gb = rng.choice(pos) # Ensure ga != gb if ga is gb: continue # Create bad by corrupting one item in ga items_ga = [str(i) for i in ga.get("items", [])] if not items_ga: continue corrupt_idx = rng.randrange(len(items_ga)) neg_pool = list(all_set - set(items_ga)) if not neg_pool: continue neg_item = rng.choice(neg_pool) bad = items_ga.copy() bad[corrupt_idx] = neg_item triplets.append({ "good_a": items_ga, "good_b": [str(i) for i in gb.get("items", [])], "bad": bad }) return triplets def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--root", type=str, required=True, help="Polyvore dataset root") ap.add_argument("--out", type=str, default=None, help="Output directory for splits (default: /splits)") ap.add_argument("--max_samples", type=int, default=None, help="Maximum number of samples to use (for testing)") ap.add_argument("--max_triplets", type=int, default=200000) ap.add_argument("--neg_per_pos", type=int, default=1) ap.add_argument("--force_random_split", action="store_true", help="Force random split creation (not recommended)") args = ap.parse_args() out_dir = args.out or os.path.join(args.root, "splits") Path(out_dir).mkdir(parents=True, exist_ok=True) print(f"šŸ” Preparing Polyvore dataset from {args.root}") print(f"šŸ“ Output directory: {out_dir}") # Always try to use official splits first splits = {} found_any_official = False print("šŸŽÆ Looking for official splits...") for split in ["train", "valid", "test"]: try: data = load_outfits_json(args.root, split) splits[split] = data if data: found_any_official = True print(f"āœ… Loaded {split} split: {len(data)} outfits") except FileNotFoundError as e: print(f"āš ļø Skipping {split}: {e}") splits[split] = [] if found_any_official: print("šŸŽ‰ Using official splits from dataset!") else: print("āš ļø No official splits found") if args.force_random_split: print("šŸ”§ Creating random split (not recommended for production)...") all_outfits = try_load_any_outfits(args.root) if not all_outfits: print("āŒ No outfits found to split. Please check dataset structure.") print("šŸ“ Expected files:") print(" - train.json, valid.json, test.json") print(" - nondisjoint/train.json, etc.") print(" - polyvore_item_metadata.json") print(" - polyvore_outfit_titles.json") return print(f"šŸŽÆ Creating random split from {len(all_outfits)} outfits") rng = random.Random(2024) rng.shuffle(all_outfits) n = len(all_outfits) n_train = int(0.7 * n) n_valid = int(0.1 * n) splits = { "train": all_outfits[:n_train], "valid": all_outfits[n_train:n_train + n_valid], "test": all_outfits[n_train + n_valid:], } print(f"šŸ“Š Split created: train={n_train}, valid={n_valid}, test={n-n_train-n_valid}") else: print("āŒ Random split creation disabled. Use --force_random_split if needed.") print("šŸ”§ Please ensure official splits are available in nondisjoint/ or disjoint/ folders.") return # Apply dataset size limit if specified if args.max_samples: print(f"šŸŽÆ Limiting dataset to {args.max_samples} samples for testing...") for split in splits: if splits[split]: # Take only the first max_samples outfits splits[split] = splits[split][:args.max_samples] print(f" šŸ“Š {split}: Limited to {len(splits[split])} outfits") # Generate training data for each split for split, outfits in splits.items(): if not outfits: print(f"āš ļø No outfits for {split} split, skipping") continue print(f"\nšŸ”§ Processing {split} split ({len(outfits)} outfits)...") all_items = collect_all_items(outfits) print(f" šŸ“¦ Total unique items: {len(all_items)}") triplets = build_triplets(outfits, all_items, max_triplets=args.max_triplets) print(f" šŸ”— Generated {len(triplets)} item triplets") pairs = build_outfit_pairs(outfits, num_negatives_per_pos=args.neg_per_pos) print(f" šŸ‘• Generated {len(pairs)} outfit pairs") outfit_triplets = build_outfit_triplets(outfits) print(f" šŸŽ­ Generated {len(outfit_triplets)} outfit triplets") # Save files with open(os.path.join(out_dir, f"{split}.json"), "w") as f: json.dump(triplets, f, indent=2) with open(os.path.join(out_dir, f"outfits_{split}.json"), "w") as f: json.dump(pairs, f, indent=2) with open(os.path.join(out_dir, f"outfit_triplets_{split}.json"), "w") as f: json.dump(outfit_triplets, f, indent=2) print(f" šŸ’¾ Saved {split} data to {out_dir}") print(f"\nšŸŽ‰ Dataset preparation complete!") print(f"šŸ“ All files saved to: {out_dir}") if found_any_official: print("āœ… Used official dataset splits - production ready!") else: print("āš ļø Used random splits - not recommended for production") if __name__ == "__main__": main()