import json import sys from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path from typing import Any, Dict, Iterable, List from jinja2 import Template from conv_data_gen.config import config from conv_data_gen.logger import setup_logger from conv_data_gen.generators.structured_use_case.plan_generator import ( generate_plans_for_company, ) # noqa: E402,E401 from conv_data_gen.generators.structured_use_case.narrative_generator import ( generate_narratives, ) # noqa: E402,E401 # ----------------------- Config ----------------------- # Prompt templates (assumed available in your config) PLAN_PROMPT_FILE = config.paths.PLAN_PROMPT NARRATIVE_PROMPT_FILE = config.paths.NARRATIVE_PROMPT # IO INPUT_FACTSHEETS = config.paths.COMPANY_FACTSHEETS_JSON OUT_DIR = config.paths.USE_CASE_OUTPUT_DIR # writes per-company outputs here AGG_OUT = config.paths.USE_CASE_AGG_OUTPUT # aggregated JSON # Runtime knobs MAX_WORKERS = config.generation.USE_CASE_MAX_WORKERS logger = setup_logger(__name__) # -------------------- Utilities ----------------------- def _load_template(path: Path) -> Template: with open(path, "r", encoding="utf-8") as f: return Template(f.read()) # -------------------- Validation ----------------------- """ Validation for plans and narratives now lives in: - plan_generator._validate_plan_item (aligned to plan.j2) - narrative_generator._validate_narrative_item (aligned to narrative.j2) """ # -------------------- Core pipeline ----------------------- def _package_company_output( factsheet: Dict[str, Any], plans: List[Dict[str, Any]], narratives: List[Dict[str, Any]], ) -> Dict[str, Any]: """Assemble final per-company payload with plan + narrative outputs. plans: flat list of plan objects per plan.j2 schema. narratives: flat list where each item has a matching plan_id and nested narrative object. """ name = factsheet.get("name", "Unknown") # Index narratives by plan_id for quick join nar_by_pid = { n["plan_id"]: n for n in narratives if isinstance(n, dict) and "plan_id" in n } # Flatten plans with optional narrative flattened_use_cases: List[Dict[str, Any]] = [] for plan in plans: pid = plan.get("plan_id", "") nar = nar_by_pid.get(pid, {}) agent_type = plan.get("agent_type", "") user_type = plan.get("user_type", "") flattened_use_cases.append( { "plan_id": pid, "agent_type": agent_type, "user_type": user_type, "trigger": plan.get("trigger", ""), "linked_pain_points": plan.get("linked_pain_points", []), "linked_processes": plan.get("linked_processes", []), "linked_policies": plan.get("linked_policies", []), "linked_metrics": plan.get("linked_metrics", []), "business_value": plan.get("business_value", ""), "priority_level": plan.get("priority_level", ""), "conversation_type": plan.get("conversation_type", ""), "complexity_hint": plan.get("complexity_hint", ""), "diversity_level": plan.get("diversity_level", ""), "notes": plan.get("notes", ""), # Narrative (may be missing if LLM dropped one – keep robust) "narrative": nar.get("narrative", {}), # Dedup signature at plan level "dedup_signature": ( f"{agent_type.lower()}|{user_type.lower()}|" f"{str(plan.get('trigger', '')).strip().lower()}" ), } ) pkg = { "company": name, "archetype": factsheet.get("archetype", "unknown_archetype"), "description": factsheet.get("description", ""), "pain_points": factsheet.get("pain_points", []), "lines_of_business": factsheet.get("lines_of_business", []), "processes": factsheet.get("processes", []), "metrics": factsheet.get("metrics", []), "counts": { "plans_total": len(plans), "narratives": len(narratives), }, "plans": plans, "narratives": narratives, "use_cases": flattened_use_cases, } return pkg def _process_company( factsheet: Dict[str, Any], plan_tpl: Template, narrative_tpl: Template ) -> Dict[str, Any]: """End-to-end for a single company: plans → narratives → package.""" name = factsheet.get("name", "Unknown") plans = generate_plans_for_company(factsheet, plan_tpl) logger.info("✓ %s: %d plans", name, len(plans)) narratives = generate_narratives(factsheet, plans, narrative_tpl) logger.info("✓ %s: %d narratives", narratives) return _package_company_output(factsheet, plans, narratives) def generate_plan_and_narrative( factsheets: Iterable[Dict[str, Any]], max_workers: int = MAX_WORKERS, ) -> List[Dict[str, Any]]: """Generate plans, then narratives; return packaged results per company.""" plan_tpl = _load_template(PLAN_PROMPT_FILE) narrative_tpl = _load_template(NARRATIVE_PROMPT_FILE) factsheets_list = list(factsheets) results: List[Dict[str, Any]] = [] with ThreadPoolExecutor(max_workers=max_workers) as ex: future_map = { ex.submit(_process_company, fs, plan_tpl, narrative_tpl): fs.get( "name", "Unknown" ) for fs in factsheets_list } for fut in as_completed(future_map): name = future_map[fut] try: pkg = fut.result() results.append(pkg) logger.info( "✓ %s: packaged %d use_cases", name, len(pkg.get("use_cases", [])), ) except Exception as exc: # pragma: no cover logger.error("✗ %s: %s", name, exc) return results # ---------------------- Main ------------------------ def main(): if not INPUT_FACTSHEETS.exists(): logger.error("Factsheets file not found: %s", INPUT_FACTSHEETS) sys.exit(1) OUT_DIR.mkdir(parents=True, exist_ok=True) factsheets: List[Dict[str, Any]] = json.loads( INPUT_FACTSHEETS.read_text(encoding="utf-8") ) results = generate_plan_and_narrative( factsheets, max_workers=MAX_WORKERS, ) # Write per-company and aggregate artifacts for pkg in results: per_company_path = ( OUT_DIR / f"{pkg['company'].replace(' ', '_')}_plans_narratives.json" ) per_company_path.write_text( json.dumps(pkg, indent=2, ensure_ascii=False), encoding="utf-8" ) logger.info("Saved → %s", per_company_path) AGG_OUT.write_text( json.dumps(results, indent=2, ensure_ascii=False), encoding="utf-8" ) logger.info("Saved aggregated outputs → %s", AGG_OUT) if __name__ == "__main__": main()