File size: 3,556 Bytes
e42a7af
 
 
 
312c390
e42a7af
 
 
 
 
 
 
 
0e23a69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e42a7af
0e23a69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e42a7af
0e23a69
 
 
e42a7af
0e23a69
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Sanity-check random-policy script for BoardSimEnv.

NOTE: this is **not** the canonical baseline used in the headline
trained-vs-baseline comparison. The canonical baseline is
**base Qwen3-0.6B without LoRA**, computed inside `notebooks/train_grpo_v2.ipynb`
(and the mirrored `Training.py` script). A coin-flip is not a
competitive opponent for a 4 B language model choosing among 3
well-formed strings; we keep this script only as a quick env-health
check (it confirms the env is reachable and rewards stay in range).

Outputs:
  - assets/random_sanity.csv               per-episode final profitability
  - assets/random_sanity_distribution.png  histogram of final profitabilities
"""

from __future__ import annotations

import csv
import os
import random
import statistics
import sys

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

# Make `envs.board_sim_env...` importable.
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, ROOT)
sys.path.insert(0, os.path.join(ROOT, "envs", "board_sim_env"))

from envs.board_sim_env.server.board_sim_env_environment import BoardSimEnvironment  # noqa: E402
from envs.board_sim_env.models import BoardSimAction  # noqa: E402


N_EPISODES = 200


def main() -> None:
    env = BoardSimEnvironment()
    final_profits: list[float] = []
    survival = 0
    total_reward_per_ep: list[float] = []

    for ep in range(N_EPISODES):
        obs = env.reset(seed=ep)
        done = False
        ep_reward = 0.0
        while not done:
            opts = obs.options
            if not opts:
                break
            obs = env.step(BoardSimAction(decision=random.choice(opts)))
            ep_reward += float(obs.reward or 0.0)
            done = obs.done
        final_profits.append(obs.state["profitability_score"])
        total_reward_per_ep.append(ep_reward)
        if obs.state.get("done_reason") != "runway_exhausted":
            survival += 1

    mean_p = statistics.mean(final_profits)
    std_p = statistics.stdev(final_profits)
    mean_r = statistics.mean(total_reward_per_ep)
    surv_rate = survival / N_EPISODES

    print(f"Random baseline over {N_EPISODES} episodes:")
    print(f"  mean final profitability = {mean_p:6.2f}  (std {std_p:.2f})")
    print(f"  mean total episode reward = {mean_r:6.2f}")
    print(f"  survival rate (no bankruptcy) = {surv_rate:.1%}")

    assets_dir = os.path.join(ROOT, "assets")
    os.makedirs(assets_dir, exist_ok=True)

    # CSV
    csv_path = os.path.join(assets_dir, "random_sanity.csv")
    with open(csv_path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["episode", "final_profitability", "total_reward"])
        for i, (p, r) in enumerate(zip(final_profits, total_reward_per_ep)):
            w.writerow([i, f"{p:.4f}", f"{r:.4f}"])

    # Histogram
    plt.figure(figsize=(8, 5))
    plt.hist(final_profits, bins=20, color="#c44", edgecolor="white", alpha=0.85)
    plt.axvline(mean_p, color="black", linestyle="--", linewidth=2, label=f"mean = {mean_p:.1f}")
    plt.title(f"Random-policy baseline — final profitability ({N_EPISODES} episodes)")
    plt.xlabel("Final profitability score (0–100)")
    plt.ylabel("Episodes")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(assets_dir, "random_sanity_distribution.png"), dpi=120)
    plt.close()

    print(f"\nWrote {csv_path}")
    print(f"Wrote {os.path.join(assets_dir, 'random_sanity_distribution.png')}")


if __name__ == "__main__":
    main()