Spaces:
Sleeping
Sleeping
| """ | |
| Samples and displays actual Arbitrator generations from a training checkpoint. | |
| Run during or after training to check for reward hacking patterns. | |
| Usage: | |
| python scripts/inspect_generations.py --checkpoint outputs/checkpoints/checkpoint-50 --n 10 | |
| python scripts/inspect_generations.py --checkpoint outputs/checkpoints/final_model --n 20 | |
| """ | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from collections import Counter | |
| from rich.console import Console | |
| from rich.panel import Panel | |
| from rich.table import Table | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| console = Console() | |
| REWARD_HACK_PATTERNS = [ | |
| ("same_action_repeat", lambda actions: len(set(actions)) == 1 and len(actions) >= 3), | |
| ("empty_reasoning", lambda actions: any(len(a.get("reasoning", "")) < 10 for a in actions)), | |
| ("hook_fixation", lambda actions: all(a.get("action_type") == "hook_rewrite" for a in actions)), | |
| ("ignores_debate", lambda actions: any(not a.get("critique_claim_id") for a in actions)), | |
| ] | |
| def inspect_checkpoint(checkpoint_path: str, n_samples: int): | |
| """ | |
| Load the model from checkpoint and run N episodes with the trained Arbitrator. | |
| Display each generated action and flag any reward hacking patterns. | |
| """ | |
| from viral_script_engine.environment.env import ViralScriptEnv | |
| console.print(f"\n[bold cyan]Inspecting checkpoint:[/bold cyan] {checkpoint_path}") | |
| console.print(f"[dim]Running {n_samples} sample episodes...[/dim]\n") | |
| try: | |
| from unsloth import FastLanguageModel | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=checkpoint_path, | |
| max_seq_length=2048, | |
| dtype=None, | |
| load_in_4bit=True, | |
| ) | |
| FastLanguageModel.for_inference(model) | |
| model_loaded = True | |
| except Exception as e: | |
| console.print(f"[yellow]Warning: Could not load model ({e}). Running with baseline agent.[/yellow]") | |
| model_loaded = False | |
| from viral_script_engine.agents.baseline_arbitrator import BaselineArbitratorAgent | |
| agent = BaselineArbitratorAgent() | |
| ROOT = Path(__file__).parent.parent / "viral_script_engine" | |
| env = ViralScriptEnv( | |
| scripts_path=str(ROOT / "data" / "test_scripts" / "scripts.json"), | |
| cultural_kb_path=str(ROOT / "data" / "cultural_kb.json"), | |
| max_steps=3, | |
| difficulty="easy", | |
| use_escalation=False, | |
| ) | |
| all_episode_actions = [] | |
| all_rewards = [] | |
| for ep_num in range(1, n_samples + 1): | |
| obs, _ = env.reset() | |
| episode_actions = [] | |
| episode_reward = 0.0 | |
| for _ in range(env.max_steps): | |
| action = agent.act(obs) | |
| episode_actions.append(action) | |
| obs, reward, terminated, truncated, info = env.step(action) | |
| episode_reward = reward | |
| if terminated or truncated: | |
| break | |
| all_episode_actions.append(episode_actions) | |
| all_rewards.append(episode_reward) | |
| console.print(f" Ep {ep_num:02d} | reward={episode_reward:.3f} | actions={[a.get('action_type','?') for a in episode_actions]}") | |
| console.print() | |
| # Action type distribution | |
| all_action_types = [a.get("action_type", "unknown") for eps in all_episode_actions for a in eps] | |
| action_counts = Counter(all_action_types) | |
| table = Table(title="Action Type Distribution", show_header=True) | |
| table.add_column("Action Type", style="cyan") | |
| table.add_column("Count", justify="right") | |
| table.add_column("Pct", justify="right") | |
| total_actions = sum(action_counts.values()) | |
| for action_type, count in action_counts.most_common(): | |
| pct = 100 * count / total_actions if total_actions > 0 else 0 | |
| table.add_row(action_type, str(count), f"{pct:.1f}%") | |
| console.print(table) | |
| # Reward hacking detection | |
| console.print("\n[bold]Reward Hacking Pattern Check:[/bold]") | |
| hacking_episodes = 0 | |
| for ep_idx, episode_actions in enumerate(all_episode_actions): | |
| flags = [] | |
| for pattern_name, check_fn in REWARD_HACK_PATTERNS: | |
| try: | |
| if check_fn(episode_actions): | |
| flags.append(pattern_name) | |
| except Exception: | |
| pass | |
| if flags: | |
| hacking_episodes += 1 | |
| console.print(f" [red]Ep {ep_idx + 1:02d}: {flags}[/red]") | |
| if hacking_episodes == 0: | |
| console.print(" [green]No reward hacking patterns detected[/green]") | |
| console.print(f"\n[bold]{hacking_episodes}/{n_samples} episodes show potential reward hacking patterns[/bold]") | |
| console.print(f"[bold]Mean reward across {n_samples} episodes: {sum(all_rewards)/len(all_rewards):.3f}[/bold]") | |
| return hacking_episodes, all_rewards | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--checkpoint", required=True, help="Path to model checkpoint directory") | |
| parser.add_argument("--n", type=int, default=10, help="Number of sample episodes to run") | |
| args = parser.parse_args() | |
| hacking_count, rewards = inspect_checkpoint(args.checkpoint, args.n) | |
| sys.exit(0 if hacking_count == 0 else 1) | |