File size: 5,930 Bytes
9415028 74851ea 9415028 a4bb654 9415028 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | """
Load, filter, and convert personas from the Nemotron-Personas-USA dataset.
Generic loader — filters and field mapping are configurable via CLI args or
as a library. Returns a list of evaluator-ready profile dicts.
Usage:
# Filter by any combination of fields
uv run python scripts/persona_loader.py \
--filters '{"sex": "Female", "state": "IL", "age_min": 25, "age_max": 50}' \
--limit 100 \
--output data/filtered.json
# As a library
from persona_loader import load_personas, filter_personas, to_profile
"""
import json
import random
import argparse
from pathlib import Path
from datasets import load_from_disk
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_DATA_DIR = PROJECT_ROOT / "data" / "nemotron"
# All narrative fields in the dataset, in order of richness
NARRATIVE_FIELDS = [
"persona", "cultural_background", "professional_persona",
"career_goals_and_ambitions", "hobbies_and_interests",
"sports_persona", "arts_persona", "travel_persona", "culinary_persona",
"skills_and_expertise",
]
def load_personas(data_dir=None):
"""Load dataset from disk. Run setup_data.py first if not cached."""
data_dir = Path(data_dir or DEFAULT_DATA_DIR)
if not (data_dir / "dataset_info.json").exists():
raise FileNotFoundError(
f"Dataset not found at {data_dir}. Run: uv run python scripts/setup_data.py"
)
return load_from_disk(str(data_dir))
def filter_personas(ds, filters: dict, limit: int = None, seed: int = 42):
"""
Filter dataset by arbitrary field conditions.
Supported filter keys:
sex, state, city (substring match), age_min, age_max,
marital_status (list), education_level (list),
occupation (substring match)
Any unrecognized key is treated as an exact match on that column.
"""
random.seed(seed)
age_min = filters.get("age_min", 0)
age_max = filters.get("age_max", 200)
sex = filters.get("sex")
state = filters.get("state")
city = filters.get("city")
marital = filters.get("marital_status")
education = filters.get("education_level")
occupation = filters.get("occupation")
if isinstance(marital, str):
marital = [marital]
if isinstance(education, str):
education = [education]
def matches(row):
if sex and row["sex"] != sex:
return False
if not (age_min <= row["age"] <= age_max):
return False
if state and row["state"] != state:
return False
if city and city.lower() not in row["city"].lower():
return False
if marital and row["marital_status"] not in marital:
return False
if education and row["education_level"] not in education:
return False
if occupation and occupation.lower() not in row["occupation"].lower():
return False
return True
filtered = ds.filter(matches, num_proc=4)
if limit and len(filtered) > limit:
indices = random.sample(range(len(filtered)), limit)
filtered = filtered.select(indices)
return filtered
def build_persona_text(row: dict) -> str:
"""Combine all narrative dimensions into a single rich description."""
parts = []
labels = ["", "Background", "Career", "Ambitions", "Hobbies",
"Sports", "Arts", "Travel", "Food", "Skills"]
for label, field in zip(labels, NARRATIVE_FIELDS):
val = row.get(field)
if val:
parts.append(f"{label}: {val}" if label else val)
return " ".join(parts)
def extract_name(row: dict) -> str:
"""Extract name from the first narrative field that starts with a name."""
for field in NARRATIVE_FIELDS:
text = row.get(field, "")
if text:
words = text.split()
if len(words) >= 2 and words[0][0].isupper() and words[1][0].isupper():
return f"{words[0]} {words[1]}".rstrip(",.")
return "Unknown"
def parse_json_list(raw) -> list:
try:
out = json.loads(raw) if isinstance(raw, str) else raw
return out if isinstance(out, list) else []
except (json.JSONDecodeError, TypeError):
return []
def to_profile(row: dict, user_id: int) -> dict:
"""Convert a Nemotron row into a generic evaluator profile dict."""
name = extract_name(row)
hobbies = parse_json_list(row.get("hobbies_and_interests_list", "[]"))
skills = parse_json_list(row.get("skills_and_expertise_list", "[]"))
return {
"user_id": user_id,
"name": name,
"persona": build_persona_text(row),
"age": row.get("age", 30),
"sex": row.get("sex", ""),
"city": row.get("city", ""),
"state": row.get("state", ""),
"country": row.get("country", "USA"),
"education_level": row.get("education_level", ""),
"marital_status": row.get("marital_status", ""),
"occupation": (row.get("occupation") or "").replace("_", " ").title(),
"interests": hobbies + skills,
"source_uuid": row.get("uuid", ""),
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--filters", type=json.loads, default={})
parser.add_argument("--limit", type=int, default=None)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--output", default="data/filtered.json")
args = parser.parse_args()
ds = load_personas()
print(f"Loaded {len(ds)} total personas")
filtered = filter_personas(ds, args.filters, limit=args.limit, seed=args.seed)
print(f"Filtered: {len(filtered)} personas")
profiles = [to_profile(row, i) for i, row in enumerate(filtered)]
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
with open(args.output, "w") as f:
json.dump(profiles, f, ensure_ascii=False, indent=2)
print(f"Saved to {args.output}")
|