File size: 5,930 Bytes
9415028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74851ea
 
9415028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4bb654
9415028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""
Load, filter, and convert personas from the Nemotron-Personas-USA dataset.

Generic loader — filters and field mapping are configurable via CLI args or
as a library. Returns a list of evaluator-ready profile dicts.

Usage:
    # Filter by any combination of fields
    uv run python scripts/persona_loader.py \
      --filters '{"sex": "Female", "state": "IL", "age_min": 25, "age_max": 50}' \
      --limit 100 \
      --output data/filtered.json

    # As a library
    from persona_loader import load_personas, filter_personas, to_profile
"""

import json
import random
import argparse
from pathlib import Path
from datasets import load_from_disk

PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_DATA_DIR = PROJECT_ROOT / "data" / "nemotron"

# All narrative fields in the dataset, in order of richness
NARRATIVE_FIELDS = [
    "persona", "cultural_background", "professional_persona",
    "career_goals_and_ambitions", "hobbies_and_interests",
    "sports_persona", "arts_persona", "travel_persona", "culinary_persona",
    "skills_and_expertise",
]


def load_personas(data_dir=None):
    """Load dataset from disk. Run setup_data.py first if not cached."""
    data_dir = Path(data_dir or DEFAULT_DATA_DIR)
    if not (data_dir / "dataset_info.json").exists():
        raise FileNotFoundError(
            f"Dataset not found at {data_dir}. Run: uv run python scripts/setup_data.py"
        )
    return load_from_disk(str(data_dir))


def filter_personas(ds, filters: dict, limit: int = None, seed: int = 42):
    """
    Filter dataset by arbitrary field conditions.

    Supported filter keys:
        sex, state, city (substring match), age_min, age_max,
        marital_status (list), education_level (list),
        occupation (substring match)

    Any unrecognized key is treated as an exact match on that column.
    """
    random.seed(seed)

    age_min = filters.get("age_min", 0)
    age_max = filters.get("age_max", 200)
    sex = filters.get("sex")
    state = filters.get("state")
    city = filters.get("city")
    marital = filters.get("marital_status")
    education = filters.get("education_level")
    occupation = filters.get("occupation")

    if isinstance(marital, str):
        marital = [marital]
    if isinstance(education, str):
        education = [education]

    def matches(row):
        if sex and row["sex"] != sex:
            return False
        if not (age_min <= row["age"] <= age_max):
            return False
        if state and row["state"] != state:
            return False
        if city and city.lower() not in row["city"].lower():
            return False
        if marital and row["marital_status"] not in marital:
            return False
        if education and row["education_level"] not in education:
            return False
        if occupation and occupation.lower() not in row["occupation"].lower():
            return False
        return True

    filtered = ds.filter(matches, num_proc=4)

    if limit and len(filtered) > limit:
        indices = random.sample(range(len(filtered)), limit)
        filtered = filtered.select(indices)

    return filtered


def build_persona_text(row: dict) -> str:
    """Combine all narrative dimensions into a single rich description."""
    parts = []
    labels = ["", "Background", "Career", "Ambitions", "Hobbies",
              "Sports", "Arts", "Travel", "Food", "Skills"]
    for label, field in zip(labels, NARRATIVE_FIELDS):
        val = row.get(field)
        if val:
            parts.append(f"{label}: {val}" if label else val)
    return " ".join(parts)


def extract_name(row: dict) -> str:
    """Extract name from the first narrative field that starts with a name."""
    for field in NARRATIVE_FIELDS:
        text = row.get(field, "")
        if text:
            words = text.split()
            if len(words) >= 2 and words[0][0].isupper() and words[1][0].isupper():
                return f"{words[0]} {words[1]}".rstrip(",.")
    return "Unknown"


def parse_json_list(raw) -> list:
    try:
        out = json.loads(raw) if isinstance(raw, str) else raw
        return out if isinstance(out, list) else []
    except (json.JSONDecodeError, TypeError):
        return []


def to_profile(row: dict, user_id: int) -> dict:
    """Convert a Nemotron row into a generic evaluator profile dict."""
    name = extract_name(row)
    hobbies = parse_json_list(row.get("hobbies_and_interests_list", "[]"))
    skills = parse_json_list(row.get("skills_and_expertise_list", "[]"))

    return {
        "user_id": user_id,
        "name": name,
        "persona": build_persona_text(row),
        "age": row.get("age", 30),
        "sex": row.get("sex", ""),
        "city": row.get("city", ""),
        "state": row.get("state", ""),
        "country": row.get("country", "USA"),
        "education_level": row.get("education_level", ""),
        "marital_status": row.get("marital_status", ""),
        "occupation": (row.get("occupation") or "").replace("_", " ").title(),
        "interests": hobbies + skills,
        "source_uuid": row.get("uuid", ""),
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--filters", type=json.loads, default={})
    parser.add_argument("--limit", type=int, default=None)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--output", default="data/filtered.json")
    args = parser.parse_args()

    ds = load_personas()
    print(f"Loaded {len(ds)} total personas")

    filtered = filter_personas(ds, args.filters, limit=args.limit, seed=args.seed)
    print(f"Filtered: {len(filtered)} personas")

    profiles = [to_profile(row, i) for i, row in enumerate(filtered)]
    Path(args.output).parent.mkdir(parents=True, exist_ok=True)
    with open(args.output, "w") as f:
        json.dump(profiles, f, ensure_ascii=False, indent=2)
    print(f"Saved to {args.output}")