Spaces:
Sleeping
Sleeping
| """Main training loop for the Fake Gang Detection LLM Agent. | |
| Learning mechanism (Reflexion + Episodic Memory): | |
| ┌─────────────────────────────────────────────────────────────────────┐ | |
| │ Episode N │ | |
| │ 1. LLM receives: system_prompt + reflections + few_shot + obs │ | |
| │ 2. LLM reasons and outputs one action │ | |
| │ 3. Action is sent to OpenEnv HTTP server → new obs + reward │ | |
| │ 4. Repeat until done │ | |
| │ 5. Post-episode: │ | |
| │ • If FAIL → Qwen generates a reflection → stored to disk │ | |
| │ • If WIN → trajectory saved as few-shot example │ | |
| │ 6. Both are injected into Episode N+1's prompt → better policy │ | |
| └─────────────────────────────────────────────────────────────────────┘ | |
| Curriculum: | |
| Phase 1 (episodes 1-20) : easy — learn basic signal detection | |
| Phase 2 (episodes 21-35): medium — learn to handle evasion | |
| Phase 3 (episodes 36-50): hard — feature-only detection | |
| Usage: | |
| python train.py # defaults: 50 episodes, curriculum | |
| python train.py --task easy --episodes 20 | |
| python train.py --env-url http://localhost:8000 --episodes 50 --log-dir runs/ | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from collections import defaultdict | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| # Resolve imports | |
| _ROOT = Path(__file__).parent | |
| sys.path.insert(0, str(_ROOT)) | |
| sys.path.insert(0, str(_ROOT / "server")) | |
| from client import FakeGangEnvClient, StepResult | |
| from models import FakeGangAction, FakeGangObservation, ActionType | |
| from agent.memory import AgentMemory | |
| from agent.hybrid_policy import get_hybrid_action, compute_alpha | |
| from agent.reflection import generate_reflection, generate_success_reflection | |
| # --------------------------------------------------------------------------- | |
| # Curriculum schedule | |
| # --------------------------------------------------------------------------- | |
| def _curriculum_task(episode_num: int, override_task: Optional[str]) -> str: | |
| if override_task: | |
| return override_task | |
| if episode_num < 20: | |
| return "easy" | |
| if episode_num < 35: | |
| return "medium" | |
| return "hard" | |
| def _episode_seed(episode_num: int, task: str) -> int: | |
| """Rotate through pre-generated seeds so each task sees varied episodes.""" | |
| offsets = {"easy": 0, "medium": 50, "hard": 100} | |
| return (episode_num + offsets.get(task, 0)) % 50 | |
| # --------------------------------------------------------------------------- | |
| # Episode runner | |
| # --------------------------------------------------------------------------- | |
| def run_episode( | |
| env: FakeGangEnvClient, | |
| task: str, | |
| seed: int, | |
| memory: AgentMemory, | |
| episode_num: int, | |
| alpha: float = 0.20, | |
| temperature: float = 0.4, | |
| verbose: bool = False, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Run one full episode using the hybrid policy. Returns a metrics dict. | |
| alpha — current LLM trust weight (0.20 = rules dominate, 1.00 = pure LLM). | |
| """ | |
| result = env.reset(task=task, seed=seed) | |
| obs: FakeGangObservation = result.observation | |
| # Fetch learning context from memory | |
| reflections = memory.get_reflections(task, n=4) | |
| few_shot = memory.get_best_trajectory(task) | |
| action_log: List[str] = [] | |
| mode_log: List[str] = [] | |
| step_num = 0 | |
| while not obs.done: | |
| action, raw_llm, mode = get_hybrid_action( | |
| obs=obs, | |
| reflections=reflections, | |
| few_shot_example=few_shot, | |
| alpha=alpha, | |
| temperature=temperature, | |
| ) | |
| # Build a human-readable action string for the log | |
| action_str = action.action_type.value.upper() | |
| if action.account_id: | |
| action_str += f" {action.account_id}" | |
| action_log.append(action_str) | |
| mode_log.append(mode) | |
| if verbose: | |
| mode_tag = mode.split("(")[0] # strip params for brevity | |
| print(f" Step {step_num+1:2d}: {action_str:35s} [{mode_tag}]") | |
| result = env.step(action) | |
| obs = result.observation | |
| step_num += 1 | |
| if obs.done: | |
| break | |
| # Parse final message for TP/FP/FN | |
| final_msg = obs.message | |
| won = "[WIN]" in final_msg | |
| reward = result.reward if result.reward is not None else obs.reward or 0.0 | |
| steps_used = ({"easy": 30, "medium": 50, "hard": 80}.get(task, 30)) - obs.steps_remaining | |
| # Extract recall/precision from message if present | |
| recall = _extract_float(final_msg, "Recall=") | |
| precision = _extract_float(final_msg, "Precision=") | |
| # Summarise mode distribution for the episode | |
| agree_count = sum(1 for m in mode_log if m == "agree") | |
| rule_count = sum(1 for m in mode_log if m.startswith("rule_override")) | |
| llm_count = sum(1 for m in mode_log if m.startswith("llm")) | |
| return { | |
| "episode": episode_num, | |
| "task": task, | |
| "seed": seed, | |
| "won": won, | |
| "reward": round(reward, 3), | |
| "steps_used": steps_used, | |
| "recall": recall, | |
| "precision": precision, | |
| "action_log": action_log, | |
| "final_message": final_msg, | |
| "n_reflections_used": len(reflections), | |
| "had_few_shot": few_shot is not None, | |
| "alpha_used": round(alpha, 3), | |
| "mode_agree": agree_count, | |
| "mode_rule": rule_count, | |
| "mode_llm": llm_count, | |
| "timestamp": datetime.utcnow().isoformat(), | |
| } | |
| def _extract_float(text: str, key: str) -> float: | |
| import re | |
| m = re.search(re.escape(key) + r"([0-9.]+)", text) | |
| return float(m.group(1)) if m else 0.0 | |
| # --------------------------------------------------------------------------- | |
| # Learning step (post-episode) | |
| # --------------------------------------------------------------------------- | |
| def learning_step( | |
| metrics: Dict[str, Any], | |
| memory: AgentMemory, | |
| ) -> str: | |
| """Generate reflection or save trajectory. Returns a short status string.""" | |
| task = metrics["task"] | |
| won = metrics["won"] | |
| action_log = metrics["action_log"] | |
| final_msg = metrics["final_message"] | |
| steps_used = metrics["steps_used"] | |
| max_steps = {"easy": 30, "medium": 50, "hard": 80}.get(task, 30) | |
| ep = metrics["episode"] | |
| if won: | |
| # Save trajectory as few-shot example | |
| saved = memory.add_trajectory( | |
| task=task, | |
| action_log=action_log, | |
| final_message=final_msg, | |
| reward=metrics["reward"], | |
| episode_num=ep, | |
| ) | |
| # Also generate a success reflection occasionally | |
| if saved or memory.reflection_count(task) == 0: | |
| ref = generate_success_reflection(task, action_log, final_msg, steps_used, max_steps, ep) | |
| memory.add_reflection(task, ref, ep, metrics["reward"]) | |
| return f"new best trajectory saved + success reflection generated" | |
| return "trajectory kept (not better than current best)" | |
| else: | |
| # Generate failure reflection | |
| ref = generate_reflection( | |
| task=task, | |
| action_log=action_log, | |
| final_message=final_msg, | |
| won=False, | |
| steps_used=steps_used, | |
| max_steps=max_steps, | |
| episode_num=ep, | |
| ) | |
| memory.add_reflection(task, ref, ep, metrics["reward"]) | |
| return f"reflection generated: \"{ref[:80]}…\"" | |
| # --------------------------------------------------------------------------- | |
| # Progress printer | |
| # --------------------------------------------------------------------------- | |
| class ProgressTracker: | |
| def __init__(self) -> None: | |
| self.by_task: Dict[str, List[Dict]] = defaultdict(list) | |
| self.all: List[Dict] = [] | |
| def record(self, m: Dict[str, Any]) -> None: | |
| self.by_task[m["task"]].append(m) | |
| self.all.append(m) | |
| def win_rate(self, task: Optional[str] = None, last_n: int = 10) -> float: | |
| records = self.by_task[task] if task else self.all | |
| if not records: | |
| return 0.0 | |
| window = records[-last_n:] | |
| return sum(1 for r in window if r["won"]) / len(window) | |
| def print_episode(self, m: Dict[str, Any], learn_status: str) -> None: | |
| wr = self.win_rate(m["task"], last_n=10) | |
| tag = "WIN " if m["won"] else "LOSS" | |
| alpha = m.get("alpha_used", 0.20) | |
| agree = m.get("mode_agree", 0) | |
| rule = m.get("mode_rule", 0) | |
| llm = m.get("mode_llm", 0) | |
| total_steps = agree + rule + llm | |
| mode_str = ( | |
| f"agree={agree} rule={rule} llm={llm}" | |
| if total_steps > 0 | |
| else "n/a" | |
| ) | |
| print( | |
| f" Ep {m['episode']:3d} | {m['task']:6s} | {tag} | " | |
| f"reward={m['reward']:+7.2f} | " | |
| f"recall={m['recall']:.2f} prec={m['precision']:.2f} | " | |
| f"steps={m['steps_used']:2d} | " | |
| f"wr={wr:.0%} | α={alpha:.2f} | {mode_str}" | |
| ) | |
| print(f" └─ learn: {learn_status}") | |
| def print_summary(self, phase: str) -> None: | |
| print(f"\n{'━'*70}") | |
| print(f" {phase} SUMMARY") | |
| print(f"{'━'*70}") | |
| for task in ["easy", "medium", "hard"]: | |
| records = self.by_task[task] | |
| if not records: | |
| continue | |
| wins = sum(1 for r in records if r["won"]) | |
| avg_r = sum(r["reward"] for r in records) / len(records) | |
| avg_recall = sum(r["recall"] for r in records) / len(records) | |
| print( | |
| f" {task:6s}: {wins}/{len(records)} wins ({100*wins/len(records):.0f}%) | " | |
| f"avg reward={avg_r:+.2f} | avg recall={avg_recall:.2f}" | |
| ) | |
| total = len(self.all) | |
| total_wins = sum(1 for r in self.all if r["won"]) | |
| if total: | |
| print(f" {'TOTAL':6s}: {total_wins}/{total} wins ({100*total_wins/total:.0f}%)") | |
| print(f"{'━'*70}\n") | |
| # --------------------------------------------------------------------------- | |
| # Metrics persistence | |
| # --------------------------------------------------------------------------- | |
| def save_metrics(metrics_list: List[Dict], log_dir: Path) -> None: | |
| log_dir.mkdir(parents=True, exist_ok=True) | |
| path = log_dir / "metrics.jsonl" | |
| with open(path, "a") as f: | |
| for m in metrics_list: | |
| f.write(json.dumps(m) + "\n") | |
| # --------------------------------------------------------------------------- | |
| # Main training loop | |
| # --------------------------------------------------------------------------- | |
| def train( | |
| env_url: str = "http://localhost:8000", | |
| task: Optional[str] = None, | |
| n_episodes: int = 50, | |
| temperature: float = 0.4, | |
| verbose: bool = False, | |
| log_dir: Path = Path("runs"), | |
| ) -> None: | |
| print(f"\n{'━'*70}") | |
| print(f" Fake Gang Detection — Hybrid Policy Training") | |
| print(f" OpenEnv server : {env_url}") | |
| print(f" Episodes : {n_episodes}") | |
| print(f" Task schedule : {task or 'curriculum (easy→medium→hard)'}") | |
| print(f" LLM : Qwen3 via AWS Bedrock") | |
| print(f" Policy : Hybrid (rules ↔ LLM, dynamic α)") | |
| print(f" Learning : Reflexion (reflections + few-shot in prompt)") | |
| print(f"{'━'*70}\n") | |
| memory = AgentMemory() | |
| tracker = ProgressTracker() | |
| pending_metrics: List[Dict] = [] | |
| print(memory.summary()) | |
| print() | |
| with FakeGangEnvClient(base_url=env_url) as env: | |
| for ep in range(n_episodes): | |
| current_task = _curriculum_task(ep, task) | |
| seed = _episode_seed(ep, current_task) | |
| # --- Phase announcements --- | |
| if ep == 0: | |
| print("=== Phase 1: Easy — learning basic signal detection ===") | |
| elif ep == 20 and task is None: | |
| print("\n=== Phase 2: Medium — handling evasion ===") | |
| elif ep == 35 and task is None: | |
| print("\n=== Phase 3: Hard — feature-only detection ===") | |
| # --- Compute α for this episode --- | |
| # Load persisted α and update with latest win rate + reflection count | |
| n_refs = memory.reflection_count(current_task) | |
| wr = memory.recent_win_rate(current_task, n=10) | |
| alpha = compute_alpha(recent_win_rate=wr, n_reflections=n_refs, task=current_task) | |
| # --- Run episode --- | |
| metrics = run_episode( | |
| env=env, | |
| task=current_task, | |
| seed=seed, | |
| memory=memory, | |
| episode_num=ep + 1, | |
| alpha=alpha, | |
| temperature=temperature, | |
| verbose=verbose, | |
| ) | |
| # --- Post-episode: record win, update + persist α --- | |
| memory.record_win(current_task, metrics["won"], ep + 1) | |
| # Recompute with the just-added result for next episode's alpha | |
| new_wr = memory.recent_win_rate(current_task, n=10) | |
| new_alpha = compute_alpha(recent_win_rate=new_wr, n_reflections=n_refs, task=current_task) | |
| memory.save_alpha(current_task, new_alpha) | |
| # --- Learning step --- | |
| learn_status = learning_step(metrics, memory) | |
| # --- Log --- | |
| tracker.record(metrics) | |
| pending_metrics.append(metrics) | |
| tracker.print_episode(metrics, learn_status) | |
| # Flush metrics every 5 episodes | |
| if len(pending_metrics) >= 5: | |
| save_metrics(pending_metrics, log_dir) | |
| pending_metrics.clear() | |
| # Phase summary every 10 episodes | |
| if (ep + 1) % 10 == 0: | |
| tracker.print_summary(f"Episodes 1–{ep+1}") | |
| # Final flush and summary | |
| if pending_metrics: | |
| save_metrics(pending_metrics, log_dir) | |
| tracker.print_summary("FINAL") | |
| print(memory.summary()) | |
| print(f"\nMetrics saved to: {log_dir / 'metrics.jsonl'}") | |
| print(f"Memory saved to: {memory.memory_dir}/") | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Train the Fake Gang Detection LLM agent (Reflexion + Episodic Memory)." | |
| ) | |
| parser.add_argument("--env-url", default="http://localhost:8000", | |
| help="OpenEnv server URL (default: http://localhost:8000)") | |
| parser.add_argument("--task", choices=["easy", "medium", "hard"], default=None, | |
| help="Fix task (default: curriculum easy→medium→hard)") | |
| parser.add_argument("--episodes", type=int, default=50, | |
| help="Total training episodes (default: 50)") | |
| parser.add_argument("--temperature", type=float, default=0.4, | |
| help="LLM temperature (default: 0.4)") | |
| parser.add_argument("--verbose", action="store_true", | |
| help="Print each action step") | |
| parser.add_argument("--log-dir", default="runs", | |
| help="Directory for metrics output (default: runs/)") | |
| args = parser.parse_args() | |
| train( | |
| env_url=args.env_url, | |
| task=args.task, | |
| n_episodes=args.episodes, | |
| temperature=args.temperature, | |
| verbose=args.verbose, | |
| log_dir=Path(args.log_dir), | |
| ) | |