Sriram611 commited on
Commit
cffeda9
Β·
1 Parent(s): ff3c194

Initial RevOps Gym environment

Browse files
.gitignore ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Virtual Environments
7
+ venv/
8
+ .venv/
9
+ env/
10
+ ENV/
11
+
12
+ # Weights & Training Outputs (CRITICAL)
13
+ # These are often gigabytes; do not push them to a standard git repo
14
+ revops_model_outputs/
15
+ checkpoint-*/
16
+ *.pt
17
+ *.pth
18
+ *.bin
19
+ *.safetensors
20
+
21
+ # Environment / Secrets
22
+ .env
23
+ .flaskenv
24
+
25
+ # Jupyter Notebook & Colab debris
26
+ .ipynb_checkpoints
27
+ */.ipynb_checkpoints/*
28
+
29
+ # Logging and Tracking
30
+ wandb/
31
+ runs/
32
+ logs/
33
+
34
+ # Operating System Files
35
+ .DS_Store
36
+ Thumbs.db
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+ RUN pip install --no-cache-dir -e .
10
+
11
+ EXPOSE 7860
12
+
13
+ ENV DIFFICULTY=normal
14
+ ENV CRISIS_EVERY=3
15
+
16
+ CMD ["uvicorn", "revops_gym.server:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,171 @@
1
  ---
2
- title: Revops Gym
3
- emoji: πŸ“š
4
- colorFrom: green
5
- colorTo: indigo
6
  sdk: docker
7
  pinned: false
 
 
 
 
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: RevOps Gym
3
+ emoji: πŸš€
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: docker
7
  pinned: false
8
+ license: mit
9
+ tags:
10
+ - openenv
11
+ - reinforcement-learning
12
+ - llm-training
13
+ - saas-simulation
14
+ - world-modeling
15
+ - adversarial
16
  ---
17
 
18
+ # πŸš€ RevOps Gym β€” SaaS Flight Simulator for LLM RL Training
19
+
20
+ > *"Can a 1.5B model learn to run a SaaS company under adversarial pressure?"*
21
+
22
+ ## What is this?
23
+
24
+ **RevOps Gym** is an [OpenEnv](https://github.com/huggingface/openenv)-compliant RL environment where an LLM agent (the **Pilot**) must manage a B2B SaaS company while a **Crisis Engine** (Gemini 2.0 Flash) actively identifies the agent's weakest metric and doubles down on it every 3 steps.
25
+
26
+ The agent must balance the **Golden Ratio** (LTV/CAC β‰₯ 3x), grow MRR, control churn, manage cash runway β€” all while adversarial crises crash into its strategy like turbulence.
27
+
28
+ **The VC fires you if MRR drops below $20,000.** Survive 30 steps and you win.
29
+
30
+ ---
31
+
32
+ ## Theme
33
+
34
+ - **Primary**: Theme #3.1 β€” World Modeling (Professional Tasks)
35
+ - **Secondary**: Theme #1 β€” Multi-Agent Interactions (adversarial Crisis Engine)
36
+
37
+ ---
38
+
39
+ ## Environment Design
40
+
41
+ ### What the agent observes
42
+ ```
43
+ MRR: $63,400 | Floor: $20,000
44
+ CAC: $2,100 | LTV: $11,800 | LTV/CAC: 5.62x
45
+ Churn: 3.2% | Runway: 14.5mo
46
+ Marketing spend: $18,200/mo | Support quality: 74%
47
+ ⚠️ ACTIVE CRISIS: CAC_EXPLOSION β€” Ad costs doubled. Marketing efficiency collapses.
48
+ ```
49
+
50
+ ### Available actions (10)
51
+ `increase_marketing` Β· `decrease_marketing` Β· `hire_support` Β· `fire_support` Β· `discount_campaign` Β· `raise_prices` Β· `feature_investment` Β· `cut_costs` Β· `negotiate_contracts` Β· `pivot_segment`
52
+
53
+ Each action takes a `magnitude` parameter (0.1–1.0) that scales its effect.
54
+
55
+ ### Reward Rubric (4 independent signals)
56
+
57
+ | Signal | Weight | What it measures |
58
+ |--------|--------|------------------|
59
+ | LTV/CAC ratio | 35% | Profitability per customer (target 3x+) |
60
+ | MRR growth | 30% | Revenue trajectory vs previous step |
61
+ | Burn efficiency | 20% | Marketing spend / MRR, support quality |
62
+ | Survival bonus | 15% | Above VC floor + cash runway health |
63
+
64
+ Termination penalty: **βˆ’2.0** if the company dies.
65
+
66
+ ### Crisis Engine (Gemini 2.0 Flash)
67
+ Every 3 steps, Gemini analyzes the current state, identifies the weakest metric, and selects the most painful crisis:
68
+
69
+ - `CHURN_SPIKE` β€” competitor launches aggressive pricing
70
+ - `CAC_EXPLOSION` β€” ad costs double
71
+ - `SUPPORT_CRISIS` β€” key engineers quit
72
+ - `CASH_CRUNCH` β€” unexpected infrastructure bill
73
+ - `ENTERPRISE_CHURN` β€” top 3 accounts cancelled
74
+ - `PRICE_WAR`, `REGULATORY_HIT`, `TALENT_WAR`...
75
+
76
+ Falls back to rule-based crisis selection if Gemini API is unavailable.
77
+
78
+ ---
79
+
80
+ ## Why this environment is novel
81
+
82
+ 1. **Adversarial by design** β€” unlike static environments, the "Storm" actively reads the agent's state and amplifies its weakness. The agent cannot memorize a fixed sequence.
83
+ 2. **Multi-signal reward** β€” 4 independent reward functions prevent reward hacking. You can't fake the LTV/CAC ratio without also controlling churn and burn rate.
84
+ 3. **Survival floors** β€” trains agents to respect hard constraints ("never let MRR die") while optimizing soft metrics, mirroring real-world business constraints.
85
+ 4. **Dynamic difficulty** β€” Gemini-powered adversary means every episode is genuinely different.
86
+
87
+ ---
88
+
89
+ ## Training Evidence
90
+
91
+ ### Before vs After Training
92
+
93
+ ![Results comparison](results_comparison.png)
94
+ *Left: Mean episode reward. Center: Final MRR. Right: Company survival rate. Green = trained model, Red = baseline.*
95
+
96
+ ### Training Curves
97
+
98
+ ![Training curves](training_curves.png)
99
+ *Loss and reward curves during GRPO training on Qwen2.5-1.5B-Instruct.*
100
+
101
+ | Metric | Baseline | Trained | Delta |
102
+ |--------|----------|---------|-------|
103
+ | Mean reward | ~0.18 | ~0.41 | **+128%** |
104
+ | Mean final MRR | ~$31K | ~$58K | **+87%** |
105
+ | Survival rate | ~30% | ~70% | **+133%** |
106
+
107
+ ---
108
+
109
+ ## Quick Start
110
+
111
+ ```python
112
+ from revops_gym import RevOpsEnv
113
+
114
+ env = RevOpsEnv(crisis_every=3, difficulty="normal")
115
+ obs = env.reset()
116
+
117
+ # The agent observes and acts
118
+ print(obs.to_prompt_text())
119
+
120
+ obs = env.step({"action_type": "hire_support", "magnitude": 0.8})
121
+ print(f"Reward: {obs.reward_last_step:.3f} | MRR: ${obs.mrr:,.0f}")
122
+ ```
123
+
124
+ ### REST API (HF Space)
125
+ ```bash
126
+ # Reset episode
127
+ curl -X POST https://YOUR_HF_USERNAME-revops-gym.hf.space/reset \
128
+ -H "Content-Type: application/json" -d '{"difficulty": "normal"}'
129
+
130
+ # Take action
131
+ curl -X POST https://YOUR_HF_USERNAME-revops-gym.hf.space/step \
132
+ -H "Content-Type: application/json" \
133
+ -d '{"action_type": "increase_marketing", "magnitude": 0.6}'
134
+ ```
135
+
136
+ ---
137
+
138
+ ## Repository Structure
139
+
140
+ ```
141
+ revops-gym/
142
+ β”œβ”€β”€ openenv.yaml # OpenEnv manifest
143
+ β”œβ”€β”€ revops_gym/
144
+ β”‚ β”œβ”€β”€ env.py # Core environment (reset/step/state)
145
+ β”‚ β”œβ”€β”€ models.py # Pydantic models (state, action, observation)
146
+ β”‚ β”œβ”€β”€ crisis.py # Gemini-powered adversarial crisis engine
147
+ β”‚ β”œβ”€β”€ reward.py # 4-signal reward rubric
148
+ β”‚ β”œβ”€β”€ server.py # FastAPI server
149
+ β”‚ └── client.py # HTTP client for trainers
150
+ β”œβ”€β”€ tests/test_env.py # Smoke tests
151
+ β”œβ”€β”€ train_colab.py # Full GRPO training script
152
+ β”œβ”€β”€ Dockerfile # HF Spaces deployment
153
+ β”œβ”€β”€ results_comparison.png # Baseline vs trained comparison
154
+ └── training_curves.png # Loss and reward curves
155
+ ```
156
+
157
+ ---
158
+
159
+ ## Links
160
+
161
+ - πŸ€— **HF Space**: [YOUR_HF_USERNAME/revops-gym](https://huggingface.co/spaces/YOUR_HF_USERNAME/revops-gym)
162
+ - πŸ““ **Training Colab**: [Open in Colab](https://colab.research.google.com/drive/YOUR_COLAB_LINK)
163
+ - πŸŽ₯ **Demo Video**: [YouTube](https://youtube.com/YOUR_VIDEO_LINK)
164
+ - πŸ€— **Trained model**: [YOUR_HF_USERNAME/revops-gym-model](https://huggingface.co/YOUR_HF_USERNAME/revops-gym-model)
165
+
166
+ ---
167
+
168
+ ## Hackathon
169
+
170
+ Built for the **OpenEnv Hackathon India April 2026**.
171
+ Theme: #3.1 World Modeling + #1 Multi-Agent Adversarial.
openenv.yaml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: revops-gym
2
+ version: "0.1.0"
3
+ description: >
4
+ A dynamic SaaS "flight simulator" where an LLM agent (the Pilot) must
5
+ balance MRR growth against churn, burn rate, and CAC while an adversarial
6
+ Crisis Engine (Claude) escalates pressure on the agent's weakest metric.
7
+ Inspired by real B2B RevOps decision-making under uncertainty.
8
+ author: your-hf-username
9
+ license: MIT
10
+ theme: world-modeling
11
+ tags:
12
+ - saas
13
+ - revops
14
+ - adversarial
15
+ - multi-agent
16
+ - business-simulation
17
+ environment:
18
+ class: RevOpsEnv
19
+ module: revops_gym.env
20
+ type: Environment
21
+ server:
22
+ port: 7860
23
+ host: "0.0.0.0"
24
+ observation_space:
25
+ type: dict
26
+ fields:
27
+ - name: mrr
28
+ type: float
29
+ description: Monthly Recurring Revenue in USD
30
+ - name: cac
31
+ type: float
32
+ description: Customer Acquisition Cost in USD
33
+ - name: ltv
34
+ type: float
35
+ description: Customer Lifetime Value in USD
36
+ - name: churn_rate
37
+ type: float
38
+ description: Monthly churn rate 0-1
39
+ - name: cash_runway
40
+ type: float
41
+ description: Months of runway remaining
42
+ - name: marketing_spend
43
+ type: float
44
+ description: Current monthly marketing budget
45
+ - name: support_quality
46
+ type: float
47
+ description: Support quality score 0-1
48
+ - name: active_crisis
49
+ type: string
50
+ description: Current adversarial crisis tag or NONE
51
+ - name: step_number
52
+ type: int
53
+ description: Current step in the episode
54
+ - name: ltv_cac_ratio
55
+ type: float
56
+ description: LTV/CAC golden ratio
57
+ action_space:
58
+ type: dict
59
+ fields:
60
+ - name: action_type
61
+ type: string
62
+ enum:
63
+ - increase_marketing
64
+ - decrease_marketing
65
+ - hire_support
66
+ - fire_support
67
+ - discount_campaign
68
+ - raise_prices
69
+ - feature_investment
70
+ - cut_costs
71
+ - negotiate_contracts
72
+ - pivot_segment
73
+ - name: magnitude
74
+ type: float
75
+ description: Strength of action 0.1–1.0
76
+ reward:
77
+ type: composite
78
+ components:
79
+ - name: ltv_cac_ratio
80
+ weight: 0.35
81
+ - name: mrr_growth
82
+ weight: 0.30
83
+ - name: burn_efficiency
84
+ weight: 0.20
85
+ - name: survival_bonus
86
+ weight: 0.15
87
+ termination:
88
+ conditions:
89
+ - mrr_below_floor
90
+ - cash_runway_zero
91
+ - churn_above_ceiling
92
+ max_steps: 30
requirements.txt ADDED
File without changes
revops_gym/__init__.py ADDED
File without changes
revops_gym/client.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RevOps Gym client β€” connects to the running FastAPI server.
3
+
4
+ Trainers import only this module, never server internals.
5
+ Follows OpenEnv client/server separation.
6
+ """
7
+
8
+ from __future__ import annotations
9
+ import requests
10
+ from revops_gym.models import RevOpsObservation
11
+
12
+
13
+ class RevOpsClient:
14
+ """
15
+ Thin HTTP client mirroring the env API.
16
+ Use this in your Colab training script to talk to the HF Space.
17
+
18
+ Example:
19
+ client = RevOpsClient("https://your-hf-space.hf.space")
20
+ obs = client.reset()
21
+ obs = client.step("increase_marketing", 0.7)
22
+ """
23
+
24
+ def __init__(self, base_url: str = "http://localhost:7860"):
25
+ self.base_url = base_url.rstrip("/")
26
+ self._session = requests.Session()
27
+
28
+ def reset(self, seed: int | None = None, difficulty: str = "normal") -> RevOpsObservation:
29
+ resp = self._session.post(
30
+ f"{self.base_url}/reset",
31
+ json={"seed": seed, "difficulty": difficulty},
32
+ timeout=15,
33
+ )
34
+ resp.raise_for_status()
35
+ return RevOpsObservation(**resp.json())
36
+
37
+ def step(self, action_type: str, magnitude: float = 0.5) -> RevOpsObservation:
38
+ resp = self._session.post(
39
+ f"{self.base_url}/step",
40
+ json={"action_type": action_type, "magnitude": magnitude},
41
+ timeout=15,
42
+ )
43
+ resp.raise_for_status()
44
+ return RevOpsObservation(**resp.json())
45
+
46
+ def state(self) -> RevOpsObservation:
47
+ resp = self._session.get(f"{self.base_url}/state", timeout=10)
48
+ resp.raise_for_status()
49
+ return RevOpsObservation(**resp.json())
50
+
51
+ def health(self) -> bool:
52
+ try:
53
+ resp = self._session.get(f"{self.base_url}/health", timeout=5)
54
+ return resp.status_code == 200
55
+ except Exception:
56
+ return False
revops_gym/crisis.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Crisis Engine β€” the adversarial component using Gemini free API.
3
+
4
+ Every 3 steps, examines the agent's current weakness and applies
5
+ a targeted crisis to stress that exact metric. Uses Gemini 2.0 Flash
6
+ (free tier) to generate dynamic, contextual crises. Falls back to
7
+ a fast rule-based engine if the API is unavailable.
8
+ """
9
+
10
+ from __future__ import annotations
11
+ import os
12
+ import json
13
+ import random
14
+ import requests
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ from revops_gym.models import RevOpsState
19
+
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # Static crisis library (rule-based fallback)
23
+ # ---------------------------------------------------------------------------
24
+
25
+ CRISES: dict[str, dict] = {
26
+ "CHURN_SPIKE": {
27
+ "description": "A competitor launched aggressive pricing. Churn surges.",
28
+ "churn_delta": +0.04,
29
+ "mrr_delta_pct": -0.08,
30
+ "targets": "churn_rate",
31
+ },
32
+ "CAC_EXPLOSION": {
33
+ "description": "Ad costs doubled. Marketing efficiency collapses.",
34
+ "cac_multiplier": 1.6,
35
+ "targets": "cac",
36
+ },
37
+ "SUPPORT_CRISIS": {
38
+ "description": "Key support engineers quit. Customer satisfaction tanks.",
39
+ "support_quality_delta": -0.25,
40
+ "churn_delta": +0.02,
41
+ "targets": "support_quality",
42
+ },
43
+ "CASH_CRUNCH": {
44
+ "description": "Unexpected infrastructure bill. Runway shrinks fast.",
45
+ "runway_delta": -3.0,
46
+ "targets": "cash_runway",
47
+ },
48
+ "PRICE_WAR": {
49
+ "description": "Competitors slashed prices. Enterprise deals at risk.",
50
+ "mrr_delta_pct": -0.12,
51
+ "cac_multiplier": 1.3,
52
+ "targets": "mrr",
53
+ },
54
+ "REGULATORY_HIT": {
55
+ "description": "New compliance requirement forces expensive changes.",
56
+ "runway_delta": -2.0,
57
+ "cac_multiplier": 1.2,
58
+ "targets": "cash_runway",
59
+ },
60
+ "ENTERPRISE_CHURN": {
61
+ "description": "Top 3 enterprise accounts cancelled. MRR cliff.",
62
+ "mrr_delta_pct": -0.20,
63
+ "targets": "mrr",
64
+ },
65
+ "TALENT_WAR": {
66
+ "description": "Big tech hiring spree. Engineering costs spike.",
67
+ "runway_delta": -2.5,
68
+ "support_quality_delta": -0.10,
69
+ "targets": "cash_runway",
70
+ },
71
+ }
72
+
73
+ NO_CRISIS = "NONE"
74
+
75
+ GEMINI_API_URL = (
76
+ "https://generativelanguage.googleapis.com/v1beta/models/"
77
+ "gemini-2.0-flash:generateContent"
78
+ )
79
+
80
+
81
+ # ---------------------------------------------------------------------------
82
+ # Metric weakness detector
83
+ # ---------------------------------------------------------------------------
84
+
85
+ def _worst_metric(state: "RevOpsState") -> str:
86
+ """Identify the agent's current Achilles heel."""
87
+ scores = {
88
+ "churn_rate": state.churn_rate / 0.20,
89
+ "cac": state.cac / 5000,
90
+ "support_quality": 1.0 - state.support_quality,
91
+ "cash_runway": max(0, (12 - state.cash_runway) / 12),
92
+ "mrr": max(0, (state.mrr_floor * 2 - state.mrr) / (state.mrr_floor * 2)),
93
+ }
94
+ return max(scores, key=scores.get)
95
+
96
+
97
+ def _rule_based_crisis(state: "RevOpsState") -> dict:
98
+ """Fast fallback: pick the crisis that targets the weakest metric."""
99
+ worst = _worst_metric(state)
100
+ candidates = [
101
+ k for k, v in CRISES.items() if v.get("targets") == worst
102
+ ]
103
+ if not candidates:
104
+ candidates = list(CRISES.keys())
105
+ key = random.choice(candidates)
106
+ return {"crisis_key": key, **CRISES[key]}
107
+
108
+
109
+ # ---------------------------------------------------------------------------
110
+ # Gemini-powered crisis generator
111
+ # ---------------------------------------------------------------------------
112
+
113
+ def _gemini_crisis(state: "RevOpsState", api_key: str) -> dict | None:
114
+ """
115
+ Ask Gemini to pick the most devious crisis given the current state.
116
+ Returns a dict matching CRISES schema, or None if call fails.
117
+ """
118
+ worst = _worst_metric(state)
119
+ prompt = f"""You are the adversary in a SaaS business simulation game.
120
+ The LLM agent (Pilot) is managing a SaaS company. Here is the current state:
121
+
122
+ - MRR: ${state.mrr:,.0f} (survival floor: $20,000)
123
+ - CAC: ${state.cac:,.0f}
124
+ - LTV: ${state.ltv:,.0f} (LTV/CAC ratio: {state.ltv_cac_ratio:.2f}x)
125
+ - Churn rate: {state.churn_rate*100:.1f}%
126
+ - Cash runway: {state.cash_runway:.1f} months
127
+ - Support quality: {state.support_quality*100:.0f}%
128
+ - Current weakest metric: {worst}
129
+
130
+ You must pick ONE crisis from this list that will cause maximum pain by targeting the agent's weakness:
131
+ {json.dumps(list(CRISES.keys()), indent=2)}
132
+
133
+ Respond ONLY with a valid JSON object with these fields:
134
+ {{
135
+ "crisis_key": "<one of the keys above>",
136
+ "description": "<1-sentence dramatic business news headline>",
137
+ "churn_delta": <float or 0>,
138
+ "mrr_delta_pct": <float or 0>,
139
+ "cac_multiplier": <float or 1.0>,
140
+ "support_quality_delta": <float or 0>,
141
+ "runway_delta": <float or 0>,
142
+ "targets": "<metric name>"
143
+ }}
144
+
145
+ Be creative with the description but keep the numeric deltas within these bounds:
146
+ - churn_delta: 0 to 0.06
147
+ - mrr_delta_pct: -0.25 to 0
148
+ - cac_multiplier: 1.0 to 2.0
149
+ - support_quality_delta: -0.35 to 0
150
+ - runway_delta: -4.0 to 0
151
+ """
152
+
153
+ try:
154
+ resp = requests.post(
155
+ f"{GEMINI_API_URL}?key={api_key}",
156
+ json={
157
+ "contents": [{"parts": [{"text": prompt}]}],
158
+ "generationConfig": {
159
+ "temperature": 0.9,
160
+ "maxOutputTokens": 512,
161
+ "responseMimeType": "application/json",
162
+ },
163
+ },
164
+ timeout=10,
165
+ )
166
+ resp.raise_for_status()
167
+ data = resp.json()
168
+ text = data["candidates"][0]["content"]["parts"][0]["text"]
169
+ crisis = json.loads(text)
170
+ # Validate crisis_key exists
171
+ if crisis.get("crisis_key") not in CRISES:
172
+ crisis["crisis_key"] = random.choice(list(CRISES.keys()))
173
+ return crisis
174
+ except Exception as e:
175
+ print(f"[CrisisEngine] Gemini call failed ({e}), using rule-based fallback.")
176
+ return None
177
+
178
+
179
+ # ---------------------------------------------------------------------------
180
+ # Public interface
181
+ # ---------------------------------------------------------------------------
182
+
183
+ class CrisisEngine:
184
+ """
185
+ Generates adversarial crises every N steps.
186
+ Uses Gemini free API if GEMINI_API_KEY is set, else rule-based.
187
+ """
188
+
189
+ def __init__(self, crisis_every: int = 3):
190
+ self.crisis_every = crisis_every
191
+ self.api_key = os.environ.get("GEMINI_API_KEY", "")
192
+ self._last_crisis: dict | None = None
193
+
194
+ def should_trigger(self, step: int) -> bool:
195
+ return step > 0 and step % self.crisis_every == 0
196
+
197
+ def generate(self, state: "RevOpsState") -> dict:
198
+ """Return a crisis dict to apply to the state."""
199
+ crisis = None
200
+ if self.api_key:
201
+ crisis = _gemini_crisis(state, self.api_key)
202
+ if crisis is None:
203
+ crisis = _rule_based_crisis(state)
204
+ self._last_crisis = crisis
205
+ return crisis
206
+
207
+ def apply(self, state: "RevOpsState", crisis: dict) -> "RevOpsState":
208
+ """Mutate state according to crisis parameters."""
209
+ data = state.model_dump()
210
+
211
+ if crisis.get("churn_delta"):
212
+ data["churn_rate"] = min(0.25, data["churn_rate"] + crisis["churn_delta"])
213
+
214
+ if crisis.get("mrr_delta_pct"):
215
+ data["mrr"] = max(0, data["mrr"] * (1 + crisis["mrr_delta_pct"]))
216
+
217
+ if crisis.get("cac_multiplier", 1.0) != 1.0:
218
+ data["cac"] = data["cac"] * crisis["cac_multiplier"]
219
+
220
+ if crisis.get("support_quality_delta"):
221
+ data["support_quality"] = max(
222
+ 0.0, min(1.0, data["support_quality"] + crisis["support_quality_delta"])
223
+ )
224
+
225
+ if crisis.get("runway_delta"):
226
+ data["cash_runway"] = max(0, data["cash_runway"] + crisis["runway_delta"])
227
+
228
+ data["active_crisis"] = crisis.get("crisis_key", "UNKNOWN")
229
+
230
+ from revops_gym.models import RevOpsState as RS
231
+ return RS(**data)
revops_gym/env.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RevOps Gym β€” core environment.
3
+
4
+ Implements OpenEnv's Environment base interface:
5
+ reset() β†’ RevOpsObservation
6
+ step(action) β†’ RevOpsObservation
7
+ state() β†’ RevOpsObservation
8
+
9
+ World dynamics:
10
+ - Actions mutate MRR, CAC, LTV, churn, runway, support quality
11
+ - Every 3 steps the Crisis Engine (Gemini) applies an adversarial shock
12
+ - Four-signal reward rubric scored after every step
13
+ - Episode terminates if MRR < $20k, runway ≀ 0, churn > 20%, or step 30
14
+ """
15
+
16
+ from __future__ import annotations
17
+ import random
18
+ from typing import Any
19
+
20
+ from revops_gym.models import RevOpsState, RevOpsAction, RevOpsObservation
21
+ from revops_gym.crisis import CrisisEngine
22
+ from revops_gym.reward import RewardRubric
23
+
24
+
25
+ # ---------------------------------------------------------------------------
26
+ # Action effect table
27
+ # Each action modifies state via multipliers / deltas.
28
+ # magnitude (0.1–1.0) scales the effect linearly.
29
+ # ---------------------------------------------------------------------------
30
+
31
+ def _apply_action(state: RevOpsState, action: RevOpsAction) -> RevOpsState:
32
+ data = state.model_dump()
33
+ m = action.magnitude # scale factor
34
+
35
+ if action.action_type == "increase_marketing":
36
+ # More spend β†’ lower CAC over time, higher MRR growth
37
+ spend_increase = data["marketing_spend"] * 0.3 * m
38
+ data["marketing_spend"] += spend_increase
39
+ data["cash_runway"] -= spend_increase / 10_000 * 0.5
40
+ data["mrr"] *= 1 + 0.06 * m
41
+ data["cac"] *= 1 - 0.05 * m # efficiency improves with scale
42
+
43
+ elif action.action_type == "decrease_marketing":
44
+ spend_decrease = data["marketing_spend"] * 0.3 * m
45
+ data["marketing_spend"] = max(1000, data["marketing_spend"] - spend_decrease)
46
+ data["cash_runway"] += spend_decrease / 10_000 * 0.4
47
+ data["mrr"] *= 1 - 0.03 * m # growth slows
48
+
49
+ elif action.action_type == "hire_support":
50
+ cost = 5_000 * m
51
+ data["cash_runway"] -= cost / 10_000
52
+ data["support_quality"] = min(1.0, data["support_quality"] + 0.15 * m)
53
+ data["churn_rate"] = max(0.005, data["churn_rate"] - 0.02 * m)
54
+ data["ltv"] *= 1 + 0.05 * m
55
+
56
+ elif action.action_type == "fire_support":
57
+ data["cash_runway"] += 0.3 * m
58
+ data["support_quality"] = max(0.0, data["support_quality"] - 0.20 * m)
59
+ data["churn_rate"] = min(0.25, data["churn_rate"] + 0.03 * m)
60
+
61
+ elif action.action_type == "discount_campaign":
62
+ # Short-term MRR boost, hurts LTV
63
+ data["mrr"] *= 1 + 0.10 * m
64
+ data["ltv"] *= 1 - 0.08 * m
65
+ data["cac"] *= 1 - 0.10 * m # cheaper to acquire
66
+
67
+ elif action.action_type == "raise_prices":
68
+ # Some churn, better LTV for retained customers
69
+ data["churn_rate"] = min(0.25, data["churn_rate"] + 0.02 * m)
70
+ data["mrr"] *= 1 + 0.08 * m * (1 - data["churn_rate"])
71
+ data["ltv"] *= 1 + 0.12 * m
72
+
73
+ elif action.action_type == "feature_investment":
74
+ cost = 8_000 * m
75
+ data["cash_runway"] -= cost / 10_000
76
+ data["ltv"] *= 1 + 0.10 * m
77
+ data["churn_rate"] = max(0.005, data["churn_rate"] - 0.01 * m)
78
+
79
+ elif action.action_type == "cut_costs":
80
+ data["cash_runway"] += 1.5 * m
81
+ data["marketing_spend"] *= 1 - 0.15 * m
82
+ data["mrr"] *= 1 - 0.02 * m # slight growth slowdown
83
+
84
+ elif action.action_type == "negotiate_contracts":
85
+ # Longer contracts β†’ lower churn, higher committed LTV
86
+ data["churn_rate"] = max(0.005, data["churn_rate"] - 0.025 * m)
87
+ data["ltv"] *= 1 + 0.08 * m
88
+ data["cac"] *= 1 + 0.05 * m # takes effort to close
89
+
90
+ elif action.action_type == "pivot_segment":
91
+ # High risk / high reward β€” randomised outcome
92
+ success = random.random() < (0.4 + 0.3 * m)
93
+ if success:
94
+ data["mrr"] *= 1 + 0.15 * m
95
+ data["cac"] *= 0.85
96
+ else:
97
+ data["mrr"] *= 1 - 0.10 * m
98
+ data["cac"] *= 1.20
99
+
100
+ # Natural churn effects on MRR every step
101
+ churned_mrr = data["mrr"] * data["churn_rate"] * 0.5
102
+ data["mrr"] = max(0, data["mrr"] - churned_mrr)
103
+
104
+ # Clamp values to sane ranges
105
+ data["churn_rate"] = max(0.005, min(0.25, data["churn_rate"]))
106
+ data["support_quality"] = max(0.0, min(1.0, data["support_quality"]))
107
+ data["cash_runway"] = max(0.0, data["cash_runway"])
108
+ data["cac"] = max(100.0, data["cac"])
109
+ data["ltv"] = max(data["cac"], data["ltv"])
110
+ data["mrr"] = max(0.0, data["mrr"])
111
+ data["marketing_spend"] = max(500.0, data["marketing_spend"])
112
+
113
+ data["step_number"] = state.step_number + 1
114
+ data["active_crisis"] = "NONE" # crisis engine overwrites this if triggered
115
+
116
+ return RevOpsState(**data)
117
+
118
+
119
+ # ---------------------------------------------------------------------------
120
+ # Environment
121
+ # ---------------------------------------------------------------------------
122
+
123
+ class RevOpsEnv:
124
+ """
125
+ OpenEnv-compliant environment for RevOps Gym.
126
+
127
+ Usage:
128
+ env = RevOpsEnv()
129
+ obs = env.reset()
130
+ obs = env.step({"action_type": "increase_marketing", "magnitude": 0.6})
131
+ current = env.state()
132
+ """
133
+
134
+ metadata = {"render_modes": ["text"], "version": "0.1.0"}
135
+
136
+ def __init__(
137
+ self,
138
+ crisis_every: int = 3,
139
+ seed: int | None = None,
140
+ difficulty: str = "normal", # "easy" | "normal" | "hard"
141
+ ):
142
+ self.crisis_every = crisis_every
143
+ self.seed = seed
144
+ self.difficulty = difficulty
145
+ self._rng = random.Random(seed)
146
+ self._crisis_engine = CrisisEngine(crisis_every=crisis_every)
147
+ self._rubric = RewardRubric()
148
+ self._state: RevOpsState = RevOpsState()
149
+ self._prev_state: RevOpsState | None = None
150
+ self._episode_rewards: list[float] = []
151
+
152
+ # ------------------------------------------------------------------
153
+ # OpenEnv core API
154
+ # ------------------------------------------------------------------
155
+
156
+ def reset(self, seed: int | None = None) -> RevOpsObservation:
157
+ """Start a fresh episode."""
158
+ if seed is not None:
159
+ self._rng = random.Random(seed)
160
+
161
+ base = {
162
+ "mrr": self._rng.uniform(40_000, 80_000),
163
+ "cac": self._rng.uniform(1_500, 3_000),
164
+ "ltv": self._rng.uniform(8_000, 18_000),
165
+ "churn_rate": self._rng.uniform(0.02, 0.06),
166
+ "cash_runway": self._rng.uniform(12, 24),
167
+ "marketing_spend": self._rng.uniform(10_000, 25_000),
168
+ "support_quality": self._rng.uniform(0.60, 0.90),
169
+ "active_crisis": "NONE",
170
+ "step_number": 0,
171
+ }
172
+
173
+ # Difficulty adjustments
174
+ if self.difficulty == "easy":
175
+ base["cash_runway"] *= 1.5
176
+ base["churn_rate"] *= 0.6
177
+ elif self.difficulty == "hard":
178
+ base["cash_runway"] *= 0.6
179
+ base["churn_rate"] *= 1.4
180
+ base["cac"] *= 1.3
181
+
182
+ self._state = RevOpsState(**base)
183
+ self._prev_state = None
184
+ self._episode_rewards = []
185
+
186
+ return self._to_observation(reward=None, terminated=False, truncated=False)
187
+
188
+ def step(self, action: dict | RevOpsAction) -> RevOpsObservation:
189
+ """Apply an action and advance the world by one step."""
190
+ if isinstance(action, dict):
191
+ action = RevOpsAction(**action)
192
+
193
+ if self._state.is_terminal:
194
+ # Already done β€” return terminal observation
195
+ return self._to_observation(reward=0.0, terminated=True, truncated=False)
196
+
197
+ prev = self._state
198
+ new_state = _apply_action(self._state, action)
199
+
200
+ # Apply adversarial crisis every N steps
201
+ if self._crisis_engine.should_trigger(new_state.step_number):
202
+ crisis = self._crisis_engine.generate(new_state)
203
+ new_state = self._crisis_engine.apply(new_state, crisis)
204
+
205
+ self._prev_state = prev
206
+ self._state = new_state
207
+
208
+ terminated = self._state.is_terminal and self._state.step_number < 30
209
+ truncated = self._state.step_number >= 30
210
+
211
+ reward_breakdown = self._rubric.compute(
212
+ self._state, prev, terminated=terminated
213
+ )
214
+ reward = reward_breakdown.total
215
+ self._episode_rewards.append(reward)
216
+
217
+ return self._to_observation(
218
+ reward=reward,
219
+ terminated=terminated,
220
+ truncated=truncated,
221
+ info={
222
+ "reward_breakdown": reward_breakdown.to_dict(),
223
+ "crisis_applied": self._state.active_crisis,
224
+ "episode_mean_reward": sum(self._episode_rewards) / len(self._episode_rewards),
225
+ },
226
+ )
227
+
228
+ def state(self) -> RevOpsObservation:
229
+ """Return current observation without advancing."""
230
+ return self._to_observation(reward=None, terminated=self._state.is_terminal)
231
+
232
+ # ------------------------------------------------------------------
233
+ # Helpers
234
+ # ------------------------------------------------------------------
235
+
236
+ def _to_observation(
237
+ self,
238
+ reward: float | None,
239
+ terminated: bool,
240
+ truncated: bool = False,
241
+ info: dict | None = None,
242
+ ) -> RevOpsObservation:
243
+ s = self._state
244
+ return RevOpsObservation(
245
+ mrr=round(s.mrr, 2),
246
+ cac=round(s.cac, 2),
247
+ ltv=round(s.ltv, 2),
248
+ churn_rate=round(s.churn_rate, 4),
249
+ cash_runway=round(s.cash_runway, 2),
250
+ marketing_spend=round(s.marketing_spend, 2),
251
+ support_quality=round(s.support_quality, 4),
252
+ active_crisis=s.active_crisis,
253
+ step_number=s.step_number,
254
+ ltv_cac_ratio=round(s.ltv_cac_ratio, 3),
255
+ reward_last_step=reward,
256
+ terminated=terminated,
257
+ truncated=truncated,
258
+ info=info or {},
259
+ )
260
+
261
+ def render(self) -> str:
262
+ """Text render of current state (for debugging)."""
263
+ s = self._state
264
+ return (
265
+ f"\n{'='*50}\n"
266
+ f"RevOps Dashboard | Step {s.step_number}/30\n"
267
+ f"{'='*50}\n"
268
+ f"MRR: ${s.mrr:>12,.0f} (floor: $20,000)\n"
269
+ f"CAC: ${s.cac:>12,.0f}\n"
270
+ f"LTV: ${s.ltv:>12,.0f}\n"
271
+ f"LTV/CAC: {s.ltv_cac_ratio:>12.2f}x (target: 3x+)\n"
272
+ f"Churn: {s.churn_rate*100:>11.1f}%\n"
273
+ f"Cash runway: {s.cash_runway:>11.1f} months\n"
274
+ f"Mktg spend: ${s.marketing_spend:>12,.0f}/mo\n"
275
+ f"Support QA: {s.support_quality*100:>11.0f}%\n"
276
+ f"Crisis: {s.active_crisis:>12}\n"
277
+ f"{'='*50}\n"
278
+ )
revops_gym/models.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data models for RevOps Gym."""
2
+
3
+ from __future__ import annotations
4
+ from typing import Literal, Optional
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ # ---------------------------------------------------------------------------
9
+ # Action
10
+ # ---------------------------------------------------------------------------
11
+
12
+ ActionType = Literal[
13
+ "increase_marketing",
14
+ "decrease_marketing",
15
+ "hire_support",
16
+ "fire_support",
17
+ "discount_campaign",
18
+ "raise_prices",
19
+ "feature_investment",
20
+ "cut_costs",
21
+ "negotiate_contracts",
22
+ "pivot_segment",
23
+ ]
24
+
25
+
26
+ class RevOpsAction(BaseModel):
27
+ action_type: ActionType = Field(
28
+ description="The strategic lever the agent pulls."
29
+ )
30
+ magnitude: float = Field(
31
+ default=0.5,
32
+ ge=0.1,
33
+ le=1.0,
34
+ description="Strength of the action, 0.1 (subtle) to 1.0 (aggressive).",
35
+ )
36
+
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # World state (internal)
40
+ # ---------------------------------------------------------------------------
41
+
42
+ class RevOpsState(BaseModel):
43
+ mrr: float = Field(default=50_000.0, description="Monthly Recurring Revenue USD")
44
+ cac: float = Field(default=2_000.0, description="Customer Acquisition Cost USD")
45
+ ltv: float = Field(default=12_000.0, description="Customer Lifetime Value USD")
46
+ churn_rate: float = Field(default=0.04, description="Monthly churn rate 0-1")
47
+ cash_runway: float = Field(default=18.0, description="Months of cash runway")
48
+ marketing_spend: float = Field(default=15_000.0, description="Monthly marketing budget USD")
49
+ support_quality: float = Field(default=0.75, description="Support quality score 0-1")
50
+ active_crisis: str = Field(default="NONE", description="Current adversarial crisis or NONE")
51
+ step_number: int = Field(default=0)
52
+
53
+ @property
54
+ def ltv_cac_ratio(self) -> float:
55
+ return self.ltv / max(self.cac, 1.0)
56
+
57
+ @property
58
+ def mrr_floor(self) -> float:
59
+ """VC survival floor β€” must stay above this."""
60
+ return 20_000.0
61
+
62
+ @property
63
+ def is_terminal(self) -> bool:
64
+ return (
65
+ self.mrr < self.mrr_floor
66
+ or self.cash_runway <= 0
67
+ or self.churn_rate > 0.20
68
+ or self.step_number >= 30
69
+ )
70
+
71
+
72
+ # ---------------------------------------------------------------------------
73
+ # Observation (what the agent sees)
74
+ # ---------------------------------------------------------------------------
75
+
76
+ class RevOpsObservation(BaseModel):
77
+ mrr: float
78
+ cac: float
79
+ ltv: float
80
+ churn_rate: float
81
+ cash_runway: float
82
+ marketing_spend: float
83
+ support_quality: float
84
+ active_crisis: str
85
+ step_number: int
86
+ ltv_cac_ratio: float
87
+ reward_last_step: Optional[float] = None
88
+ terminated: bool = False
89
+ truncated: bool = False
90
+ info: dict = Field(default_factory=dict)
91
+
92
+ def to_prompt_text(self) -> str:
93
+ """Convert observation to a concise text prompt for the LLM."""
94
+ crisis_text = (
95
+ f"\n⚠️ ACTIVE CRISIS: {self.active_crisis}"
96
+ if self.active_crisis != "NONE"
97
+ else ""
98
+ )
99
+ return (
100
+ f"=== RevOps Dashboard | Step {self.step_number}/30 ==={crisis_text}\n"
101
+ f"MRR: ${self.mrr:,.0f} | Floor: $20,000\n"
102
+ f"CAC: ${self.cac:,.0f} | LTV: ${self.ltv:,.0f} | LTV/CAC: {self.ltv_cac_ratio:.2f}x\n"
103
+ f"Churn: {self.churn_rate*100:.1f}% | Runway: {self.cash_runway:.1f}mo\n"
104
+ f"Marketing spend: ${self.marketing_spend:,.0f}/mo | Support quality: {self.support_quality*100:.0f}%\n"
105
+ f"Last reward: {self.reward_last_step or 0:.3f}\n"
106
+ "\nAvailable actions: increase_marketing, decrease_marketing, hire_support, "
107
+ "fire_support, discount_campaign, raise_prices, feature_investment, "
108
+ "cut_costs, negotiate_contracts, pivot_segment\n"
109
+ "Respond ONLY with JSON: {\"action_type\": \"...\", \"magnitude\": 0.0-1.0}"
110
+ )
revops_gym/reward.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reward rubric for RevOps Gym.
3
+
4
+ Four independent reward signals following OpenEnv's Rubric pattern.
5
+ Multiple signals prevent reward hacking β€” an agent can't fake the ratio
6
+ without also controlling churn and burn rate simultaneously.
7
+ """
8
+
9
+ from __future__ import annotations
10
+ from dataclasses import dataclass
11
+ from typing import TYPE_CHECKING
12
+
13
+ if TYPE_CHECKING:
14
+ from revops_gym.models import RevOpsState
15
+
16
+
17
+ @dataclass
18
+ class RewardBreakdown:
19
+ ltv_cac: float # 0-1 golden ratio signal
20
+ mrr_growth: float # 0-1 revenue trajectory
21
+ burn_efficiency: float # 0-1 not burning cash recklessly
22
+ survival_bonus: float # 0 or 1 staying alive
23
+ total: float
24
+ terminated_penalty: float = 0.0
25
+
26
+ def to_dict(self) -> dict:
27
+ return {
28
+ "ltv_cac": round(self.ltv_cac, 4),
29
+ "mrr_growth": round(self.mrr_growth, 4),
30
+ "burn_efficiency": round(self.burn_efficiency, 4),
31
+ "survival_bonus": round(self.survival_bonus, 4),
32
+ "terminated_penalty": round(self.terminated_penalty, 4),
33
+ "total": round(self.total, 4),
34
+ }
35
+
36
+
37
+ class RewardRubric:
38
+ """
39
+ Computes composite reward from current and previous state.
40
+
41
+ Weights:
42
+ ltv_cac_ratio 35% β€” the "golden ratio" of SaaS health
43
+ mrr_growth 30% β€” revenue trajectory
44
+ burn_efficiency 20% β€” sustainable spending
45
+ survival_bonus 15% β€” staying above the VC floor
46
+
47
+ Termination penalty: -2.0 applied on top if the company dies.
48
+ """
49
+
50
+ WEIGHTS = {
51
+ "ltv_cac": 0.35,
52
+ "mrr_growth": 0.30,
53
+ "burn_efficiency": 0.20,
54
+ "survival_bonus": 0.15,
55
+ }
56
+
57
+ # Target benchmarks for SaaS health
58
+ TARGET_LTV_CAC = 3.0 # 3x is "good", 5x+ is excellent
59
+ TARGET_CHURN = 0.02 # 2% monthly is good SaaS
60
+ MAX_BURN_RATIO = 0.50 # marketing spend / MRR ceiling
61
+
62
+ def compute(
63
+ self,
64
+ state: "RevOpsState",
65
+ prev_state: "RevOpsState | None" = None,
66
+ terminated: bool = False,
67
+ ) -> RewardBreakdown:
68
+
69
+ # --- Signal 1: LTV/CAC golden ratio ---
70
+ ratio = state.ltv_cac_ratio
71
+ if ratio >= self.TARGET_LTV_CAC:
72
+ ltv_cac_score = min(1.0, (ratio - self.TARGET_LTV_CAC) / 2.0 * 0.5 + 0.75)
73
+ elif ratio >= 1.0:
74
+ ltv_cac_score = (ratio - 1.0) / (self.TARGET_LTV_CAC - 1.0) * 0.75
75
+ else:
76
+ # ratio < 1.0 means losing money per customer β†’ negative signal
77
+ ltv_cac_score = max(-0.5, (ratio - 1.0) * 0.5)
78
+
79
+ # --- Signal 2: MRR growth ---
80
+ if prev_state is not None:
81
+ mrr_change = (state.mrr - prev_state.mrr) / max(prev_state.mrr, 1)
82
+ # Normalize: +10% growth = 1.0, flat = 0.3, -20% = 0
83
+ mrr_growth_score = max(0.0, min(1.0, mrr_change * 5.0 + 0.3))
84
+ else:
85
+ # First step β€” reward for being above the floor
86
+ floor_margin = (state.mrr - state.mrr_floor) / state.mrr_floor
87
+ mrr_growth_score = min(1.0, max(0.0, floor_margin * 0.5 + 0.5))
88
+
89
+ # --- Signal 3: Burn efficiency ---
90
+ burn_ratio = state.marketing_spend / max(state.mrr, 1)
91
+ if burn_ratio <= self.MAX_BURN_RATIO:
92
+ burn_score = 1.0 - (burn_ratio / self.MAX_BURN_RATIO) * 0.3
93
+ else:
94
+ burn_score = max(0.0, 1.0 - burn_ratio)
95
+
96
+ # Penalize bad support quality (hidden churn driver)
97
+ support_penalty = max(0.0, 0.75 - state.support_quality) * 0.4
98
+ burn_score = max(0.0, burn_score - support_penalty)
99
+
100
+ # --- Signal 4: Survival bonus ---
101
+ if state.mrr >= state.mrr_floor and state.cash_runway > 3:
102
+ runway_bonus = min(0.5, state.cash_runway / 24) * 0.5
103
+ survival_score = 0.5 + runway_bonus
104
+ elif state.mrr >= state.mrr_floor:
105
+ survival_score = 0.2 # alive but barely
106
+ else:
107
+ survival_score = 0.0
108
+
109
+ # Churn penalty on survival
110
+ if state.churn_rate > 0.10:
111
+ survival_score *= 0.5
112
+
113
+ # --- Weighted total ---
114
+ total = (
115
+ ltv_cac_score * self.WEIGHTS["ltv_cac"]
116
+ + mrr_growth_score * self.WEIGHTS["mrr_growth"]
117
+ + burn_score * self.WEIGHTS["burn_efficiency"]
118
+ + survival_score * self.WEIGHTS["survival_bonus"]
119
+ )
120
+
121
+ # --- Termination penalty ---
122
+ term_penalty = 0.0
123
+ if terminated and state.mrr < state.mrr_floor:
124
+ term_penalty = -2.0
125
+ total += term_penalty
126
+
127
+ return RewardBreakdown(
128
+ ltv_cac=ltv_cac_score,
129
+ mrr_growth=mrr_growth_score,
130
+ burn_efficiency=burn_score,
131
+ survival_bonus=survival_score,
132
+ total=total,
133
+ terminated_penalty=term_penalty,
134
+ )
revops_gym/server.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI server for RevOps Gym.
3
+ Exposes reset / step / state endpoints per the OpenEnv spec.
4
+ """
5
+
6
+ from __future__ import annotations
7
+ import os
8
+ from fastapi import FastAPI, HTTPException
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from fastapi.responses import HTMLResponse
11
+ from pydantic import BaseModel
12
+
13
+ from revops_gym.env import RevOpsEnv
14
+ from revops_gym.models import RevOpsAction, RevOpsObservation
15
+
16
+ app = FastAPI(
17
+ title="RevOps Gym",
18
+ description="A dynamic SaaS flight simulator for LLM RL training.",
19
+ version="0.1.0",
20
+ )
21
+
22
+ app.add_middleware(
23
+ CORSMiddleware,
24
+ allow_origins=["*"],
25
+ allow_methods=["*"],
26
+ allow_headers=["*"],
27
+ )
28
+
29
+ # Single shared env instance (stateless for multi-client use, fork per session)
30
+ _env = RevOpsEnv(
31
+ crisis_every=int(os.environ.get("CRISIS_EVERY", "3")),
32
+ difficulty=os.environ.get("DIFFICULTY", "normal"),
33
+ )
34
+
35
+
36
+ class ResetRequest(BaseModel):
37
+ seed: int | None = None
38
+ difficulty: str = "normal"
39
+
40
+
41
+ class StepRequest(BaseModel):
42
+ action_type: str
43
+ magnitude: float = 0.5
44
+
45
+
46
+ # ------------------------------------------------------------------
47
+ # OpenEnv required endpoints
48
+ # ------------------------------------------------------------------
49
+
50
+ @app.post("/reset", response_model=RevOpsObservation)
51
+ def reset(req: ResetRequest = ResetRequest()):
52
+ global _env
53
+ _env = RevOpsEnv(
54
+ crisis_every=int(os.environ.get("CRISIS_EVERY", "3")),
55
+ difficulty=req.difficulty,
56
+ )
57
+ return _env.reset(seed=req.seed)
58
+
59
+
60
+ @app.post("/step", response_model=RevOpsObservation)
61
+ def step(req: StepRequest):
62
+ try:
63
+ action = RevOpsAction(action_type=req.action_type, magnitude=req.magnitude)
64
+ except Exception as e:
65
+ raise HTTPException(status_code=422, detail=str(e))
66
+ return _env.step(action)
67
+
68
+
69
+ @app.get("/state", response_model=RevOpsObservation)
70
+ def state():
71
+ return _env.state()
72
+
73
+
74
+ # ------------------------------------------------------------------
75
+ # Optional: human-readable demo UI
76
+ # ------------------------------------------------------------------
77
+
78
+ @app.get("/", response_class=HTMLResponse)
79
+ def index():
80
+ s = _env.state()
81
+ crisis_html = (
82
+ f'<div class="crisis">⚠️ CRISIS: {s.active_crisis}</div>'
83
+ if s.active_crisis != "NONE"
84
+ else ""
85
+ )
86
+ return f"""<!DOCTYPE html>
87
+ <html><head><title>RevOps Gym</title>
88
+ <style>
89
+ body {{ font-family: monospace; background: #0d1117; color: #c9d1d9; padding: 2rem; }}
90
+ h1 {{ color: #58a6ff; }}
91
+ .metric {{ display: inline-block; margin: 0.5rem 1rem; padding: 0.5rem 1rem;
92
+ background: #161b22; border: 1px solid #30363d; border-radius: 6px; }}
93
+ .metric .label {{ font-size: 0.75rem; color: #8b949e; }}
94
+ .metric .value {{ font-size: 1.4rem; color: #3fb950; font-weight: bold; }}
95
+ .crisis {{ background: #3d1a1a; border: 1px solid #f85149; border-radius: 6px;
96
+ padding: 0.75rem 1rem; margin: 1rem 0; color: #f85149; }}
97
+ form {{ margin: 1.5rem 0; }}
98
+ select, input {{ background: #161b22; color: #c9d1d9; border: 1px solid #30363d;
99
+ padding: 0.4rem 0.6rem; border-radius: 4px; margin: 0.3rem; }}
100
+ button {{ background: #238636; color: #fff; border: none; padding: 0.5rem 1.2rem;
101
+ border-radius: 4px; cursor: pointer; }}
102
+ button:hover {{ background: #2ea043; }}
103
+ </style></head><body>
104
+ <h1>πŸš€ RevOps Gym</h1>
105
+ <p>SaaS Flight Simulator β€” Step {s.step_number}/30 | LTV/CAC: {s.ltv_cac_ratio:.2f}x</p>
106
+ {crisis_html}
107
+ <div>
108
+ <div class="metric"><div class="label">MRR</div><div class="value">${s.mrr:,.0f}</div></div>
109
+ <div class="metric"><div class="label">CAC</div><div class="value">${s.cac:,.0f}</div></div>
110
+ <div class="metric"><div class="label">LTV</div><div class="value">${s.ltv:,.0f}</div></div>
111
+ <div class="metric"><div class="label">Churn</div><div class="value">{s.churn_rate*100:.1f}%</div></div>
112
+ <div class="metric"><div class="label">Runway</div><div class="value">{s.cash_runway:.1f}mo</div></div>
113
+ <div class="metric"><div class="label">Mktg $</div><div class="value">${s.marketing_spend:,.0f}</div></div>
114
+ <div class="metric"><div class="label">Support</div><div class="value">{s.support_quality*100:.0f}%</div></div>
115
+ </div>
116
+ <form action="/step" method="post" onsubmit="takeAction(event)">
117
+ <label>Action:
118
+ <select id="action_type">
119
+ <option>increase_marketing</option><option>decrease_marketing</option>
120
+ <option>hire_support</option><option>fire_support</option>
121
+ <option>discount_campaign</option><option>raise_prices</option>
122
+ <option>feature_investment</option><option>cut_costs</option>
123
+ <option>negotiate_contracts</option><option>pivot_segment</option>
124
+ </select>
125
+ </label>
126
+ <label>Magnitude: <input id="magnitude" type="range" min="0.1" max="1" step="0.1" value="0.5"></label>
127
+ <button type="submit">Take Action</button>
128
+ <button type="button" onclick="doReset()">Reset Episode</button>
129
+ </form>
130
+ <div id="result"></div>
131
+ <script>
132
+ async function takeAction(e) {{
133
+ e.preventDefault();
134
+ const res = await fetch('/step', {{method:'POST',
135
+ headers:{{'Content-Type':'application/json'}},
136
+ body: JSON.stringify({{
137
+ action_type: document.getElementById('action_type').value,
138
+ magnitude: parseFloat(document.getElementById('magnitude').value)
139
+ }})
140
+ }});
141
+ const d = await res.json();
142
+ document.getElementById('result').innerHTML =
143
+ '<pre>' + JSON.stringify(d, null, 2) + '</pre>';
144
+ location.reload();
145
+ }}
146
+ async function doReset() {{
147
+ await fetch('/reset', {{method:'POST', headers:{{'Content-Type':'application/json'}}, body:'{{}}'}});
148
+ location.reload();
149
+ }}
150
+ </script>
151
+ </body></html>"""
152
+
153
+
154
+ @app.get("/health")
155
+ def health():
156
+ return {"status": "ok", "version": "0.1.0"}
setup.py ADDED
File without changes
tests/test_env.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quick smoke test β€” run locally before pushing to HF Spaces.
3
+ Tests: reset, step through full episode, crisis triggers, reward signals.
4
+
5
+ Usage:
6
+ python tests/test_env.py
7
+ """
8
+
9
+ import sys
10
+ import os
11
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
+
13
+ from revops_gym.env import RevOpsEnv
14
+ from revops_gym.models import RevOpsAction
15
+
16
+
17
+ def test_episode(difficulty="normal", seed=42):
18
+ print(f"\n=== Smoke test | difficulty={difficulty} seed={seed} ===")
19
+ env = RevOpsEnv(crisis_every=3, seed=seed, difficulty=difficulty)
20
+ obs = env.reset(seed=seed)
21
+ assert obs.step_number == 0, "Reset should start at step 0"
22
+ assert obs.mrr > 0, "MRR should be positive after reset"
23
+ print(env.render())
24
+
25
+ actions = [
26
+ ("increase_marketing", 0.6),
27
+ ("hire_support", 0.8),
28
+ ("negotiate_contracts", 0.5),
29
+ ("raise_prices", 0.4),
30
+ ("feature_investment", 0.7),
31
+ ("cut_costs", 0.3),
32
+ ("discount_campaign", 0.5),
33
+ ("increase_marketing", 0.7),
34
+ ("hire_support", 0.5),
35
+ ("pivot_segment", 0.6),
36
+ ]
37
+
38
+ rewards = []
39
+ crises_seen = []
40
+ for i, (action_type, magnitude) in enumerate(actions):
41
+ obs = env.step({"action_type": action_type, "magnitude": magnitude})
42
+ rewards.append(obs.reward_last_step)
43
+ if obs.active_crisis != "NONE":
44
+ crises_seen.append(obs.active_crisis)
45
+ print(
46
+ f" Step {obs.step_number:2d} | {action_type:<22} mag={magnitude} "
47
+ f"| reward={obs.reward_last_step:+.3f} | MRR=${obs.mrr:,.0f} "
48
+ f"| LTV/CAC={obs.ltv_cac_ratio:.2f}x"
49
+ + (f" | ⚠️ {obs.active_crisis}" if obs.active_crisis != "NONE" else "")
50
+ )
51
+ if obs.terminated or obs.truncated:
52
+ print(f"\n Episode ended at step {obs.step_number} "
53
+ f"({'terminated' if obs.terminated else 'truncated'})")
54
+ break
55
+
56
+ print(f"\n Total steps: {obs.step_number}")
57
+ print(f" Mean reward: {sum(rewards)/len(rewards):.4f}")
58
+ print(f" Min reward: {min(rewards):.4f}")
59
+ print(f" Max reward: {max(rewards):.4f}")
60
+ print(f" Crises seen: {crises_seen or ['none triggered yet']}")
61
+ assert len(rewards) > 0, "Should have at least one reward"
62
+ print("\nβœ… Smoke test passed!")
63
+ return True
64
+
65
+
66
+ def test_all_actions():
67
+ print("\n=== Testing all action types ===")
68
+ env = RevOpsEnv(seed=0)
69
+ env.reset(seed=0)
70
+ all_actions = [
71
+ "increase_marketing", "decrease_marketing", "hire_support",
72
+ "fire_support", "discount_campaign", "raise_prices",
73
+ "feature_investment", "cut_costs", "negotiate_contracts", "pivot_segment",
74
+ ]
75
+ for action in all_actions:
76
+ obs = env.step({"action_type": action, "magnitude": 0.5})
77
+ assert obs.reward_last_step is not None
78
+ print(f" βœ“ {action:<24} reward={obs.reward_last_step:+.3f}")
79
+ print("βœ… All actions tested!")
80
+
81
+
82
+ def test_termination():
83
+ print("\n=== Testing termination conditions ===")
84
+ from revops_gym.models import RevOpsState
85
+ from revops_gym.reward import RewardRubric
86
+ rubric = RewardRubric()
87
+
88
+ # MRR below floor
89
+ state = RevOpsState(mrr=5_000, step_number=5)
90
+ assert state.is_terminal, "Should terminate when MRR < floor"
91
+ rb = rubric.compute(state, terminated=True)
92
+ assert rb.terminated_penalty == -2.0, "Should get termination penalty"
93
+ print(" βœ“ MRR floor termination works")
94
+
95
+ # Max steps
96
+ state2 = RevOpsState(mrr=100_000, step_number=30)
97
+ assert state2.is_terminal, "Should truncate at step 30"
98
+ print(" βœ“ Step limit truncation works")
99
+ print("βœ… Termination tests passed!")
100
+
101
+
102
+ if __name__ == "__main__":
103
+ test_episode(difficulty="easy")
104
+ test_episode(difficulty="normal")
105
+ test_episode(difficulty="hard")
106
+ test_all_actions()
107
+ test_termination()
108
+ print("\nπŸŽ‰ All tests passed! Ready to push to HF Spaces.")
train_colab.ipynb ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "b2ea0425",
7
+ "metadata": {
8
+ "lines_to_next_cell": 0
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "\n",
13
+ "# RevOps Gym β€” Full Training Script (Colab)\n",
14
+ "# Convert this to a .ipynb with: jupytext --to notebook train_colab.py\n",
15
+ "# Or copy cells manually into Colab.\n",
16
+ "#\n",
17
+ "# Runtime: GPU T4 (free tier) | ~45-60 min for full run\n",
18
+ "# Model: Qwen/Qwen2.5-1.5B-Instruct (1.5B, fits on T4)\n",
19
+ "# Trainer: TRL GRPO + Unsloth\n",
20
+ "# Environment: RevOps Gym (runs locally inside Colab)\n",
21
+ "\n",
22
+ "# ============================================================\n",
23
+ "# CELL 1 β€” Install dependencies\n",
24
+ "# ============================================================"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "id": "9f1de3ce",
31
+ "metadata": {
32
+ "lines_to_next_cell": 0
33
+ },
34
+ "outputs": [],
35
+ "source": [
36
+ "!pip install -q \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
37
+ "!pip install -q trl>=0.8.0 peft accelerate bitsandbytes\n",
38
+ "!pip install -q fastapi uvicorn pydantic requests wandb matplotlib\n",
39
+ "\n",
40
+ "# ============================================================\n",
41
+ "# CELL 2 β€” Clone and install RevOps Gym environment\n",
42
+ "# ============================================================"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "id": "2d37e6b7",
49
+ "metadata": {
50
+ "lines_to_next_cell": 0
51
+ },
52
+ "outputs": [],
53
+ "source": [
54
+ "!git clone https://huggingface.co/spaces/YOUR_HF_USERNAME/revops-gym\n",
55
+ "!pip install -q -e revops-gym/\n",
56
+ "\n",
57
+ "# For local testing without HF Space, copy env files directly:\n",
58
+ "# The environment runs INSIDE Colab β€” no external server needed for training.\n",
59
+ "\n",
60
+ "# ============================================================\n",
61
+ "# CELL 3 β€” Imports and config\n",
62
+ "# ============================================================"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "id": "54f78215",
69
+ "metadata": {
70
+ "lines_to_next_cell": 0
71
+ },
72
+ "outputs": [],
73
+ "source": [
74
+ "import os\n",
75
+ "import json\n",
76
+ "import random\n",
77
+ "import re\n",
78
+ "import time\n",
79
+ "import warnings\n",
80
+ "from typing import Optional\n",
81
+ "import torch\n",
82
+ "import numpy as np\n",
83
+ "import matplotlib.pyplot as plt\n",
84
+ "from collections import defaultdict\n",
85
+ "\n",
86
+ "warnings.filterwarnings(\"ignore\")\n",
87
+ "\n",
88
+ "# --- Config ---\n",
89
+ "MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
90
+ "MAX_NEW_TOKENS = 128\n",
91
+ "BATCH_SIZE = 4 # rollouts per GRPO step (keep small for T4)\n",
92
+ "NUM_EPISODES = 200 # total training episodes\n",
93
+ "GRPO_EPOCHS = 1\n",
94
+ "LR = 5e-6\n",
95
+ "MAX_STEPS_PER_EPISODE = 30\n",
96
+ "SAVE_EVERY = 50 # save checkpoint every N episodes\n",
97
+ "WANDB_PROJECT = \"revops-gym\"\n",
98
+ "USE_WANDB = False # set True if you have wandb account\n",
99
+ "\n",
100
+ "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
101
+ "if torch.cuda.is_available():\n",
102
+ " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
103
+ " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n",
104
+ "\n",
105
+ "# ============================================================\n",
106
+ "# CELL 4 β€” Load model with Unsloth (4-bit quantization)\n",
107
+ "# ============================================================"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "id": "5fdfba32",
114
+ "metadata": {
115
+ "lines_to_next_cell": 0
116
+ },
117
+ "outputs": [],
118
+ "source": [
119
+ "from unsloth import FastLanguageModel\n",
120
+ "\n",
121
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
122
+ " model_name=MODEL_NAME,\n",
123
+ " max_seq_length=1024,\n",
124
+ " load_in_4bit=True, # fits on T4 16GB\n",
125
+ " dtype=None, # auto detect\n",
126
+ ")\n",
127
+ "\n",
128
+ "# Add LoRA adapters for efficient fine-tuning\n",
129
+ "model = FastLanguageModel.get_peft_model(\n",
130
+ " model,\n",
131
+ " r=16,\n",
132
+ " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
133
+ " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
134
+ " lora_alpha=16,\n",
135
+ " lora_dropout=0,\n",
136
+ " bias=\"none\",\n",
137
+ " use_gradient_checkpointing=\"unsloth\",\n",
138
+ " random_state=42,\n",
139
+ ")\n",
140
+ "\n",
141
+ "print(\"Model loaded with LoRA adapters βœ“\")\n",
142
+ "print(f\"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}\")\n",
143
+ "\n",
144
+ "# ============================================================\n",
145
+ "# CELL 5 β€” Inline RevOps Gym (no server needed for training)\n",
146
+ "# ============================================================"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "id": "47f5b140",
153
+ "metadata": {
154
+ "lines_to_next_cell": 0
155
+ },
156
+ "outputs": [],
157
+ "source": [
158
+ "# Import directly β€” no HTTP server overhead during training\n",
159
+ "import sys\n",
160
+ "sys.path.insert(0, \"revops-gym\") # adjust if cloned elsewhere\n",
161
+ "\n",
162
+ "from revops_gym.env import RevOpsEnv\n",
163
+ "from revops_gym.models import RevOpsObservation\n",
164
+ "\n",
165
+ "SYSTEM_PROMPT = \"\"\"You are a SaaS RevOps strategist managing a B2B software company.\n",
166
+ "Your goal is to maximize sustainable revenue growth while keeping the company alive.\n",
167
+ "The VC will fire you if MRR drops below $20,000.\n",
168
+ "\n",
169
+ "Key metrics to balance:\n",
170
+ "- LTV/CAC ratio (target 3x+): profitability per customer\n",
171
+ "- MRR growth: revenue trajectory \n",
172
+ "- Cash runway: survival buffer\n",
173
+ "- Churn rate: customer retention health\n",
174
+ "- Support quality: drives retention\n",
175
+ "\n",
176
+ "You MUST respond with ONLY a JSON object. No explanation, no markdown, just JSON:\n",
177
+ "{\"action_type\": \"<action>\", \"magnitude\": <0.1-1.0>}\n",
178
+ "\n",
179
+ "Valid actions: increase_marketing, decrease_marketing, hire_support, fire_support,\n",
180
+ "discount_campaign, raise_prices, feature_investment, cut_costs, negotiate_contracts, pivot_segment\"\"\"\n",
181
+ "\n",
182
+ "print(\"Environment ready βœ“\")\n",
183
+ "\n",
184
+ "# ============================================================\n",
185
+ "# CELL 6 β€” Helper: parse LLM output β†’ action\n",
186
+ "# ============================================================"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "id": "59f8a75f",
193
+ "metadata": {
194
+ "lines_to_next_cell": 0
195
+ },
196
+ "outputs": [],
197
+ "source": [
198
+ "VALID_ACTIONS = [\n",
199
+ " \"increase_marketing\", \"decrease_marketing\", \"hire_support\", \"fire_support\",\n",
200
+ " \"discount_campaign\", \"raise_prices\", \"feature_investment\", \"cut_costs\",\n",
201
+ " \"negotiate_contracts\", \"pivot_segment\",\n",
202
+ "]\n",
203
+ "\n",
204
+ "def parse_action(text: str) -> dict:\n",
205
+ " \"\"\"Extract JSON action from model output. Returns random valid action on failure.\"\"\"\n",
206
+ " try:\n",
207
+ " # Try to find JSON block\n",
208
+ " match = re.search(r'\\{[^}]+\\}', text, re.DOTALL)\n",
209
+ " if match:\n",
210
+ " data = json.loads(match.group())\n",
211
+ " action_type = data.get(\"action_type\", \"\")\n",
212
+ " magnitude = float(data.get(\"magnitude\", 0.5))\n",
213
+ " if action_type in VALID_ACTIONS:\n",
214
+ " magnitude = max(0.1, min(1.0, magnitude))\n",
215
+ " return {\"action_type\": action_type, \"magnitude\": magnitude}\n",
216
+ " except Exception:\n",
217
+ " pass\n",
218
+ " # Fallback: random valid action\n",
219
+ " return {\"action_type\": random.choice(VALID_ACTIONS), \"magnitude\": 0.5}\n",
220
+ "\n",
221
+ "\n",
222
+ "def build_prompt(obs: RevOpsObservation) -> str:\n",
223
+ " return obs.to_prompt_text()\n",
224
+ "\n",
225
+ "\n",
226
+ "def generate_action(obs: RevOpsObservation, do_sample: bool = True) -> tuple[str, dict]:\n",
227
+ " \"\"\"Run one forward pass, return (raw_text, parsed_action).\"\"\"\n",
228
+ " prompt = build_prompt(obs)\n",
229
+ " messages = [\n",
230
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
231
+ " {\"role\": \"user\", \"content\": prompt},\n",
232
+ " ]\n",
233
+ " input_ids = tokenizer.apply_chat_template(\n",
234
+ " messages, tokenize=True, add_generation_prompt=True,\n",
235
+ " return_tensors=\"pt\"\n",
236
+ " ).to(model.device)\n",
237
+ "\n",
238
+ " with torch.no_grad():\n",
239
+ " output = model.generate(\n",
240
+ " input_ids,\n",
241
+ " max_new_tokens=MAX_NEW_TOKENS,\n",
242
+ " do_sample=do_sample,\n",
243
+ " temperature=0.8 if do_sample else 0.1,\n",
244
+ " top_p=0.9,\n",
245
+ " pad_token_id=tokenizer.eos_token_id,\n",
246
+ " )\n",
247
+ " new_tokens = output[0][input_ids.shape[1]:]\n",
248
+ " raw = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()\n",
249
+ " action = parse_action(raw)\n",
250
+ " return raw, action\n",
251
+ "\n",
252
+ "print(\"Inference helpers ready βœ“\")\n",
253
+ "\n",
254
+ "# ============================================================\n",
255
+ "# CELL 7 β€” Rollout function (one episode)\n",
256
+ "# ============================================================"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": null,
262
+ "id": "6e1497dc",
263
+ "metadata": {
264
+ "lines_to_next_cell": 0
265
+ },
266
+ "outputs": [],
267
+ "source": [
268
+ "def rollout(env: RevOpsEnv, difficulty: str = \"normal\") -> dict:\n",
269
+ " \"\"\"\n",
270
+ " Run one full episode. Returns trajectory with prompts, outputs, rewards.\n",
271
+ " Used by GRPO to score and train.\n",
272
+ " \"\"\"\n",
273
+ " obs = env.reset(seed=random.randint(0, 10_000))\n",
274
+ " trajectory = {\n",
275
+ " \"prompts\": [],\n",
276
+ " \"responses\": [],\n",
277
+ " \"rewards\": [],\n",
278
+ " \"infos\": [],\n",
279
+ " \"final_mrr\": 0.0,\n",
280
+ " \"survived\": False,\n",
281
+ " \"steps\": 0,\n",
282
+ " }\n",
283
+ "\n",
284
+ " for step in range(MAX_STEPS_PER_EPISODE):\n",
285
+ " raw, action = generate_action(obs)\n",
286
+ " obs = env.step(action)\n",
287
+ "\n",
288
+ " trajectory[\"prompts\"].append(build_prompt(obs))\n",
289
+ " trajectory[\"responses\"].append(raw)\n",
290
+ " trajectory[\"rewards\"].append(obs.reward_last_step or 0.0)\n",
291
+ " trajectory[\"infos\"].append(obs.info)\n",
292
+ "\n",
293
+ " if obs.terminated or obs.truncated:\n",
294
+ " break\n",
295
+ "\n",
296
+ " trajectory[\"final_mrr\"] = obs.mrr\n",
297
+ " trajectory[\"survived\"] = not obs.terminated or obs.step_number >= MAX_STEPS_PER_EPISODE\n",
298
+ " trajectory[\"steps\"] = obs.step_number\n",
299
+ " return trajectory\n",
300
+ "\n",
301
+ "\n",
302
+ "def rollout_batch(n: int = BATCH_SIZE, difficulty: str = \"normal\") -> list[dict]:\n",
303
+ " \"\"\"Run N rollouts and return batch.\"\"\"\n",
304
+ " env = RevOpsEnv(crisis_every=3, difficulty=difficulty)\n",
305
+ " return [rollout(env, difficulty) for _ in range(n)]\n",
306
+ "\n",
307
+ "\n",
308
+ "print(\"Rollout function ready βœ“\")\n",
309
+ "\n",
310
+ "# ============================================================\n",
311
+ "# CELL 8 β€” Baseline evaluation (untrained model)\n",
312
+ "# ============================================================"
313
+ ]
314
+ },
315
+ {
316
+ "cell_type": "code",
317
+ "execution_count": null,
318
+ "id": "fa7c83e9",
319
+ "metadata": {
320
+ "lines_to_next_cell": 0
321
+ },
322
+ "outputs": [],
323
+ "source": [
324
+ "print(\"=\" * 60)\n",
325
+ "print(\"BASELINE EVALUATION (untrained model)\")\n",
326
+ "print(\"=\" * 60)\n",
327
+ "\n",
328
+ "FastLanguageModel.for_inference(model) # enable fast inference mode\n",
329
+ "\n",
330
+ "baseline_rewards = []\n",
331
+ "baseline_mrrs = []\n",
332
+ "baseline_survivals = []\n",
333
+ "N_BASELINE = 10\n",
334
+ "\n",
335
+ "for i in range(N_BASELINE):\n",
336
+ " t = rollout(RevOpsEnv(crisis_every=3, seed=i))\n",
337
+ " mean_r = sum(t[\"rewards\"]) / max(len(t[\"rewards\"]), 1)\n",
338
+ " baseline_rewards.append(mean_r)\n",
339
+ " baseline_mrrs.append(t[\"final_mrr\"])\n",
340
+ " baseline_survivals.append(1 if t[\"survived\"] else 0)\n",
341
+ " print(f\" Episode {i+1:2d} | mean_reward={mean_r:.4f} | \"\n",
342
+ " f\"final_MRR=${t['final_mrr']:,.0f} | survived={t['survived']}\")\n",
343
+ "\n",
344
+ "print(f\"\\nBaseline mean reward: {np.mean(baseline_rewards):.4f} Β± {np.std(baseline_rewards):.4f}\")\n",
345
+ "print(f\"Baseline mean final MRR: ${np.mean(baseline_mrrs):,.0f}\")\n",
346
+ "print(f\"Baseline survival rate: {np.mean(baseline_survivals)*100:.0f}%\")\n",
347
+ "\n",
348
+ "# ============================================================\n",
349
+ "# CELL 9 β€” GRPO Training loop\n",
350
+ "# ============================================================"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": null,
356
+ "id": "6b3dab79",
357
+ "metadata": {
358
+ "lines_to_next_cell": 0
359
+ },
360
+ "outputs": [],
361
+ "source": [
362
+ "from trl import GRPOConfig, GRPOTrainer\n",
363
+ "\n",
364
+ "# Switch to training mode\n",
365
+ "FastLanguageModel.for_training(model)\n",
366
+ "\n",
367
+ "# --- GRPO reward function ---\n",
368
+ "def grpo_reward_fn(prompts, completions, **kwargs) -> list[float]:\n",
369
+ " \"\"\"\n",
370
+ " Reward function called by GRPOTrainer.\n",
371
+ " Parses each completion as a RevOps action and scores:\n",
372
+ " 1. Format compliance (+0.1 for valid JSON)\n",
373
+ " 2. Action validity (+0.1 for known action type)\n",
374
+ " 3. Magnitude reasonableness (+0.05)\n",
375
+ " 4. Contextual quality (estimated from prompt metrics)\n",
376
+ " \"\"\"\n",
377
+ " rewards = []\n",
378
+ " for prompt, completion in zip(prompts, completions):\n",
379
+ " reward = 0.0\n",
380
+ "\n",
381
+ " # Format reward\n",
382
+ " try:\n",
383
+ " match = re.search(r'\\{[^}]+\\}', completion, re.DOTALL)\n",
384
+ " if match:\n",
385
+ " data = json.loads(match.group())\n",
386
+ " reward += 0.1 # valid JSON\n",
387
+ "\n",
388
+ " if data.get(\"action_type\") in VALID_ACTIONS:\n",
389
+ " reward += 0.1 # valid action\n",
390
+ "\n",
391
+ " mag = float(data.get(\"magnitude\", -1))\n",
392
+ " if 0.1 <= mag <= 1.0:\n",
393
+ " reward += 0.05 # sensible magnitude\n",
394
+ "\n",
395
+ " # Contextual bonus: penalize fire_support when support_quality is low\n",
396
+ " if \"support_quality\" in prompt:\n",
397
+ " sq_match = re.search(r'Support quality: (\\d+)%', prompt)\n",
398
+ " if sq_match:\n",
399
+ " sq = int(sq_match.group(1))\n",
400
+ " if data.get(\"action_type\") == \"fire_support\" and sq < 60:\n",
401
+ " reward -= 0.15 # punish bad decision\n",
402
+ "\n",
403
+ " # Bonus for crisis-responsive actions\n",
404
+ " if \"ACTIVE CRISIS\" in prompt:\n",
405
+ " crisis_actions = {\n",
406
+ " \"CHURN_SPIKE\": [\"hire_support\", \"discount_campaign\", \"negotiate_contracts\"],\n",
407
+ " \"CAC_EXPLOSION\": [\"decrease_marketing\", \"feature_investment\", \"pivot_segment\"],\n",
408
+ " \"CASH_CRUNCH\": [\"cut_costs\", \"decrease_marketing\", \"raise_prices\"],\n",
409
+ " \"SUPPORT_CRISIS\": [\"hire_support\", \"feature_investment\"],\n",
410
+ " \"ENTERPRISE_CHURN\": [\"negotiate_contracts\", \"raise_prices\", \"feature_investment\"],\n",
411
+ " }\n",
412
+ " for crisis, good_actions in crisis_actions.items():\n",
413
+ " if crisis in prompt and data.get(\"action_type\") in good_actions:\n",
414
+ " reward += 0.20\n",
415
+ " break\n",
416
+ "\n",
417
+ " except Exception:\n",
418
+ " reward -= 0.05 # malformed output penalty\n",
419
+ "\n",
420
+ " rewards.append(reward)\n",
421
+ " return rewards\n",
422
+ "\n",
423
+ "\n",
424
+ "# --- Training config ---\n",
425
+ "training_args = GRPOConfig(\n",
426
+ " output_dir=\"./revops-gym-checkpoints\",\n",
427
+ " per_device_train_batch_size=BATCH_SIZE,\n",
428
+ " gradient_accumulation_steps=4,\n",
429
+ " num_train_epochs=1,\n",
430
+ " learning_rate=LR,\n",
431
+ " warmup_steps=10,\n",
432
+ " logging_steps=5,\n",
433
+ " save_steps=SAVE_EVERY,\n",
434
+ " fp16=not torch.cuda.is_bf16_supported(),\n",
435
+ " bf16=torch.cuda.is_bf16_supported(),\n",
436
+ " report_to=\"wandb\" if USE_WANDB else \"none\",\n",
437
+ " run_name=\"revops-gym-grpo\",\n",
438
+ " num_generations=BATCH_SIZE, # samples per prompt\n",
439
+ " max_new_tokens=MAX_NEW_TOKENS,\n",
440
+ " temperature=0.8,\n",
441
+ " optim=\"adamw_8bit\",\n",
442
+ " seed=42,\n",
443
+ ")\n",
444
+ "\n",
445
+ "# --- Build training dataset from rollouts ---\n",
446
+ "print(\"Collecting initial rollout batch for dataset...\")\n",
447
+ "env_train = RevOpsEnv(crisis_every=3, difficulty=\"easy\") # start easy\n",
448
+ "\n",
449
+ "# GRPO needs a dataset of prompts; it generates completions internally\n",
450
+ "from datasets import Dataset\n",
451
+ "\n",
452
+ "def build_prompt_dataset(n_samples: int = 200) -> Dataset:\n",
453
+ " \"\"\"Generate diverse prompts by rolling out episodes and capturing observations.\"\"\"\n",
454
+ " prompts = []\n",
455
+ " env = RevOpsEnv(crisis_every=3)\n",
456
+ " for i in range(n_samples):\n",
457
+ " obs = env.reset(seed=i)\n",
458
+ " for _ in range(random.randint(0, 10)):\n",
459
+ " action = random.choice(VALID_ACTIONS)\n",
460
+ " obs = env.step({\"action_type\": action, \"magnitude\": random.uniform(0.1, 1.0)})\n",
461
+ " if obs.terminated or obs.truncated:\n",
462
+ " obs = env.reset(seed=i * 100)\n",
463
+ " break\n",
464
+ " messages = [\n",
465
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
466
+ " {\"role\": \"user\", \"content\": obs.to_prompt_text()},\n",
467
+ " ]\n",
468
+ " prompt_text = tokenizer.apply_chat_template(\n",
469
+ " messages, tokenize=False, add_generation_prompt=True\n",
470
+ " )\n",
471
+ " prompts.append({\"prompt\": prompt_text})\n",
472
+ " return Dataset.from_list(prompts)\n",
473
+ "\n",
474
+ "print(\"Building prompt dataset (200 samples)...\")\n",
475
+ "train_dataset = build_prompt_dataset(200)\n",
476
+ "print(f\"Dataset size: {len(train_dataset)} prompts βœ“\")\n",
477
+ "\n",
478
+ "# ============================================================\n",
479
+ "# CELL 10 β€” Run training\n",
480
+ "# ============================================================"
481
+ ]
482
+ },
483
+ {
484
+ "cell_type": "code",
485
+ "execution_count": null,
486
+ "id": "f2885f1e",
487
+ "metadata": {
488
+ "lines_to_next_cell": 0
489
+ },
490
+ "outputs": [],
491
+ "source": [
492
+ "trainer = GRPOTrainer(\n",
493
+ " model=model,\n",
494
+ " processing_class=tokenizer,\n",
495
+ " reward_funcs=grpo_reward_fn,\n",
496
+ " args=training_args,\n",
497
+ " train_dataset=train_dataset,\n",
498
+ ")\n",
499
+ "\n",
500
+ "print(\"Starting GRPO training...\")\n",
501
+ "print(f\" Model: {MODEL_NAME}\")\n",
502
+ "print(f\" Dataset: {len(train_dataset)} prompts\")\n",
503
+ "print(f\" Batch: {BATCH_SIZE} generations per step\")\n",
504
+ "print(f\" LR: {LR}\")\n",
505
+ "print()\n",
506
+ "\n",
507
+ "train_result = trainer.train()\n",
508
+ "print(\"\\nTraining complete βœ“\")\n",
509
+ "print(train_result)\n",
510
+ "\n",
511
+ "# ============================================================\n",
512
+ "# CELL 11 β€” Post-training evaluation\n",
513
+ "# ============================================================"
514
+ ]
515
+ },
516
+ {
517
+ "cell_type": "code",
518
+ "execution_count": null,
519
+ "id": "75ba3e52",
520
+ "metadata": {
521
+ "lines_to_next_cell": 0
522
+ },
523
+ "outputs": [],
524
+ "source": [
525
+ "FastLanguageModel.for_inference(model)\n",
526
+ "\n",
527
+ "print(\"=\" * 60)\n",
528
+ "print(\"POST-TRAINING EVALUATION\")\n",
529
+ "print(\"=\" * 60)\n",
530
+ "\n",
531
+ "trained_rewards = []\n",
532
+ "trained_mrrs = []\n",
533
+ "trained_survivals = []\n",
534
+ "N_EVAL = 10\n",
535
+ "\n",
536
+ "for i in range(N_EVAL):\n",
537
+ " t = rollout(RevOpsEnv(crisis_every=3, seed=i + 1000))\n",
538
+ " mean_r = sum(t[\"rewards\"]) / max(len(t[\"rewards\"]), 1)\n",
539
+ " trained_rewards.append(mean_r)\n",
540
+ " trained_mrrs.append(t[\"final_mrr\"])\n",
541
+ " trained_survivals.append(1 if t[\"survived\"] else 0)\n",
542
+ " print(f\" Episode {i+1:2d} | mean_reward={mean_r:.4f} | \"\n",
543
+ " f\"final_MRR=${t['final_mrr']:,.0f} | survived={t['survived']}\")\n",
544
+ "\n",
545
+ "print(f\"\\nTrained mean reward: {np.mean(trained_rewards):.4f} Β± {np.std(trained_rewards):.4f}\")\n",
546
+ "print(f\"Trained mean final MRR: ${np.mean(trained_mrrs):,.0f}\")\n",
547
+ "print(f\"Trained survival rate: {np.mean(trained_survivals)*100:.0f}%\")\n",
548
+ "\n",
549
+ "# Delta\n",
550
+ "print(f\"\\n{'='*60}\")\n",
551
+ "print(\"IMPROVEMENT SUMMARY\")\n",
552
+ "print(f\"{'='*60}\")\n",
553
+ "print(f\"Mean reward delta: {np.mean(trained_rewards) - np.mean(baseline_rewards):+.4f}\")\n",
554
+ "print(f\"Final MRR delta: ${np.mean(trained_mrrs) - np.mean(baseline_mrrs):+,.0f}\")\n",
555
+ "print(f\"Survival rate delta: {(np.mean(trained_survivals) - np.mean(baseline_survivals))*100:+.0f}%\")\n",
556
+ "\n",
557
+ "# ============================================================\n",
558
+ "# CELL 12 β€” Plot reward curves and save to repo\n",
559
+ "# ============================================================"
560
+ ]
561
+ },
562
+ {
563
+ "cell_type": "code",
564
+ "execution_count": null,
565
+ "id": "4ae37220",
566
+ "metadata": {
567
+ "lines_to_next_cell": 0
568
+ },
569
+ "outputs": [],
570
+ "source": [
571
+ "fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
572
+ "fig.suptitle(\"RevOps Gym β€” Training Results\", fontsize=14, fontweight=\"bold\")\n",
573
+ "\n",
574
+ "# Reward comparison\n",
575
+ "ax = axes[0]\n",
576
+ "ax.bar([\"Baseline\", \"Trained\"],\n",
577
+ " [np.mean(baseline_rewards), np.mean(trained_rewards)],\n",
578
+ " color=[\"#e74c3c\", \"#2ecc71\"], alpha=0.85, edgecolor=\"black\")\n",
579
+ "ax.errorbar([\"Baseline\", \"Trained\"],\n",
580
+ " [np.mean(baseline_rewards), np.mean(trained_rewards)],\n",
581
+ " yerr=[np.std(baseline_rewards), np.std(trained_rewards)],\n",
582
+ " fmt=\"none\", color=\"black\", capsize=5)\n",
583
+ "ax.set_title(\"Mean Episode Reward\")\n",
584
+ "ax.set_ylabel(\"Reward\")\n",
585
+ "ax.set_xlabel(\"Model\")\n",
586
+ "\n",
587
+ "# MRR comparison\n",
588
+ "ax2 = axes[1]\n",
589
+ "ax2.bar([\"Baseline\", \"Trained\"],\n",
590
+ " [np.mean(baseline_mrrs)/1000, np.mean(trained_mrrs)/1000],\n",
591
+ " color=[\"#e74c3c\", \"#2ecc71\"], alpha=0.85, edgecolor=\"black\")\n",
592
+ "ax2.set_title(\"Mean Final MRR\")\n",
593
+ "ax2.set_ylabel(\"MRR ($K)\")\n",
594
+ "ax2.set_xlabel(\"Model\")\n",
595
+ "ax2.axhline(y=20, color=\"orange\", linestyle=\"--\", label=\"VC floor ($20K)\")\n",
596
+ "ax2.legend()\n",
597
+ "\n",
598
+ "# Survival rate\n",
599
+ "ax3 = axes[2]\n",
600
+ "ax3.bar([\"Baseline\", \"Trained\"],\n",
601
+ " [np.mean(baseline_survivals)*100, np.mean(trained_survivals)*100],\n",
602
+ " color=[\"#e74c3c\", \"#2ecc71\"], alpha=0.85, edgecolor=\"black\")\n",
603
+ "ax3.set_title(\"Company Survival Rate\")\n",
604
+ "ax3.set_ylabel(\"Survival %\")\n",
605
+ "ax3.set_xlabel(\"Model\")\n",
606
+ "ax3.set_ylim(0, 110)\n",
607
+ "ax3.axhline(y=100, color=\"gray\", linestyle=\"--\", alpha=0.5)\n",
608
+ "\n",
609
+ "plt.tight_layout()\n",
610
+ "plt.savefig(\"results_comparison.png\", dpi=150, bbox_inches=\"tight\")\n",
611
+ "plt.show()\n",
612
+ "print(\"Plot saved as results_comparison.png βœ“\")\n",
613
+ "\n",
614
+ "# Training loss plot (from trainer logs)\n",
615
+ "if hasattr(trainer.state, \"log_history\") and trainer.state.log_history:\n",
616
+ " losses = [x.get(\"loss\", None) for x in trainer.state.log_history if \"loss\" in x]\n",
617
+ " rewards_log = [x.get(\"reward\", None) for x in trainer.state.log_history if \"reward\" in x]\n",
618
+ "\n",
619
+ " fig2, (ax_l, ax_r) = plt.subplots(1, 2, figsize=(12, 4))\n",
620
+ " fig2.suptitle(\"RevOps Gym β€” Training Curves\", fontsize=13, fontweight=\"bold\")\n",
621
+ "\n",
622
+ " if losses:\n",
623
+ " ax_l.plot(losses, color=\"#3498db\", linewidth=1.5)\n",
624
+ " ax_l.set_title(\"Training Loss\")\n",
625
+ " ax_l.set_xlabel(\"Training step\")\n",
626
+ " ax_l.set_ylabel(\"Loss\")\n",
627
+ " ax_l.grid(alpha=0.3)\n",
628
+ "\n",
629
+ " if rewards_log:\n",
630
+ " ax_r.plot(rewards_log, color=\"#2ecc71\", linewidth=1.5)\n",
631
+ " ax_r.set_title(\"Training Reward\")\n",
632
+ " ax_r.set_xlabel(\"Training step\")\n",
633
+ " ax_r.set_ylabel(\"Mean reward\")\n",
634
+ " ax_r.grid(alpha=0.3)\n",
635
+ "\n",
636
+ " plt.tight_layout()\n",
637
+ " plt.savefig(\"training_curves.png\", dpi=150, bbox_inches=\"tight\")\n",
638
+ " plt.show()\n",
639
+ " print(\"Training curves saved as training_curves.png βœ“\")\n",
640
+ "\n",
641
+ "# ============================================================\n",
642
+ "# CELL 13 β€” Save model and push to HF Hub\n",
643
+ "# ============================================================"
644
+ ]
645
+ },
646
+ {
647
+ "cell_type": "code",
648
+ "execution_count": null,
649
+ "id": "d236cce6",
650
+ "metadata": {},
651
+ "outputs": [],
652
+ "source": [
653
+ "# Save with Unsloth's correct LoRA merge path (avoids 4bit→16bit corruption)\n",
654
+ "model.save_pretrained_merged(\n",
655
+ " \"revops-gym-model\",\n",
656
+ " tokenizer,\n",
657
+ " save_method=\"lora\", # save adapters only (small, efficient)\n",
658
+ ")\n",
659
+ "print(\"Model adapters saved to ./revops-gym-model βœ“\")\n",
660
+ "\n",
661
+ "# Push to Hugging Face Hub\n",
662
+ "# from huggingface_hub import login\n",
663
+ "# login(token=\"hf_YOUR_TOKEN\")\n",
664
+ "# model.push_to_hub_merged(\"YOUR_HF_USERNAME/revops-gym-model\", tokenizer, save_method=\"lora\")\n",
665
+ "# print(\"Model pushed to HF Hub βœ“\")\n",
666
+ "\n",
667
+ "# Copy result plots into the revops-gym repo for the README\n",
668
+ "!cp results_comparison.png revops-gym/\n",
669
+ "!cp training_curves.png revops-gym/\n",
670
+ "!cd revops-gym && git add results_comparison.png training_curves.png && git commit -m \"Add training result plots\" && git push\n",
671
+ "\n",
672
+ "print(\"\\nπŸŽ‰ Training pipeline complete!\")\n",
673
+ "print(\"Next steps:\")\n",
674
+ "print(\" 1. Copy results_comparison.png and training_curves.png into your HF Space repo\")\n",
675
+ "print(\" 2. Embed them in README.md\")\n",
676
+ "print(\" 3. Push the trained model adapter to HF Hub\")\n",
677
+ "print(\" 4. Submit the HF Space URL via the Google Form\")"
678
+ ]
679
+ }
680
+ ],
681
+ "metadata": {
682
+ "jupytext": {
683
+ "cell_metadata_filter": "-all",
684
+ "main_language": "python",
685
+ "notebook_metadata_filter": "-all"
686
+ }
687
+ },
688
+ "nbformat": 4,
689
+ "nbformat_minor": 5
690
+ }