ashish-sarvam's picture
Upload folder using huggingface_hub
fc1a684 verified
import json
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Dict, Iterable, List
from jinja2 import Template
from conv_data_gen.config import config
from conv_data_gen.logger import setup_logger
from conv_data_gen.generators.structured_use_case.plan_generator import (
generate_plans_for_company,
) # noqa: E402,E401
from conv_data_gen.generators.structured_use_case.narrative_generator import (
generate_narratives,
) # noqa: E402,E401
# ----------------------- Config -----------------------
# Prompt templates (assumed available in your config)
PLAN_PROMPT_FILE = config.paths.PLAN_PROMPT
NARRATIVE_PROMPT_FILE = config.paths.NARRATIVE_PROMPT
# IO
INPUT_FACTSHEETS = config.paths.COMPANY_FACTSHEETS_JSON
OUT_DIR = config.paths.USE_CASE_OUTPUT_DIR # writes per-company outputs here
AGG_OUT = config.paths.USE_CASE_AGG_OUTPUT # aggregated JSON
# Runtime knobs
MAX_WORKERS = config.generation.USE_CASE_MAX_WORKERS
logger = setup_logger(__name__)
# -------------------- Utilities -----------------------
def _load_template(path: Path) -> Template:
with open(path, "r", encoding="utf-8") as f:
return Template(f.read())
# -------------------- Validation -----------------------
"""
Validation for plans and narratives now lives in:
- plan_generator._validate_plan_item (aligned to plan.j2)
- narrative_generator._validate_narrative_item (aligned to narrative.j2)
"""
# -------------------- Core pipeline -----------------------
def _package_company_output(
factsheet: Dict[str, Any],
plans: List[Dict[str, Any]],
narratives: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""Assemble final per-company payload with plan + narrative outputs.
plans: flat list of plan objects per plan.j2 schema.
narratives: flat list where each item has a matching plan_id and
nested narrative object.
"""
name = factsheet.get("name", "Unknown")
# Index narratives by plan_id for quick join
nar_by_pid = {
n["plan_id"]: n
for n in narratives
if isinstance(n, dict) and "plan_id" in n
}
# Flatten plans with optional narrative
flattened_use_cases: List[Dict[str, Any]] = []
for plan in plans:
pid = plan.get("plan_id", "")
nar = nar_by_pid.get(pid, {})
agent_type = plan.get("agent_type", "")
user_type = plan.get("user_type", "")
flattened_use_cases.append(
{
"plan_id": pid,
"agent_type": agent_type,
"user_type": user_type,
"trigger": plan.get("trigger", ""),
"linked_pain_points": plan.get("linked_pain_points", []),
"linked_processes": plan.get("linked_processes", []),
"linked_policies": plan.get("linked_policies", []),
"linked_metrics": plan.get("linked_metrics", []),
"business_value": plan.get("business_value", ""),
"priority_level": plan.get("priority_level", ""),
"conversation_type": plan.get("conversation_type", ""),
"complexity_hint": plan.get("complexity_hint", ""),
"diversity_level": plan.get("diversity_level", ""),
"notes": plan.get("notes", ""),
# Narrative (may be missing if LLM dropped one – keep robust)
"narrative": nar.get("narrative", {}),
# Dedup signature at plan level
"dedup_signature": (
f"{agent_type.lower()}|{user_type.lower()}|"
f"{str(plan.get('trigger', '')).strip().lower()}"
),
}
)
pkg = {
"company": name,
"archetype": factsheet.get("archetype", "unknown_archetype"),
"description": factsheet.get("description", ""),
"pain_points": factsheet.get("pain_points", []),
"lines_of_business": factsheet.get("lines_of_business", []),
"processes": factsheet.get("processes", []),
"metrics": factsheet.get("metrics", []),
"counts": {
"plans_total": len(plans),
"narratives": len(narratives),
},
"plans": plans,
"narratives": narratives,
"use_cases": flattened_use_cases,
}
return pkg
def _process_company(
factsheet: Dict[str, Any], plan_tpl: Template, narrative_tpl: Template
) -> Dict[str, Any]:
"""End-to-end for a single company: plans β†’ narratives β†’ package."""
name = factsheet.get("name", "Unknown")
plans = generate_plans_for_company(factsheet, plan_tpl)
logger.info("βœ“ %s: %d plans", name, len(plans))
narratives = generate_narratives(factsheet, plans, narrative_tpl)
logger.info("βœ“ %s: %d narratives", narratives)
return _package_company_output(factsheet, plans, narratives)
def generate_plan_and_narrative(
factsheets: Iterable[Dict[str, Any]],
max_workers: int = MAX_WORKERS,
) -> List[Dict[str, Any]]:
"""Generate plans, then narratives; return packaged results per company."""
plan_tpl = _load_template(PLAN_PROMPT_FILE)
narrative_tpl = _load_template(NARRATIVE_PROMPT_FILE)
factsheets_list = list(factsheets)
results: List[Dict[str, Any]] = []
with ThreadPoolExecutor(max_workers=max_workers) as ex:
future_map = {
ex.submit(_process_company, fs, plan_tpl, narrative_tpl): fs.get(
"name", "Unknown"
)
for fs in factsheets_list
}
for fut in as_completed(future_map):
name = future_map[fut]
try:
pkg = fut.result()
results.append(pkg)
logger.info(
"βœ“ %s: packaged %d use_cases",
name,
len(pkg.get("use_cases", [])),
)
except Exception as exc: # pragma: no cover
logger.error("βœ— %s: %s", name, exc)
return results
# ---------------------- Main ------------------------
def main():
if not INPUT_FACTSHEETS.exists():
logger.error("Factsheets file not found: %s", INPUT_FACTSHEETS)
sys.exit(1)
OUT_DIR.mkdir(parents=True, exist_ok=True)
factsheets: List[Dict[str, Any]] = json.loads(
INPUT_FACTSHEETS.read_text(encoding="utf-8")
)
results = generate_plan_and_narrative(
factsheets,
max_workers=MAX_WORKERS,
)
# Write per-company and aggregate artifacts
for pkg in results:
per_company_path = (
OUT_DIR
/ f"{pkg['company'].replace(' ', '_')}_plans_narratives.json"
)
per_company_path.write_text(
json.dumps(pkg, indent=2, ensure_ascii=False), encoding="utf-8"
)
logger.info("Saved β†’ %s", per_company_path)
AGG_OUT.write_text(
json.dumps(results, indent=2, ensure_ascii=False), encoding="utf-8"
)
logger.info("Saved aggregated outputs β†’ %s", AGG_OUT)
if __name__ == "__main__":
main()