Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """Train a tabular Q-learning policy for VcGeminiV0Environment. | |
| This script does not modify the environment implementation; it learns by interacting | |
| with `server.vc_gemini_v0_environment.VcGeminiV0Environment` directly. | |
| Usage: | |
| python train_v0_qlearning.py --episodes 5000 --output artifacts/v0_q_table.json | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import random | |
| import shutil | |
| from collections import defaultdict | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from statistics import mean | |
| from typing import DefaultDict | |
| from models import VcGeminiV0Action | |
| from server.vc_gemini_v0_environment import VcGeminiV0Environment | |
| QTable = DefaultDict[str, DefaultDict[str, float]] | |
| class EpisodeStats: | |
| reward: float | |
| tvpi: float | |
| steps: int | |
| def make_q_table() -> QTable: | |
| return defaultdict(lambda: defaultdict(float)) | |
| def state_key(env: VcGeminiV0Environment) -> str: | |
| inbox_ids = sorted(p["startup_id"] for p in getattr(env, "inbox_pitches", [])) | |
| active_investments = sum(1 for p in env.portfolio if p.get("active")) | |
| budget_bucket = int(round(env.fund_budget / 20_000_000.0)) | |
| return ( | |
| f"q={env.quarter}|t={max(env.turns_remaining, 0)}|" | |
| f"b={budget_bucket}|a={active_investments}|" | |
| f"inbox={','.join(inbox_ids)}" | |
| ) | |
| def available_actions(env: VcGeminiV0Environment) -> list[str]: | |
| actions = ["wait"] | |
| for scenario in getattr(env, "inbox_pitches", []): | |
| actions.append(f"invest::{scenario['startup_id']}") | |
| return actions | |
| def action_to_payload(action_key_str: str, env: VcGeminiV0Environment) -> dict: | |
| if action_key_str == "wait": | |
| return {"action_type": "wait", "parameters": {}} | |
| if action_key_str.startswith("invest::"): | |
| target_id = action_key_str.split("::", 1)[1] | |
| for scenario in getattr(env, "inbox_pitches", []): | |
| if scenario["startup_id"] == target_id: | |
| return { | |
| "action_type": "invest", | |
| "parameters": {"startup_name": scenario["startup_name"]}, | |
| } | |
| return {"action_type": "wait", "parameters": {}} | |
| def epsilon_for_episode( | |
| episode_idx: int, | |
| epsilon_start: float, | |
| epsilon_end: float, | |
| epsilon_decay_episodes: int, | |
| ) -> float: | |
| if epsilon_decay_episodes <= 0: | |
| return epsilon_end | |
| progress = min(1.0, episode_idx / float(epsilon_decay_episodes)) | |
| return epsilon_start + (epsilon_end - epsilon_start) * progress | |
| def pick_action( | |
| q: QTable, | |
| state: str, | |
| actions: list[str], | |
| epsilon: float, | |
| ) -> str: | |
| if not actions: | |
| return "wait" | |
| if random.random() < epsilon: | |
| return random.choice(actions) | |
| best_value = max(q[state][a] for a in actions) | |
| best_actions = [a for a in actions if q[state][a] == best_value] | |
| return random.choice(best_actions) | |
| def cleanup_env(env: VcGeminiV0Environment) -> None: | |
| workspace = getattr(env, "workspace_dir", None) | |
| if workspace: | |
| shutil.rmtree(workspace, ignore_errors=True) | |
| def run_episode( | |
| env: VcGeminiV0Environment, | |
| q: QTable, | |
| alpha: float, | |
| gamma: float, | |
| epsilon: float, | |
| train_mode: bool, | |
| ) -> EpisodeStats: | |
| obs = env.reset() | |
| last_obs = obs | |
| done = bool(obs.done) | |
| total_reward = 0.0 | |
| steps = 0 | |
| while not done: | |
| s = state_key(env) | |
| actions = available_actions(env) | |
| action_key_str = pick_action(q, s, actions, epsilon if train_mode else 0.0) | |
| payload = action_to_payload(action_key_str, env) | |
| next_obs = env.step(VcGeminiV0Action(**payload)) | |
| last_obs = next_obs | |
| reward = float(next_obs.reward or 0.0) | |
| done = bool(next_obs.done) | |
| total_reward += reward | |
| steps += 1 | |
| if train_mode: | |
| if done: | |
| td_target = reward | |
| else: | |
| s_next = state_key(env) | |
| next_actions = available_actions(env) | |
| max_next_q = max((q[s_next][a] for a in next_actions), default=0.0) | |
| td_target = reward + gamma * max_next_q | |
| td_error = td_target - q[s][action_key_str] | |
| q[s][action_key_str] += alpha * td_error | |
| tvpi = float(getattr(last_obs, "data", {}).get("tvpi", 0.0)) | |
| return EpisodeStats(reward=total_reward, tvpi=tvpi, steps=steps) | |
| def train( | |
| episodes: int, | |
| alpha: float, | |
| gamma: float, | |
| epsilon_start: float, | |
| epsilon_end: float, | |
| epsilon_decay_episodes: int, | |
| log_every: int, | |
| ) -> tuple[QTable, list[EpisodeStats]]: | |
| q = make_q_table() | |
| history: list[EpisodeStats] = [] | |
| env = VcGeminiV0Environment() | |
| try: | |
| for ep in range(1, episodes + 1): | |
| epsilon = epsilon_for_episode( | |
| episode_idx=ep, | |
| epsilon_start=epsilon_start, | |
| epsilon_end=epsilon_end, | |
| epsilon_decay_episodes=epsilon_decay_episodes, | |
| ) | |
| stats = run_episode( | |
| env=env, | |
| q=q, | |
| alpha=alpha, | |
| gamma=gamma, | |
| epsilon=epsilon, | |
| train_mode=True, | |
| ) | |
| history.append(stats) | |
| if log_every > 0 and ep % log_every == 0: | |
| window = history[-log_every:] | |
| avg_reward = mean(s.reward for s in window) | |
| avg_tvpi = mean(s.tvpi for s in window) | |
| avg_steps = mean(s.steps for s in window) | |
| print( | |
| f"[train] episode={ep:5d} epsilon={epsilon:.3f} " | |
| f"avg_reward={avg_reward:.4f} avg_tvpi={avg_tvpi:.4f} avg_steps={avg_steps:.2f}" | |
| ) | |
| finally: | |
| cleanup_env(env) | |
| return q, history | |
| def evaluate(q: QTable, episodes: int) -> list[EpisodeStats]: | |
| results: list[EpisodeStats] = [] | |
| env = VcGeminiV0Environment() | |
| try: | |
| for _ in range(episodes): | |
| stats = run_episode( | |
| env=env, | |
| q=q, | |
| alpha=0.0, | |
| gamma=0.0, | |
| epsilon=0.0, | |
| train_mode=False, | |
| ) | |
| results.append(stats) | |
| finally: | |
| cleanup_env(env) | |
| return results | |
| def serialize_q_table(q: QTable) -> dict[str, dict[str, float]]: | |
| return {state: dict(actions) for state, actions in q.items()} | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Train tabular Q-learning on vc_gemini_v0") | |
| parser.add_argument("--episodes", type=int, default=5000, help="Number of training episodes") | |
| parser.add_argument("--eval-episodes", type=int, default=300, help="Greedy evaluation episodes") | |
| parser.add_argument("--alpha", type=float, default=0.15, help="Learning rate") | |
| parser.add_argument("--gamma", type=float, default=0.98, help="Discount factor") | |
| parser.add_argument("--epsilon-start", type=float, default=1.0, help="Initial epsilon") | |
| parser.add_argument("--epsilon-end", type=float, default=0.05, help="Final epsilon") | |
| parser.add_argument( | |
| "--epsilon-decay-episodes", | |
| type=int, | |
| default=3500, | |
| help="Episodes to linearly decay epsilon from start to end", | |
| ) | |
| parser.add_argument("--log-every", type=int, default=200, help="Log every N episodes") | |
| parser.add_argument("--seed", type=int, default=42, help="Random seed") | |
| parser.add_argument( | |
| "--output", | |
| type=Path, | |
| default=Path("artifacts/v0_q_table.json"), | |
| help="Path to save learned Q-table and summary", | |
| ) | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| random.seed(args.seed) | |
| q, train_history = train( | |
| episodes=args.episodes, | |
| alpha=args.alpha, | |
| gamma=args.gamma, | |
| epsilon_start=args.epsilon_start, | |
| epsilon_end=args.epsilon_end, | |
| epsilon_decay_episodes=args.epsilon_decay_episodes, | |
| log_every=args.log_every, | |
| ) | |
| eval_history = evaluate(q=q, episodes=args.eval_episodes) | |
| train_avg_reward = mean(s.reward for s in train_history) | |
| eval_avg_reward = mean(s.reward for s in eval_history) | |
| eval_avg_tvpi = mean(s.tvpi for s in eval_history) | |
| print("\n=== Training complete ===") | |
| print(f"train episodes: {args.episodes}") | |
| print(f"avg train reward: {train_avg_reward:.4f}") | |
| print(f"eval episodes: {args.eval_episodes}") | |
| print(f"avg eval reward (greedy): {eval_avg_reward:.4f}") | |
| print(f"avg eval tvpi (greedy): {eval_avg_tvpi:.4f}x") | |
| payload = { | |
| "config": { | |
| "episodes": args.episodes, | |
| "eval_episodes": args.eval_episodes, | |
| "alpha": args.alpha, | |
| "gamma": args.gamma, | |
| "epsilon_start": args.epsilon_start, | |
| "epsilon_end": args.epsilon_end, | |
| "epsilon_decay_episodes": args.epsilon_decay_episodes, | |
| "seed": args.seed, | |
| }, | |
| "metrics": { | |
| "avg_train_reward": train_avg_reward, | |
| "avg_eval_reward": eval_avg_reward, | |
| "avg_eval_tvpi": eval_avg_tvpi, | |
| }, | |
| "q_table": serialize_q_table(q), | |
| } | |
| args.output.parent.mkdir(parents=True, exist_ok=True) | |
| args.output.write_text(json.dumps(payload, indent=2)) | |
| print(f"saved: {args.output}") | |
| if __name__ == "__main__": | |
| main() | |