vc_gemini_v0 / train_v0_qlearning.py
shrads78's picture
Upload folder using huggingface_hub
488e8d0 verified
#!/usr/bin/env python3
"""Train a tabular Q-learning policy for VcGeminiV0Environment.
This script does not modify the environment implementation; it learns by interacting
with `server.vc_gemini_v0_environment.VcGeminiV0Environment` directly.
Usage:
python train_v0_qlearning.py --episodes 5000 --output artifacts/v0_q_table.json
"""
from __future__ import annotations
import argparse
import json
import random
import shutil
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from statistics import mean
from typing import DefaultDict
from models import VcGeminiV0Action
from server.vc_gemini_v0_environment import VcGeminiV0Environment
QTable = DefaultDict[str, DefaultDict[str, float]]
@dataclass
class EpisodeStats:
reward: float
tvpi: float
steps: int
def make_q_table() -> QTable:
return defaultdict(lambda: defaultdict(float))
def state_key(env: VcGeminiV0Environment) -> str:
inbox_ids = sorted(p["startup_id"] for p in getattr(env, "inbox_pitches", []))
active_investments = sum(1 for p in env.portfolio if p.get("active"))
budget_bucket = int(round(env.fund_budget / 20_000_000.0))
return (
f"q={env.quarter}|t={max(env.turns_remaining, 0)}|"
f"b={budget_bucket}|a={active_investments}|"
f"inbox={','.join(inbox_ids)}"
)
def available_actions(env: VcGeminiV0Environment) -> list[str]:
actions = ["wait"]
for scenario in getattr(env, "inbox_pitches", []):
actions.append(f"invest::{scenario['startup_id']}")
return actions
def action_to_payload(action_key_str: str, env: VcGeminiV0Environment) -> dict:
if action_key_str == "wait":
return {"action_type": "wait", "parameters": {}}
if action_key_str.startswith("invest::"):
target_id = action_key_str.split("::", 1)[1]
for scenario in getattr(env, "inbox_pitches", []):
if scenario["startup_id"] == target_id:
return {
"action_type": "invest",
"parameters": {"startup_name": scenario["startup_name"]},
}
return {"action_type": "wait", "parameters": {}}
def epsilon_for_episode(
episode_idx: int,
epsilon_start: float,
epsilon_end: float,
epsilon_decay_episodes: int,
) -> float:
if epsilon_decay_episodes <= 0:
return epsilon_end
progress = min(1.0, episode_idx / float(epsilon_decay_episodes))
return epsilon_start + (epsilon_end - epsilon_start) * progress
def pick_action(
q: QTable,
state: str,
actions: list[str],
epsilon: float,
) -> str:
if not actions:
return "wait"
if random.random() < epsilon:
return random.choice(actions)
best_value = max(q[state][a] for a in actions)
best_actions = [a for a in actions if q[state][a] == best_value]
return random.choice(best_actions)
def cleanup_env(env: VcGeminiV0Environment) -> None:
workspace = getattr(env, "workspace_dir", None)
if workspace:
shutil.rmtree(workspace, ignore_errors=True)
def run_episode(
env: VcGeminiV0Environment,
q: QTable,
alpha: float,
gamma: float,
epsilon: float,
train_mode: bool,
) -> EpisodeStats:
obs = env.reset()
last_obs = obs
done = bool(obs.done)
total_reward = 0.0
steps = 0
while not done:
s = state_key(env)
actions = available_actions(env)
action_key_str = pick_action(q, s, actions, epsilon if train_mode else 0.0)
payload = action_to_payload(action_key_str, env)
next_obs = env.step(VcGeminiV0Action(**payload))
last_obs = next_obs
reward = float(next_obs.reward or 0.0)
done = bool(next_obs.done)
total_reward += reward
steps += 1
if train_mode:
if done:
td_target = reward
else:
s_next = state_key(env)
next_actions = available_actions(env)
max_next_q = max((q[s_next][a] for a in next_actions), default=0.0)
td_target = reward + gamma * max_next_q
td_error = td_target - q[s][action_key_str]
q[s][action_key_str] += alpha * td_error
tvpi = float(getattr(last_obs, "data", {}).get("tvpi", 0.0))
return EpisodeStats(reward=total_reward, tvpi=tvpi, steps=steps)
def train(
episodes: int,
alpha: float,
gamma: float,
epsilon_start: float,
epsilon_end: float,
epsilon_decay_episodes: int,
log_every: int,
) -> tuple[QTable, list[EpisodeStats]]:
q = make_q_table()
history: list[EpisodeStats] = []
env = VcGeminiV0Environment()
try:
for ep in range(1, episodes + 1):
epsilon = epsilon_for_episode(
episode_idx=ep,
epsilon_start=epsilon_start,
epsilon_end=epsilon_end,
epsilon_decay_episodes=epsilon_decay_episodes,
)
stats = run_episode(
env=env,
q=q,
alpha=alpha,
gamma=gamma,
epsilon=epsilon,
train_mode=True,
)
history.append(stats)
if log_every > 0 and ep % log_every == 0:
window = history[-log_every:]
avg_reward = mean(s.reward for s in window)
avg_tvpi = mean(s.tvpi for s in window)
avg_steps = mean(s.steps for s in window)
print(
f"[train] episode={ep:5d} epsilon={epsilon:.3f} "
f"avg_reward={avg_reward:.4f} avg_tvpi={avg_tvpi:.4f} avg_steps={avg_steps:.2f}"
)
finally:
cleanup_env(env)
return q, history
def evaluate(q: QTable, episodes: int) -> list[EpisodeStats]:
results: list[EpisodeStats] = []
env = VcGeminiV0Environment()
try:
for _ in range(episodes):
stats = run_episode(
env=env,
q=q,
alpha=0.0,
gamma=0.0,
epsilon=0.0,
train_mode=False,
)
results.append(stats)
finally:
cleanup_env(env)
return results
def serialize_q_table(q: QTable) -> dict[str, dict[str, float]]:
return {state: dict(actions) for state, actions in q.items()}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train tabular Q-learning on vc_gemini_v0")
parser.add_argument("--episodes", type=int, default=5000, help="Number of training episodes")
parser.add_argument("--eval-episodes", type=int, default=300, help="Greedy evaluation episodes")
parser.add_argument("--alpha", type=float, default=0.15, help="Learning rate")
parser.add_argument("--gamma", type=float, default=0.98, help="Discount factor")
parser.add_argument("--epsilon-start", type=float, default=1.0, help="Initial epsilon")
parser.add_argument("--epsilon-end", type=float, default=0.05, help="Final epsilon")
parser.add_argument(
"--epsilon-decay-episodes",
type=int,
default=3500,
help="Episodes to linearly decay epsilon from start to end",
)
parser.add_argument("--log-every", type=int, default=200, help="Log every N episodes")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument(
"--output",
type=Path,
default=Path("artifacts/v0_q_table.json"),
help="Path to save learned Q-table and summary",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
random.seed(args.seed)
q, train_history = train(
episodes=args.episodes,
alpha=args.alpha,
gamma=args.gamma,
epsilon_start=args.epsilon_start,
epsilon_end=args.epsilon_end,
epsilon_decay_episodes=args.epsilon_decay_episodes,
log_every=args.log_every,
)
eval_history = evaluate(q=q, episodes=args.eval_episodes)
train_avg_reward = mean(s.reward for s in train_history)
eval_avg_reward = mean(s.reward for s in eval_history)
eval_avg_tvpi = mean(s.tvpi for s in eval_history)
print("\n=== Training complete ===")
print(f"train episodes: {args.episodes}")
print(f"avg train reward: {train_avg_reward:.4f}")
print(f"eval episodes: {args.eval_episodes}")
print(f"avg eval reward (greedy): {eval_avg_reward:.4f}")
print(f"avg eval tvpi (greedy): {eval_avg_tvpi:.4f}x")
payload = {
"config": {
"episodes": args.episodes,
"eval_episodes": args.eval_episodes,
"alpha": args.alpha,
"gamma": args.gamma,
"epsilon_start": args.epsilon_start,
"epsilon_end": args.epsilon_end,
"epsilon_decay_episodes": args.epsilon_decay_episodes,
"seed": args.seed,
},
"metrics": {
"avg_train_reward": train_avg_reward,
"avg_eval_reward": eval_avg_reward,
"avg_eval_tvpi": eval_avg_tvpi,
},
"q_table": serialize_q_table(q),
}
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(json.dumps(payload, indent=2))
print(f"saved: {args.output}")
if __name__ == "__main__":
main()