recomendation / scripts /prepare_polyvore.py
Ali Mohsin
final new ultra fixes
c150284
import os
import json
import random
import argparse
from pathlib import Path
from typing import Dict, Any, List, Set, Union
def _normalize_outfits(obj: Union[List[Any], Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Normalize various Polyvore JSON formats into a list of {"items": [id,...]} dicts."""
result: List[Dict[str, Any]] = []
if isinstance(obj, dict):
# Handle case where the file contains outfit_id -> outfit_data mapping
for outfit_id, outfit_data in obj.items():
if isinstance(outfit_data, dict):
if "items" in outfit_data:
items = outfit_data["items"]
if isinstance(items, list):
if items and isinstance(items[0], dict):
# Extract item IDs from dict format
item_ids = []
for item in items:
item_id = item.get("item_id") or item.get("id") or item.get("itemId")
if item_id is not None:
item_ids.append(str(item_id))
if item_ids:
result.append({"items": item_ids, "outfit_id": outfit_id})
else:
# Direct list of item IDs
result.append({"items": [str(x) for x in items], "outfit_id": outfit_id})
elif "set_id" in outfit_data:
# Alternative format with set_id
if "items" in outfit_data:
items = outfit_data["items"]
if isinstance(items, list):
if items and isinstance(items[0], dict):
item_ids = []
for item in items:
item_id = item.get("item_id") or item.get("id") or item.get("itemId")
if item_id is not None:
item_ids.append(str(item_id))
if item_ids:
result.append({"items": item_ids, "outfit_id": outfit_id})
else:
result.append({"items": [str(x) for x in items], "outfit_id": outfit_id})
elif isinstance(outfit_data, list):
# Direct list of item IDs
result.append({"items": [str(x) for x in outfit_data], "outfit_id": outfit_id})
elif isinstance(obj, list):
for item in obj:
if isinstance(item, dict):
if "items" in item:
items = item["items"]
if isinstance(items, list):
if items and isinstance(items[0], dict):
# Extract item IDs from dict format
item_ids = []
for it in items:
item_id = it.get("item_id") or it.get("id") or it.get("itemId")
if item_id is not None:
item_ids.append(str(item_id))
if item_ids:
result.append({"items": item_ids})
else:
# Direct list of item IDs
result.append({"items": [str(x) for x in items]})
elif "set_id" in item:
# Alternative format
if "items" in item:
items = item["items"]
if isinstance(items, list):
if items and isinstance(items[0], dict):
item_ids = []
for it in items:
item_id = it.get("item_id") or it.get("id") or it.get("itemId")
if item_id is not None:
item_ids.append(str(item_id))
if item_ids:
result.append({"items": item_ids})
else:
result.append({"items": [str(x) for x in items]})
elif isinstance(item, list):
# Direct list of item IDs
result.append({"items": [str(x) for x in item]})
return result
def load_outfits_json(root: str, split: str) -> List[Dict[str, Any]]:
"""Try to load outfit data from various possible locations and formats."""
candidates = [
os.path.join(root, f"{split}.json"),
os.path.join(root, f"{split}_no_dup.json"),
os.path.join(root, "splits", f"{split}.json"),
os.path.join(root, "splits", f"{split}_no_dup.json"),
# Official Polyvore often ships splits under nondisjoint/ or disjoint/
os.path.join(root, "nondisjoint", f"{split}.json"),
os.path.join(root, "disjoint", f"{split}.json"),
]
for p in candidates:
if os.path.exists(p):
try:
with open(p, "r") as f:
raw = json.load(f)
data = _normalize_outfits(raw)
if data:
print(f"βœ… Loaded {len(data)} outfits from {p}")
return data
except Exception as e:
print(f"⚠️ Failed to load {p}: {e}")
continue
raise FileNotFoundError(f"Could not find usable {split} split in {root} or {root}/splits")
def extract_outfits_from_metadata(root: str) -> List[Dict[str, Any]]:
"""Extract outfit information from polyvore_item_metadata.json using set_id grouping."""
print("πŸ” Extracting outfits from metadata using set_id grouping...")
metadata_path = os.path.join(root, "polyvore_item_metadata.json")
if not os.path.exists(metadata_path):
print(f"❌ Metadata file not found: {metadata_path}")
return []
try:
with open(metadata_path, "r") as f:
metadata = json.load(f)
if not isinstance(metadata, dict):
print("❌ Metadata is not a dictionary")
return []
# Group items by set_id to create outfits
outfits_by_set = {}
for item_id, item_data in metadata.items():
if isinstance(item_data, dict) and "set_id" in item_data:
set_id = item_data["set_id"]
if set_id not in outfits_by_set:
outfits_by_set[set_id] = []
outfits_by_set[set_id].append(str(item_id))
# Convert to outfit format
outfits = []
for set_id, item_ids in outfits_by_set.items():
if len(item_ids) >= 2: # Minimum outfit size
outfits.append({
"items": item_ids,
"set_id": set_id,
"outfit_id": f"set_{set_id}"
})
print(f"βœ… Extracted {len(outfits)} outfits from metadata (set_id grouping)")
return outfits
except Exception as e:
print(f"❌ Failed to parse metadata: {e}")
return []
def extract_outfits_from_titles(root: str) -> List[Dict[str, Any]]:
"""Extract outfit information from polyvore_outfit_titles.json."""
print("πŸ” Extracting outfits from outfit titles...")
titles_path = os.path.join(root, "polyvore_outfit_titles.json")
if not os.path.exists(titles_path):
print(f"❌ Titles file not found: {titles_path}")
return []
try:
with open(titles_path, "r") as f:
titles = json.load(f)
if not isinstance(titles, dict):
print("❌ Titles is not a dictionary")
return []
outfits = []
for outfit_id, outfit_data in titles.items():
if isinstance(outfit_data, dict) and "items" in outfit_data:
items = outfit_data["items"]
if isinstance(items, list) and len(items) >= 2:
# Convert all items to strings
item_ids = [str(x) for x in items]
outfits.append({
"items": item_ids,
"outfit_id": outfit_id
})
print(f"βœ… Extracted {len(outfits)} outfits from titles")
return outfits
except Exception as e:
print(f"❌ Failed to parse titles: {e}")
return []
def try_load_any_outfits(root: str) -> List[Dict[str, Any]]:
"""Try to load outfits from any available source, prioritizing official splits."""
merged: List[Dict[str, Any]] = []
# First try official splits (nondisjoint and disjoint)
print("πŸ” Looking for official splits...")
for split in ["train", "valid", "test"]:
try:
data = load_outfits_json(root, split)
merged.extend(data)
print(f"βœ… Found {split} split with {len(data)} outfits")
except FileNotFoundError:
print(f"⚠️ No {split} split found")
continue
if merged:
print(f"βœ… Total: {len(merged)} outfits from official splits")
return merged
# If no official splits, try to extract from metadata
print("πŸ”§ No official splits found, extracting from metadata...")
# Try metadata first (more reliable)
outfits = extract_outfits_from_metadata(root)
if outfits:
return outfits
# Try titles as fallback
outfits = extract_outfits_from_titles(root)
if outfits:
return outfits
print("❌ No outfits could be extracted from any source")
return []
def collect_all_items(outfits: List[Dict[str, Any]]) -> List[str]:
"""Collect all unique item IDs from outfits."""
s: Set[str] = set()
for o in outfits:
for it in o.get("items", []):
s.add(str(it))
return sorted(list(s))
def build_triplets(outfits: List[Dict[str, Any]], all_items: List[str], max_triplets: int = 200000) -> List[Dict[str, str]]:
"""Build training triplets from outfits."""
rng = random.Random(42)
all_items_set = set(all_items)
triplets: List[Dict[str, str]] = []
for o in outfits:
items = [str(i) for i in o.get("items", [])]
if len(items) < 2:
continue
local_set = set(items)
for i in range(len(items) - 1):
a = items[i]
p = items[i + 1]
# Pick a negative not in this outfit
negatives = list(all_items_set - local_set)
if not negatives:
continue
n = rng.choice(negatives)
triplets.append({"anchor": a, "positive": p, "negative": n})
if len(triplets) >= max_triplets:
return triplets
return triplets
def build_outfit_pairs(outfits: List[Dict[str, Any]], num_negatives_per_pos: int = 1) -> List[Dict[str, Any]]:
"""Build outfit pairs for training."""
rng = random.Random(123)
all_items = collect_all_items(outfits)
all_set = set(all_items)
pairs: List[Dict[str, Any]] = []
# Positive samples
for o in outfits:
items = [str(i) for i in o.get("items", [])]
if len(items) < 2:
continue
pairs.append({"items": items, "label": 1})
# Negative by corrupting one item
for _ in range(num_negatives_per_pos):
if not items:
continue
idx = rng.randrange(len(items))
neg_pool = list(all_set - set(items))
if not neg_pool:
continue
neg_item = rng.choice(neg_pool)
neg_items = items.copy()
neg_items[idx] = neg_item
pairs.append({"items": neg_items, "label": 0})
return pairs
def build_outfit_triplets(outfits: List[Dict[str, Any]], num_triplets: int = 200000) -> List[Dict[str, Any]]:
"""Build outfit-level triplets for ViT training."""
rng = random.Random(999)
# Collect only valid positive outfits (len >= 3)
pos = [o for o in outfits if len(o.get("items", [])) >= 3]
if len(pos) < 2:
print(f"⚠️ Only {len(pos)} valid outfits found, need at least 2 for triplets")
return []
all_items = collect_all_items(outfits)
all_set = set(all_items)
triplets: List[Dict[str, Any]] = []
for _ in range(min(num_triplets, len(pos) * 10)): # Limit based on available outfits
if len(pos) < 2:
break
ga = rng.choice(pos)
gb = rng.choice(pos)
# Ensure ga != gb
if ga is gb:
continue
# Create bad by corrupting one item in ga
items_ga = [str(i) for i in ga.get("items", [])]
if not items_ga:
continue
corrupt_idx = rng.randrange(len(items_ga))
neg_pool = list(all_set - set(items_ga))
if not neg_pool:
continue
neg_item = rng.choice(neg_pool)
bad = items_ga.copy()
bad[corrupt_idx] = neg_item
triplets.append({
"good_a": items_ga,
"good_b": [str(i) for i in gb.get("items", [])],
"bad": bad
})
return triplets
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--root", type=str, required=True, help="Polyvore dataset root")
ap.add_argument("--out", type=str, default=None, help="Output directory for splits (default: <root>/splits)")
ap.add_argument("--max_samples", type=int, default=None, help="Maximum number of samples to use (for testing)")
ap.add_argument("--max_triplets", type=int, default=200000)
ap.add_argument("--neg_per_pos", type=int, default=1)
ap.add_argument("--force_random_split", action="store_true", help="Force random split creation (not recommended)")
args = ap.parse_args()
out_dir = args.out or os.path.join(args.root, "splits")
Path(out_dir).mkdir(parents=True, exist_ok=True)
print(f"πŸ” Preparing Polyvore dataset from {args.root}")
print(f"πŸ“ Output directory: {out_dir}")
# Always try to use official splits first
splits = {}
found_any_official = False
print("🎯 Looking for official splits...")
for split in ["train", "valid", "test"]:
try:
data = load_outfits_json(args.root, split)
splits[split] = data
if data:
found_any_official = True
print(f"βœ… Loaded {split} split: {len(data)} outfits")
except FileNotFoundError as e:
print(f"⚠️ Skipping {split}: {e}")
splits[split] = []
if found_any_official:
print("πŸŽ‰ Using official splits from dataset!")
else:
print("⚠️ No official splits found")
if args.force_random_split:
print("πŸ”§ Creating random split (not recommended for production)...")
all_outfits = try_load_any_outfits(args.root)
if not all_outfits:
print("❌ No outfits found to split. Please check dataset structure.")
print("πŸ“ Expected files:")
print(" - train.json, valid.json, test.json")
print(" - nondisjoint/train.json, etc.")
print(" - polyvore_item_metadata.json")
print(" - polyvore_outfit_titles.json")
return
print(f"🎯 Creating random split from {len(all_outfits)} outfits")
rng = random.Random(2024)
rng.shuffle(all_outfits)
n = len(all_outfits)
n_train = int(0.7 * n)
n_valid = int(0.1 * n)
splits = {
"train": all_outfits[:n_train],
"valid": all_outfits[n_train:n_train + n_valid],
"test": all_outfits[n_train + n_valid:],
}
print(f"πŸ“Š Split created: train={n_train}, valid={n_valid}, test={n-n_train-n_valid}")
else:
print("❌ Random split creation disabled. Use --force_random_split if needed.")
print("πŸ”§ Please ensure official splits are available in nondisjoint/ or disjoint/ folders.")
return
# Apply dataset size limit if specified
if args.max_samples:
print(f"🎯 Limiting dataset to {args.max_samples} samples for testing...")
for split in splits:
if splits[split]:
# Take only the first max_samples outfits
splits[split] = splits[split][:args.max_samples]
print(f" πŸ“Š {split}: Limited to {len(splits[split])} outfits")
# Generate training data for each split
for split, outfits in splits.items():
if not outfits:
print(f"⚠️ No outfits for {split} split, skipping")
continue
print(f"\nπŸ”§ Processing {split} split ({len(outfits)} outfits)...")
all_items = collect_all_items(outfits)
print(f" πŸ“¦ Total unique items: {len(all_items)}")
triplets = build_triplets(outfits, all_items, max_triplets=args.max_triplets)
print(f" πŸ”— Generated {len(triplets)} item triplets")
pairs = build_outfit_pairs(outfits, num_negatives_per_pos=args.neg_per_pos)
print(f" πŸ‘• Generated {len(pairs)} outfit pairs")
outfit_triplets = build_outfit_triplets(outfits)
print(f" 🎭 Generated {len(outfit_triplets)} outfit triplets")
# Save files
with open(os.path.join(out_dir, f"{split}.json"), "w") as f:
json.dump(triplets, f, indent=2)
with open(os.path.join(out_dir, f"outfits_{split}.json"), "w") as f:
json.dump(pairs, f, indent=2)
with open(os.path.join(out_dir, f"outfit_triplets_{split}.json"), "w") as f:
json.dump(outfit_triplets, f, indent=2)
print(f" πŸ’Ύ Saved {split} data to {out_dir}")
print(f"\nπŸŽ‰ Dataset preparation complete!")
print(f"πŸ“ All files saved to: {out_dir}")
if found_any_official:
print("βœ… Used official dataset splits - production ready!")
else:
print("⚠️ Used random splits - not recommended for production")
if __name__ == "__main__":
main()