import json import random import re from pathlib import Path from typing import Callable, Dict, List from counsel_env.models import CounselAction from counsel_env.server.counsel_env_environment import CounselEnvironment def _question_pattern(question: str) -> str: text = re.sub(r"[^a-z0-9\s]", " ", question.lower()) tokens = ["" if token.isdigit() else token for token in text.split()] return " ".join(tokens[:12]) def _scripted_policy(env: CounselEnvironment, rng: random.Random) -> List[str]: questions: List[str] = [] for contradiction in env.witness.contradictions[:2]: question = f"{contradiction.trigger_keywords[0]}?" questions.append(question) env.step(CounselAction(tool="ask_question", text=question)) env.step(CounselAction(tool="present_evidence", exhibit_id=contradiction.disprover_evidence_id)) if env.done: break return questions def _random_policy(env: CounselEnvironment, rng: random.Random) -> List[str]: questions: List[str] = [] generic_questions = [ "What happened that day?", "Do you remember the weather?", "Who else was nearby?", "Can you describe your routine?", "What did you eat for dinner?", "Did anything unusual happen?", ] for _ in range(8): if rng.random() < 0.75: question = rng.choice(generic_questions) questions.append(question) env.step(CounselAction(tool="ask_question", text=question)) else: exhibit = rng.choice(list(env.case["evidence"].keys())) env.step(CounselAction(tool="present_evidence", exhibit_id=exhibit)) if env.done: break return questions def _mixed_policy(env: CounselEnvironment, rng: random.Random, episode: int) -> List[str]: if episode in {0, 4, 8}: return _scripted_policy(env, rng) return _random_policy(env, rng) POLICIES: Dict[str, Callable[[CounselEnvironment, random.Random, int], List[str]]] = { "mixed": _mixed_policy, "scripted": lambda env, rng, episode: _scripted_policy(env, rng), "random": lambda env, rng, episode: _random_policy(env, rng), } def run_rollout_diagnostics( output_path: str | Path = "rollout_debug.jsonl", num_episodes: int = 10, policy: str = "mixed", seed: int = 13, ) -> List[dict]: rng = random.Random(seed) path = Path(output_path) rows: List[dict] = [] stages = ["easy", "medium", "hard"] policy_fn = POLICIES[policy] with path.open("w", encoding="utf-8") as handle: for episode in range(num_episodes): env = CounselEnvironment() stage = stages[episode % len(stages)] obs = env.reset(curriculum_stage=stage) questions = policy_fn(env, rng, episode) obs = env.step(CounselAction(tool="rest_case")) if not env.done else obs components = obs.reward_components patterns = {_question_pattern(question) for question in questions} action_types = [ line.split(":", 1)[0] for line in env.transcript if ":" in line ] row = { "episode": episode, "stage": stage, "difficulty": env.case.get("difficulty"), "case_id": env.case.get("case_id"), "contradictions_total": int(components["contradictions_total"]), "contradictions_triggered": int(components["contradictions_triggered"]), "contradictions_surfaced": int(components["contradictions_surfaced"]), "questions_asked": env.questions_used, "evidence_presented": env.evidence_presented_count, "useless_questions_ratio": components["useless_questions_ratio"], "avg_question_length": components["avg_question_length"], "unique_question_patterns": len(patterns), "unique_question_patterns_pct": (len(patterns) / max(1, len(questions))), "action_diversity": len(set(action_types)), "primary_reward": components["primary_reward"], "auxiliary_reward": components["auxiliary_reward_raw"], "total_reward": obs.reward, "success": obs.reward > 0.0, } rows.append(row) handle.write(json.dumps(row, sort_keys=True) + "\n") avg_reward = sum(row["total_reward"] for row in rows) / max(1, len(rows)) if avg_reward == 0: print('WARNING: ENV TOO HARD') if avg_reward > 0.7: print('WARNING: ENV TOO EASY') print(f"Wrote {len(rows)} rollout diagnostics to {path} (avg_reward={avg_reward:.3f})") return rows if __name__ == "__main__": run_rollout_diagnostics()