Spaces:
Runtime error
Runtime error
| import json | |
| import sys | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from pathlib import Path | |
| from typing import Any, Dict, Iterable, List | |
| from jinja2 import Template | |
| from conv_data_gen.config import config | |
| from conv_data_gen.logger import setup_logger | |
| from conv_data_gen.generators.structured_use_case.plan_generator import ( | |
| generate_plans_for_company, | |
| ) # noqa: E402,E401 | |
| from conv_data_gen.generators.structured_use_case.narrative_generator import ( | |
| generate_narratives, | |
| ) # noqa: E402,E401 | |
| # ----------------------- Config ----------------------- | |
| # Prompt templates (assumed available in your config) | |
| PLAN_PROMPT_FILE = config.paths.PLAN_PROMPT | |
| NARRATIVE_PROMPT_FILE = config.paths.NARRATIVE_PROMPT | |
| # IO | |
| INPUT_FACTSHEETS = config.paths.COMPANY_FACTSHEETS_JSON | |
| OUT_DIR = config.paths.USE_CASE_OUTPUT_DIR # writes per-company outputs here | |
| AGG_OUT = config.paths.USE_CASE_AGG_OUTPUT # aggregated JSON | |
| # Runtime knobs | |
| MAX_WORKERS = config.generation.USE_CASE_MAX_WORKERS | |
| logger = setup_logger(__name__) | |
| # -------------------- Utilities ----------------------- | |
| def _load_template(path: Path) -> Template: | |
| with open(path, "r", encoding="utf-8") as f: | |
| return Template(f.read()) | |
| # -------------------- Validation ----------------------- | |
| """ | |
| Validation for plans and narratives now lives in: | |
| - plan_generator._validate_plan_item (aligned to plan.j2) | |
| - narrative_generator._validate_narrative_item (aligned to narrative.j2) | |
| """ | |
| # -------------------- Core pipeline ----------------------- | |
| def _package_company_output( | |
| factsheet: Dict[str, Any], | |
| plans: List[Dict[str, Any]], | |
| narratives: List[Dict[str, Any]], | |
| ) -> Dict[str, Any]: | |
| """Assemble final per-company payload with plan + narrative outputs. | |
| plans: flat list of plan objects per plan.j2 schema. | |
| narratives: flat list where each item has a matching plan_id and | |
| nested narrative object. | |
| """ | |
| name = factsheet.get("name", "Unknown") | |
| # Index narratives by plan_id for quick join | |
| nar_by_pid = { | |
| n["plan_id"]: n | |
| for n in narratives | |
| if isinstance(n, dict) and "plan_id" in n | |
| } | |
| # Flatten plans with optional narrative | |
| flattened_use_cases: List[Dict[str, Any]] = [] | |
| for plan in plans: | |
| pid = plan.get("plan_id", "") | |
| nar = nar_by_pid.get(pid, {}) | |
| agent_type = plan.get("agent_type", "") | |
| user_type = plan.get("user_type", "") | |
| flattened_use_cases.append( | |
| { | |
| "plan_id": pid, | |
| "agent_type": agent_type, | |
| "user_type": user_type, | |
| "trigger": plan.get("trigger", ""), | |
| "linked_pain_points": plan.get("linked_pain_points", []), | |
| "linked_processes": plan.get("linked_processes", []), | |
| "linked_policies": plan.get("linked_policies", []), | |
| "linked_metrics": plan.get("linked_metrics", []), | |
| "business_value": plan.get("business_value", ""), | |
| "priority_level": plan.get("priority_level", ""), | |
| "conversation_type": plan.get("conversation_type", ""), | |
| "complexity_hint": plan.get("complexity_hint", ""), | |
| "diversity_level": plan.get("diversity_level", ""), | |
| "notes": plan.get("notes", ""), | |
| # Narrative (may be missing if LLM dropped one β keep robust) | |
| "narrative": nar.get("narrative", {}), | |
| # Dedup signature at plan level | |
| "dedup_signature": ( | |
| f"{agent_type.lower()}|{user_type.lower()}|" | |
| f"{str(plan.get('trigger', '')).strip().lower()}" | |
| ), | |
| } | |
| ) | |
| pkg = { | |
| "company": name, | |
| "archetype": factsheet.get("archetype", "unknown_archetype"), | |
| "description": factsheet.get("description", ""), | |
| "pain_points": factsheet.get("pain_points", []), | |
| "lines_of_business": factsheet.get("lines_of_business", []), | |
| "processes": factsheet.get("processes", []), | |
| "metrics": factsheet.get("metrics", []), | |
| "counts": { | |
| "plans_total": len(plans), | |
| "narratives": len(narratives), | |
| }, | |
| "plans": plans, | |
| "narratives": narratives, | |
| "use_cases": flattened_use_cases, | |
| } | |
| return pkg | |
| def _process_company( | |
| factsheet: Dict[str, Any], plan_tpl: Template, narrative_tpl: Template | |
| ) -> Dict[str, Any]: | |
| """End-to-end for a single company: plans β narratives β package.""" | |
| name = factsheet.get("name", "Unknown") | |
| plans = generate_plans_for_company(factsheet, plan_tpl) | |
| logger.info("β %s: %d plans", name, len(plans)) | |
| narratives = generate_narratives(factsheet, plans, narrative_tpl) | |
| logger.info("β %s: %d narratives", narratives) | |
| return _package_company_output(factsheet, plans, narratives) | |
| def generate_plan_and_narrative( | |
| factsheets: Iterable[Dict[str, Any]], | |
| max_workers: int = MAX_WORKERS, | |
| ) -> List[Dict[str, Any]]: | |
| """Generate plans, then narratives; return packaged results per company.""" | |
| plan_tpl = _load_template(PLAN_PROMPT_FILE) | |
| narrative_tpl = _load_template(NARRATIVE_PROMPT_FILE) | |
| factsheets_list = list(factsheets) | |
| results: List[Dict[str, Any]] = [] | |
| with ThreadPoolExecutor(max_workers=max_workers) as ex: | |
| future_map = { | |
| ex.submit(_process_company, fs, plan_tpl, narrative_tpl): fs.get( | |
| "name", "Unknown" | |
| ) | |
| for fs in factsheets_list | |
| } | |
| for fut in as_completed(future_map): | |
| name = future_map[fut] | |
| try: | |
| pkg = fut.result() | |
| results.append(pkg) | |
| logger.info( | |
| "β %s: packaged %d use_cases", | |
| name, | |
| len(pkg.get("use_cases", [])), | |
| ) | |
| except Exception as exc: # pragma: no cover | |
| logger.error("β %s: %s", name, exc) | |
| return results | |
| # ---------------------- Main ------------------------ | |
| def main(): | |
| if not INPUT_FACTSHEETS.exists(): | |
| logger.error("Factsheets file not found: %s", INPUT_FACTSHEETS) | |
| sys.exit(1) | |
| OUT_DIR.mkdir(parents=True, exist_ok=True) | |
| factsheets: List[Dict[str, Any]] = json.loads( | |
| INPUT_FACTSHEETS.read_text(encoding="utf-8") | |
| ) | |
| results = generate_plan_and_narrative( | |
| factsheets, | |
| max_workers=MAX_WORKERS, | |
| ) | |
| # Write per-company and aggregate artifacts | |
| for pkg in results: | |
| per_company_path = ( | |
| OUT_DIR | |
| / f"{pkg['company'].replace(' ', '_')}_plans_narratives.json" | |
| ) | |
| per_company_path.write_text( | |
| json.dumps(pkg, indent=2, ensure_ascii=False), encoding="utf-8" | |
| ) | |
| logger.info("Saved β %s", per_company_path) | |
| AGG_OUT.write_text( | |
| json.dumps(results, indent=2, ensure_ascii=False), encoding="utf-8" | |
| ) | |
| logger.info("Saved aggregated outputs β %s", AGG_OUT) | |
| if __name__ == "__main__": | |
| main() | |