#!/usr/bin/env python3 """ Print full expert + brief text from one AutoDataLab++ episode (Kaggle / local). Does not emphasize scores — use for demos, reports, and qualitative inspection. python3 -m training.kaggle_agent_answers --task expert_brief --use-rag If run from Kaggle, clone the repo and `pip install -e .` first, then: !python3 training/kaggle_agent_answers.py --task risk_brief """ from __future__ import annotations import argparse import json import sys from pathlib import Path REPO = Path(__file__).resolve().parents[2] if str(REPO) not in sys.path: sys.path.insert(0, str(REPO)) from ceo_brief_env.environment import oracle_action_for_observation from inference import _roundrobin_baseline, _single_baseline, _trained_action, run_episode_collect def _line(title: str, body: str, width: int = 88) -> str: sep = "=" * width return f"\n{sep}\n{title}\n{sep}\n{body.rstrip()}\n" def format_episode_answers(data: dict, *, show_scores: bool = False) -> str: """Format expert reports and CEO brief for human reading.""" out: list[str] = [] out.append( f"Task: {data.get('task')}\n" f"Policy: {data.get('policy_label')}\n" f"RAG: {data.get('use_rag')}\n" ) if show_scores: out.append( f"Terminal score (grader): {data.get('terminal_score')} " f"success: {data.get('success')}\n" ) inst = data.get("final_instruction") or "" if inst: out.append(_line("INSTRUCTION (from metadata)", inst)) reports: dict = data.get("expert_reports") or {} order = ("analyst", "finance", "strategy", "hr") labels = { "analyst": "DATA ANALYST", "finance": "FINANCE", "strategy": "STRATEGIST", "hr": "HR / COMMS", } for eid in order: r = reports.get(eid) if not r: continue title = f"{labels.get(eid, eid).upper()} — {r.get('title', eid)}" chunks = [r.get("summary", "").strip() or "(no summary)"] bps = r.get("bullet_points") or [] if bps: chunks.append("\nBullets:\n" + "\n".join(f" • {b}" for b in bps)) issues = r.get("issues") or [] if issues and issues != ["(none)"]: chunks.append("\nIssues:\n" + "\n".join(f" - {i}" for i in issues)) m_c = r.get("memory_citations") or [] m_s = r.get("memory_snippets") or [] n = min(len(m_c), len(m_s), 5) if n: tape = [f"\nTape & citations (from strategist / RAG) — first {n}:"] for i in range(n): tape.append(f" [{m_c[i]}] {m_s[i][:500]}{'…' if len(m_s[i]) > 500 else ''}") chunks.append("\n".join(tape)) if eid == "hr" and r.get("memo"): chunks.append("\nHR memo field:\n" + str(r["memo"])) out.append(_line(title, "\n".join(chunks))) brief = data.get("current_brief") if brief: parts = [brief.get("summary", "") or ""] recs = brief.get("recommendations") or [] if recs: parts.append("\nRecommendations:\n" + "\n".join(f" • {x}" for x in recs)) m = brief.get("hr_memo") if m: parts.append(f"\nHR memo (in brief object):\n{m}") cons = brief.get("consulted_experts") if cons is not None: parts.append(f"\nConsulted in brief object: {', '.join(cons)}") out.append(_line("COMPOSED BRIEF (to CEO — merged from reports)", "\n".join(parts))) return "\n".join(out) def main() -> int: p = argparse.ArgumentParser() p.add_argument("--task", default="expert_brief", help="e.g. expert_brief, risk_brief, crisis_brief, easy_brief") p.add_argument("--use-rag", action="store_true", help="enable RAG in experts + grader grounding") p.add_argument( "--policy", choices=("oracle", "single", "roundrobin", "trained"), default="oracle", help="oracle = consult all required experts in order; best for a full 'office' readout", ) p.add_argument("--json-out", type=Path, help="write full run_episode_collect() dict to this path") p.add_argument("--show-scores", action="store_true", help="include terminal grader line in header") args = p.parse_args() if args.policy == "oracle": picker, label = oracle_action_for_observation, "oracle" elif args.policy == "single": picker, label = _single_baseline, "single-baseline" elif args.policy == "roundrobin": picker, label = _roundrobin_baseline, "roundrobin-baseline" else: picker, label = _trained_action, "trained-cos" data = run_episode_collect(args.task, picker, label, use_rag=bool(args.use_rag), quiet=True) print(format_episode_answers(data, show_scores=bool(args.show_scores))) if args.json_out: args.json_out.write_text(json.dumps(data, indent=2, default=str), encoding="utf-8") print(f"\n[written] {args.json_out}", file=sys.stderr) return 0 if __name__ == "__main__": raise SystemExit(main())