ashish-sarvam's picture
Upload folder using huggingface_hub
fc1a684 verified
from __future__ import annotations
import hashlib
import json
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from conv_data_gen.generators.conversation.models import SimContext
from conv_data_gen.logger import setup_logger
from .core_simulator import ConversationSimulatorWithTools
from concurrent.futures import ThreadPoolExecutor, as_completed
from conv_data_gen.config import config
logger = setup_logger(__name__)
def _read_jsonl(path: str) -> List[Dict[str, Any]]:
rows: List[Dict[str, Any]] = []
p = Path(path)
with p.open("r", encoding="utf-8") as f:
for line in f:
try:
obj = json.loads(line)
if isinstance(obj, dict):
rows.append({str(k): v for k, v in obj.items()})
except Exception:
continue
return rows
def _key(r: Dict[str, Any]) -> str:
return "::".join(
[
str(r.get("company", "")),
str(r.get("agent_type", "")),
str(r.get("use_case", "")),
]
)
def _get_conversation_direction(bundle_data: Optional[Dict[str, Any]]) -> str:
"""Get call direction from bundle data."""
try:
if bundle_data:
conversation_direction = bundle_data.get(
"conversation_direction", {}
)
if conversation_direction:
logger.info(
f"Call direction from bundle_data: {conversation_direction}" # noqa: E501
)
return conversation_direction
return "user_to_agent"
except Exception:
return "user_to_agent"
def _parse_placeholders(text: str) -> List[str]:
"""Parse template placeholders from text."""
try:
return list(
dict.fromkeys(
re.findall(
r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}",
text or "",
)
).keys()
)
except Exception:
return []
def _collect_template_variables(
bot_prompt: str, persona_data: Optional[Dict[str, Any]]
) -> Dict[str, str]:
"""Collect template variables from persona data."""
try:
if not persona_data:
return {}
target_vars = _parse_placeholders(bot_prompt)
if not target_vars:
return {}
persona_memories = str(persona_data.get("memories", ""))
agent_variables = persona_data.get("agent_variables", {})
# agent_variables = get_agent_variables(persona_data)
if isinstance(agent_variables, dict):
persona_values: Dict[str, str] = {}
for v in target_vars:
persona_values[v] = str(agent_variables.get(v, ""))
if not persona_values.get("memory"):
persona_values["memory"] = ""
persona_values["memory"] = persona_memories
return persona_values
return {}
except Exception:
return {}
def _render_template(text: str, values: Dict[str, str]) -> str:
"""Render template variables in text."""
from jinja2 import Environment, select_autoescape
try:
env = Environment(
autoescape=select_autoescape(disabled_extensions=(".j2",))
)
template = env.from_string(text or "")
return template.render(**values)
except Exception:
out = text or ""
for k, v in (values or {}).items():
try:
pattern = rf"\{{\{{\s*{re.escape(k)}\s*\}}\}}"
out = re.sub(pattern, v or "", out)
except Exception:
pass
return out
def _get_max_turns(persona_data: Optional[Dict[str, Any]]) -> Optional[int]:
"""Get max turns from persona data."""
try:
if not persona_data:
return None
interaction = persona_data.get("interaction", {})
if isinstance(interaction, dict):
turn_max_str = interaction.get("turn_max", "")
if turn_max_str and str(turn_max_str).isdigit():
return int(turn_max_str)
return None
except Exception:
return None
def run_conversations_from_artifacts(
bundles_uri: str,
personas_uri: str,
output_dir: Path,
*,
max_randomizer_usage: int,
) -> Tuple[Dict[str, Any], List[str]]:
bundles = _read_jsonl(bundles_uri)
personas = _read_jsonl(personas_uri)
output_dir.mkdir(parents=True, exist_ok=True)
tmp_tools_dir = output_dir / "_tools"
tmp_tools_dir.mkdir(parents=True, exist_ok=True)
logger.info(
"[09-conv] Loaded bundles=%d personas=%d",
len(bundles),
len(personas),
)
bot_map: Dict[str, Dict[str, Any]] = {}
for b in bundles:
bot_map[_key(b)] = b
tasks: List[Tuple[int, Dict[str, Any], Dict[str, Any]]] = []
for idx, p in enumerate(personas, start=1):
key = _key(p)
bot = bot_map.get(key)
if bot:
tasks.append((idx, p, bot))
# logger.info(f"tasks: {tasks}")
if not tasks:
metrics = {
"bundles": len(bundles),
"personas": len(personas),
"simulated": 0,
"successes": 0,
}
return metrics, []
configured = getattr(config.concurrency, "CONVERSATIONS_MAX_WORKERS", None)
default_workers = getattr(config.concurrency, "DEFAULT_MAX_WORKERS", 4)
desired = (
configured if (configured and configured > 0) else default_workers
)
max_workers = min(desired, max(1, len(tasks)))
logger.info(
"[09-conv] Starting simulations with max_workers=%d (tasks=%d)",
max_workers,
len(tasks),
)
def _run_one(
idx: int, p: Dict[str, Any], bot: Dict[str, Any]
) -> Tuple[int, bool, str]:
key = _key(p)
key_hash = hashlib.md5(key.encode("utf-8")).hexdigest()[:8]
bot_prompt = str(bot.get("bot_prompt", ""))
user_prompt = str(p.get("persona_text", ""))
template_vars = _collect_template_variables(bot_prompt, p)
if template_vars:
bot_prompt = _render_template(bot_prompt, template_vars)
if template_vars and "memory" in template_vars:
memory = template_vars["memory"]
conversation_direction = _get_conversation_direction(bot)
max_turns = _get_max_turns(p)
tools = bot.get("tools") if isinstance(bot.get("tools"), list) else []
kbs = (
bot.get("knowledge_bases")
if isinstance(bot.get("knowledge_bases"), list)
else []
)
tools_json = tmp_tools_dir / f"tools_{idx:04d}.json"
try:
tools_json.write_text(
json.dumps(
{"tools": tools, "knowledge_bases": kbs},
ensure_ascii=False,
),
encoding="utf-8",
)
except Exception as e:
logger.warning(
"[09-conv] Failed to write tools JSON for idx=%d: %s", idx, e
)
base = f"{p.get('company', '')}_{key_hash}_p{idx:04d}"
persona_company = str(p.get("company", ""))
persona_agent_type = str(p.get("agent_type", ""))
persona_user_type = str(p.get("user_type", ""))
persona_use_case = str(p.get("use_case", ""))
conversation_metadata = {
"company": persona_company,
"agent_type": persona_agent_type,
"user_type": persona_user_type,
"use_case": persona_use_case,
"persona_id": f"p{idx:04d}",
"key_hash": key_hash,
}
sim_context = SimContext(
company=persona_company,
agent_type=persona_agent_type,
use_case=persona_use_case,
user_id=f"p{idx:04d}",
user_type=persona_user_type,
)
try:
sim = ConversationSimulatorWithTools(
user_prompt_text=user_prompt,
bot_prompt_text=bot_prompt,
sim_context=sim_context,
max_randomizer_usage=max_randomizer_usage,
conversation_direction=conversation_direction,
max_turns=max_turns,
llm_client=None,
tools_json_path=str(tools_json),
output_dir=str(output_dir),
base_filename=base,
memory=memory or {},
)
sim._conversation_metadata = conversation_metadata
tr = sim.generate()
label = f"{base}: {len(tr)} turns" if tr else f"{base}: 0"
return idx, bool(tr), label
except Exception as e:
logger.warning(
"[09-conv] Simulation failed for %s (idx=%d): %s", base, idx, e
)
return idx, False, f"{base}: 0"
successes = 0
total_runs = 0
summaries_by_idx: Dict[int, str] = {}
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(_run_one, idx, p, bot): idx
for (idx, p, bot) in tasks
}
for fut in as_completed(futures):
idx, ok, label = fut.result()
summaries_by_idx[idx] = label
total_runs += 1
if ok:
successes += 1
summaries = [summaries_by_idx[i] for i in sorted(summaries_by_idx.keys())]
metrics = {
"bundles": len(bundles),
"personas": len(personas),
"simulated": total_runs,
"successes": successes,
}
return metrics, summaries