| """ |
| Load, filter, and convert personas from the Nemotron-Personas-USA dataset. |
| |
| Generic loader — filters and field mapping are configurable via CLI args or |
| as a library. Returns a list of evaluator-ready profile dicts. |
| |
| Usage: |
| # Filter by any combination of fields |
| uv run python scripts/persona_loader.py \ |
| --filters '{"sex": "Female", "state": "IL", "age_min": 25, "age_max": 50}' \ |
| --limit 100 \ |
| --output data/filtered.json |
| |
| # As a library |
| from persona_loader import load_personas, filter_personas, to_profile |
| """ |
|
|
| import json |
| import random |
| import argparse |
| from pathlib import Path |
| from datasets import load_from_disk |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| DEFAULT_DATA_DIR = PROJECT_ROOT / "data" / "nemotron" |
|
|
| |
| NARRATIVE_FIELDS = [ |
| "persona", "cultural_background", "professional_persona", |
| "career_goals_and_ambitions", "hobbies_and_interests", |
| "sports_persona", "arts_persona", "travel_persona", "culinary_persona", |
| "skills_and_expertise", |
| ] |
|
|
|
|
| def load_personas(data_dir=None): |
| """Load dataset from disk. Run setup_data.py first if not cached.""" |
| data_dir = Path(data_dir or DEFAULT_DATA_DIR) |
| if not (data_dir / "dataset_info.json").exists(): |
| raise FileNotFoundError( |
| f"Dataset not found at {data_dir}. Run: uv run python scripts/setup_data.py" |
| ) |
| return load_from_disk(str(data_dir)) |
|
|
|
|
| def filter_personas(ds, filters: dict, limit: int = None, seed: int = 42): |
| """ |
| Filter dataset by arbitrary field conditions. |
| |
| Supported filter keys: |
| sex, state, city (substring match), age_min, age_max, |
| marital_status (list), education_level (list), |
| occupation (substring match) |
| |
| Any unrecognized key is treated as an exact match on that column. |
| """ |
| random.seed(seed) |
|
|
| age_min = filters.get("age_min", 0) |
| age_max = filters.get("age_max", 200) |
| sex = filters.get("sex") |
| state = filters.get("state") |
| city = filters.get("city") |
| marital = filters.get("marital_status") |
| education = filters.get("education_level") |
| occupation = filters.get("occupation") |
|
|
| if isinstance(marital, str): |
| marital = [marital] |
| if isinstance(education, str): |
| education = [education] |
|
|
| def matches(row): |
| if sex and row["sex"] != sex: |
| return False |
| if not (age_min <= row["age"] <= age_max): |
| return False |
| if state and row["state"] != state: |
| return False |
| if city and city.lower() not in row["city"].lower(): |
| return False |
| if marital and row["marital_status"] not in marital: |
| return False |
| if education and row["education_level"] not in education: |
| return False |
| if occupation and occupation.lower() not in row["occupation"].lower(): |
| return False |
| return True |
|
|
| filtered = ds.filter(matches, num_proc=4) |
|
|
| if limit and len(filtered) > limit: |
| indices = random.sample(range(len(filtered)), limit) |
| filtered = filtered.select(indices) |
|
|
| return filtered |
|
|
|
|
| def build_persona_text(row: dict) -> str: |
| """Combine all narrative dimensions into a single rich description.""" |
| parts = [] |
| labels = ["", "Background", "Career", "Ambitions", "Hobbies", |
| "Sports", "Arts", "Travel", "Food", "Skills"] |
| for label, field in zip(labels, NARRATIVE_FIELDS): |
| val = row.get(field) |
| if val: |
| parts.append(f"{label}: {val}" if label else val) |
| return " ".join(parts) |
|
|
|
|
| def extract_name(row: dict) -> str: |
| """Extract name from the first narrative field that starts with a name.""" |
| for field in NARRATIVE_FIELDS: |
| text = row.get(field, "") |
| if text: |
| words = text.split() |
| if len(words) >= 2 and words[0][0].isupper() and words[1][0].isupper(): |
| return f"{words[0]} {words[1]}".rstrip(",.") |
| return "Unknown" |
|
|
|
|
| def parse_json_list(raw) -> list: |
| try: |
| out = json.loads(raw) if isinstance(raw, str) else raw |
| return out if isinstance(out, list) else [] |
| except (json.JSONDecodeError, TypeError): |
| return [] |
|
|
|
|
| def to_profile(row: dict, user_id: int) -> dict: |
| """Convert a Nemotron row into a generic evaluator profile dict.""" |
| name = extract_name(row) |
| hobbies = parse_json_list(row.get("hobbies_and_interests_list", "[]")) |
| skills = parse_json_list(row.get("skills_and_expertise_list", "[]")) |
|
|
| return { |
| "user_id": user_id, |
| "name": name, |
| "persona": build_persona_text(row), |
| "age": row.get("age", 30), |
| "sex": row.get("sex", ""), |
| "city": row.get("city", ""), |
| "state": row.get("state", ""), |
| "country": row.get("country", "USA"), |
| "education_level": row.get("education_level", ""), |
| "marital_status": row.get("marital_status", ""), |
| "occupation": (row.get("occupation") or "").replace("_", " ").title(), |
| "interests": hobbies + skills, |
| "source_uuid": row.get("uuid", ""), |
| } |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--filters", type=json.loads, default={}) |
| parser.add_argument("--limit", type=int, default=None) |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--output", default="data/filtered.json") |
| args = parser.parse_args() |
|
|
| ds = load_personas() |
| print(f"Loaded {len(ds)} total personas") |
|
|
| filtered = filter_personas(ds, args.filters, limit=args.limit, seed=args.seed) |
| print(f"Filtered: {len(filtered)} personas") |
|
|
| profiles = [to_profile(row, i) for i, row in enumerate(filtered)] |
| Path(args.output).parent.mkdir(parents=True, exist_ok=True) |
| with open(args.output, "w") as f: |
| json.dump(profiles, f, ensure_ascii=False, indent=2) |
| print(f"Saved to {args.output}") |
|
|