data-gen / conv_data_gen /generators /user /persona_randomizer.py
ashish-sarvam's picture
Upload folder using huggingface_hub
fc1a684 verified
from __future__ import annotations
import random
from typing import Any, Dict, List, Tuple
class UserPersonaRandomizer:
"""
Encapsulates all randomness for sampling persona attributes
from YAML knobs.
Exposes a single sampler that returns the six persona sections
as dicts so that the caller can construct its own data structures.
"""
def __init__(self, knobs_yaml: Dict[str, Any]) -> None:
self.knobs_yaml = knobs_yaml or {}
@staticmethod
def _sample_from_map_or_list(options: Any) -> Tuple[str, str]:
"""
Returns (choice, description). Supports:
- list of options (no description)
- dict option->description
- scalar value (stringified, no description)
"""
if isinstance(options, list):
choice = random.choice(options) if options else ""
return str(choice), ""
if isinstance(options, dict):
keys = list(options.keys())
if not keys:
return "", ""
choice = random.choice(keys)
return str(choice), str(options.get(choice, ""))
return str(options), ""
def _sample_knobs(self) -> Tuple[Dict[str, str], Dict[str, str]]:
knobs_pool = self.knobs_yaml.get("knobs", {})
selected: Dict[str, str] = {}
descs: Dict[str, str] = {}
for name, opts in knobs_pool.items():
val, desc = self._sample_from_map_or_list(opts)
selected[name] = val
descs[name] = desc
return selected, descs
def _sample_language(self) -> Tuple[Dict[str, str], Dict[str, str]]:
language_cfg = self.knobs_yaml.get("language", {})
codes = language_cfg.get("codes", {}) or {}
if isinstance(codes, dict) and codes:
lang_code = random.choice(list(codes.keys()))
lang_desc = str(codes.get(lang_code, ""))
else:
lang_code, lang_desc = "en-IN", ""
reg_map = language_cfg.get("regionalisms", {})
reg_list = list(reg_map.get(lang_code, []))
sampled_regs = ", ".join(
random.sample(reg_list, k=min(2, len(reg_list)))
)
direct_choice, direct_desc = self._sample_from_map_or_list(
language_cfg.get("directness", {}) or {}
)
form_choice, form_desc = self._sample_from_map_or_list(
language_cfg.get("formality", {}) or {}
)
cs_choice, cs_desc = self._sample_from_map_or_list(
language_cfg.get("code_switch_ratio", {}) or {}
)
lang = {
"language": lang_code,
"directness": direct_choice or "neutral",
"formality": form_choice or "medium",
"code_switch_ratio": str(cs_choice or "0.2"),
"regionalisms": sampled_regs,
}
ldesc = {
"language_desc": lang_desc,
"directness_desc": direct_desc,
"formality_desc": form_desc,
"code_switch_desc": cs_desc,
}
return lang, ldesc
def _sample_demographics(self) -> Tuple[Dict[str, str], Dict[str, str]]:
demographics_pool = self.knobs_yaml.get("demographics", {})
selected: Dict[str, str] = {}
descs: Dict[str, str] = {}
for name, opts in demographics_pool.items():
val, desc = self._sample_from_map_or_list(opts)
selected[name] = val
descs[name] = desc
return selected, descs
def sample_persona(self) -> Dict[str, Dict[str, str]]:
"""
Returns a dict containing 6 sections:
- knobs, knob_descriptions
- language_style, language_descriptions
- demographics, demographic_descriptions
- interaction
"""
knobs, knob_descs = self._sample_knobs()
language_style, language_descriptions = self._sample_language()
demographics, demo_descs = self._sample_demographics()
interaction_cfg = self.knobs_yaml.get("interaction", {})
tiers: List[Dict[str, Any]] = interaction_cfg.get(
"complexity_tier", []
)
chosen: Dict[str, Any] = (
random.choice(tiers) if isinstance(tiers, list) and tiers else {}
)
turn_range = chosen.get("turn_range", []) or []
tool_calls_budget = chosen.get("tool_calls_budget", []) or []
kb_queries_budget = chosen.get("kb_queries_budget", []) or []
interaction = {
"tier_name": chosen.get("name", ""),
"turn_min": str(turn_range[0]) if len(turn_range) > 0 else "",
"turn_max": str(turn_range[1]) if len(turn_range) > 1 else "",
"tool_calls_min": (
str(tool_calls_budget[0]) if len(tool_calls_budget) > 0 else ""
),
"tool_calls_max": (
str(tool_calls_budget[1]) if len(tool_calls_budget) > 1 else ""
),
"kb_queries_min": (
str(kb_queries_budget[0]) if len(kb_queries_budget) > 0 else ""
),
"kb_queries_max": (
str(kb_queries_budget[1]) if len(kb_queries_budget) > 1 else ""
),
}
result = {
"knobs": knobs,
"knob_descriptions": knob_descs,
"language_style": language_style,
"language_descriptions": language_descriptions,
"demographics": demographics,
"demographic_descriptions": demo_descs,
"interaction": interaction,
}
return result