File size: 3,918 Bytes
cffeda9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""
Quick smoke test β€” run locally before pushing to HF Spaces.
Tests: reset, step through full episode, crisis triggers, reward signals.

Usage:
    python tests/test_env.py
"""

import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from revops_gym.env import RevOpsEnv
from revops_gym.models import RevOpsAction


def test_episode(difficulty="normal", seed=42):
    print(f"\n=== Smoke test | difficulty={difficulty} seed={seed} ===")
    env = RevOpsEnv(crisis_every=3, seed=seed, difficulty=difficulty)
    obs = env.reset(seed=seed)
    assert obs.step_number == 0, "Reset should start at step 0"
    assert obs.mrr > 0, "MRR should be positive after reset"
    print(env.render())

    actions = [
        ("increase_marketing", 0.6),
        ("hire_support", 0.8),
        ("negotiate_contracts", 0.5),
        ("raise_prices", 0.4),
        ("feature_investment", 0.7),
        ("cut_costs", 0.3),
        ("discount_campaign", 0.5),
        ("increase_marketing", 0.7),
        ("hire_support", 0.5),
        ("pivot_segment", 0.6),
    ]

    rewards = []
    crises_seen = []
    for i, (action_type, magnitude) in enumerate(actions):
        obs = env.step({"action_type": action_type, "magnitude": magnitude})
        rewards.append(obs.reward_last_step)
        if obs.active_crisis != "NONE":
            crises_seen.append(obs.active_crisis)
        print(
            f"  Step {obs.step_number:2d} | {action_type:<22} mag={magnitude} "
            f"| reward={obs.reward_last_step:+.3f} | MRR=${obs.mrr:,.0f} "
            f"| LTV/CAC={obs.ltv_cac_ratio:.2f}x"
            + (f" | ⚠️ {obs.active_crisis}" if obs.active_crisis != "NONE" else "")
        )
        if obs.terminated or obs.truncated:
            print(f"\n  Episode ended at step {obs.step_number} "
                  f"({'terminated' if obs.terminated else 'truncated'})")
            break

    print(f"\n  Total steps: {obs.step_number}")
    print(f"  Mean reward: {sum(rewards)/len(rewards):.4f}")
    print(f"  Min reward:  {min(rewards):.4f}")
    print(f"  Max reward:  {max(rewards):.4f}")
    print(f"  Crises seen: {crises_seen or ['none triggered yet']}")
    assert len(rewards) > 0, "Should have at least one reward"
    print("\nβœ… Smoke test passed!")
    return True


def test_all_actions():
    print("\n=== Testing all action types ===")
    env = RevOpsEnv(seed=0)
    env.reset(seed=0)
    all_actions = [
        "increase_marketing", "decrease_marketing", "hire_support",
        "fire_support", "discount_campaign", "raise_prices",
        "feature_investment", "cut_costs", "negotiate_contracts", "pivot_segment",
    ]
    for action in all_actions:
        obs = env.step({"action_type": action, "magnitude": 0.5})
        assert obs.reward_last_step is not None
        print(f"  βœ“ {action:<24} reward={obs.reward_last_step:+.3f}")
    print("βœ… All actions tested!")


def test_termination():
    print("\n=== Testing termination conditions ===")
    from revops_gym.models import RevOpsState
    from revops_gym.reward import RewardRubric
    rubric = RewardRubric()

    # MRR below floor
    state = RevOpsState(mrr=5_000, step_number=5)
    assert state.is_terminal, "Should terminate when MRR < floor"
    rb = rubric.compute(state, terminated=True)
    assert rb.terminated_penalty == -2.0, "Should get termination penalty"
    print("  βœ“ MRR floor termination works")

    # Max steps
    state2 = RevOpsState(mrr=100_000, step_number=30)
    assert state2.is_terminal, "Should truncate at step 30"
    print("  βœ“ Step limit truncation works")
    print("βœ… Termination tests passed!")


if __name__ == "__main__":
    test_episode(difficulty="easy")
    test_episode(difficulty="normal")
    test_episode(difficulty="hard")
    test_all_actions()
    test_termination()
    print("\nπŸŽ‰ All tests passed! Ready to push to HF Spaces.")