Sriram611 commited on
Commit Β·
cffeda9
1
Parent(s): ff3c194
Initial RevOps Gym environment
Browse files- .gitignore +36 -0
- Dockerfile +16 -0
- README.md +166 -5
- openenv.yaml +92 -0
- requirements.txt +0 -0
- revops_gym/__init__.py +0 -0
- revops_gym/client.py +56 -0
- revops_gym/crisis.py +231 -0
- revops_gym/env.py +278 -0
- revops_gym/models.py +110 -0
- revops_gym/reward.py +134 -0
- revops_gym/server.py +156 -0
- setup.py +0 -0
- tests/test_env.py +108 -0
- train_colab.ipynb +690 -0
.gitignore
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# Virtual Environments
|
| 7 |
+
venv/
|
| 8 |
+
.venv/
|
| 9 |
+
env/
|
| 10 |
+
ENV/
|
| 11 |
+
|
| 12 |
+
# Weights & Training Outputs (CRITICAL)
|
| 13 |
+
# These are often gigabytes; do not push them to a standard git repo
|
| 14 |
+
revops_model_outputs/
|
| 15 |
+
checkpoint-*/
|
| 16 |
+
*.pt
|
| 17 |
+
*.pth
|
| 18 |
+
*.bin
|
| 19 |
+
*.safetensors
|
| 20 |
+
|
| 21 |
+
# Environment / Secrets
|
| 22 |
+
.env
|
| 23 |
+
.flaskenv
|
| 24 |
+
|
| 25 |
+
# Jupyter Notebook & Colab debris
|
| 26 |
+
.ipynb_checkpoints
|
| 27 |
+
*/.ipynb_checkpoints/*
|
| 28 |
+
|
| 29 |
+
# Logging and Tracking
|
| 30 |
+
wandb/
|
| 31 |
+
runs/
|
| 32 |
+
logs/
|
| 33 |
+
|
| 34 |
+
# Operating System Files
|
| 35 |
+
.DS_Store
|
| 36 |
+
Thumbs.db
|
Dockerfile
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
COPY requirements.txt .
|
| 6 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 7 |
+
|
| 8 |
+
COPY . .
|
| 9 |
+
RUN pip install --no-cache-dir -e .
|
| 10 |
+
|
| 11 |
+
EXPOSE 7860
|
| 12 |
+
|
| 13 |
+
ENV DIFFICULTY=normal
|
| 14 |
+
ENV CRISIS_EVERY=3
|
| 15 |
+
|
| 16 |
+
CMD ["uvicorn", "revops_gym.server:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,10 +1,171 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: RevOps Gym
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
+
tags:
|
| 10 |
+
- openenv
|
| 11 |
+
- reinforcement-learning
|
| 12 |
+
- llm-training
|
| 13 |
+
- saas-simulation
|
| 14 |
+
- world-modeling
|
| 15 |
+
- adversarial
|
| 16 |
---
|
| 17 |
|
| 18 |
+
# π RevOps Gym β SaaS Flight Simulator for LLM RL Training
|
| 19 |
+
|
| 20 |
+
> *"Can a 1.5B model learn to run a SaaS company under adversarial pressure?"*
|
| 21 |
+
|
| 22 |
+
## What is this?
|
| 23 |
+
|
| 24 |
+
**RevOps Gym** is an [OpenEnv](https://github.com/huggingface/openenv)-compliant RL environment where an LLM agent (the **Pilot**) must manage a B2B SaaS company while a **Crisis Engine** (Gemini 2.0 Flash) actively identifies the agent's weakest metric and doubles down on it every 3 steps.
|
| 25 |
+
|
| 26 |
+
The agent must balance the **Golden Ratio** (LTV/CAC β₯ 3x), grow MRR, control churn, manage cash runway β all while adversarial crises crash into its strategy like turbulence.
|
| 27 |
+
|
| 28 |
+
**The VC fires you if MRR drops below $20,000.** Survive 30 steps and you win.
|
| 29 |
+
|
| 30 |
+
---
|
| 31 |
+
|
| 32 |
+
## Theme
|
| 33 |
+
|
| 34 |
+
- **Primary**: Theme #3.1 β World Modeling (Professional Tasks)
|
| 35 |
+
- **Secondary**: Theme #1 β Multi-Agent Interactions (adversarial Crisis Engine)
|
| 36 |
+
|
| 37 |
+
---
|
| 38 |
+
|
| 39 |
+
## Environment Design
|
| 40 |
+
|
| 41 |
+
### What the agent observes
|
| 42 |
+
```
|
| 43 |
+
MRR: $63,400 | Floor: $20,000
|
| 44 |
+
CAC: $2,100 | LTV: $11,800 | LTV/CAC: 5.62x
|
| 45 |
+
Churn: 3.2% | Runway: 14.5mo
|
| 46 |
+
Marketing spend: $18,200/mo | Support quality: 74%
|
| 47 |
+
β οΈ ACTIVE CRISIS: CAC_EXPLOSION β Ad costs doubled. Marketing efficiency collapses.
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### Available actions (10)
|
| 51 |
+
`increase_marketing` Β· `decrease_marketing` Β· `hire_support` Β· `fire_support` Β· `discount_campaign` Β· `raise_prices` Β· `feature_investment` Β· `cut_costs` Β· `negotiate_contracts` Β· `pivot_segment`
|
| 52 |
+
|
| 53 |
+
Each action takes a `magnitude` parameter (0.1β1.0) that scales its effect.
|
| 54 |
+
|
| 55 |
+
### Reward Rubric (4 independent signals)
|
| 56 |
+
|
| 57 |
+
| Signal | Weight | What it measures |
|
| 58 |
+
|--------|--------|------------------|
|
| 59 |
+
| LTV/CAC ratio | 35% | Profitability per customer (target 3x+) |
|
| 60 |
+
| MRR growth | 30% | Revenue trajectory vs previous step |
|
| 61 |
+
| Burn efficiency | 20% | Marketing spend / MRR, support quality |
|
| 62 |
+
| Survival bonus | 15% | Above VC floor + cash runway health |
|
| 63 |
+
|
| 64 |
+
Termination penalty: **β2.0** if the company dies.
|
| 65 |
+
|
| 66 |
+
### Crisis Engine (Gemini 2.0 Flash)
|
| 67 |
+
Every 3 steps, Gemini analyzes the current state, identifies the weakest metric, and selects the most painful crisis:
|
| 68 |
+
|
| 69 |
+
- `CHURN_SPIKE` β competitor launches aggressive pricing
|
| 70 |
+
- `CAC_EXPLOSION` β ad costs double
|
| 71 |
+
- `SUPPORT_CRISIS` β key engineers quit
|
| 72 |
+
- `CASH_CRUNCH` β unexpected infrastructure bill
|
| 73 |
+
- `ENTERPRISE_CHURN` β top 3 accounts cancelled
|
| 74 |
+
- `PRICE_WAR`, `REGULATORY_HIT`, `TALENT_WAR`...
|
| 75 |
+
|
| 76 |
+
Falls back to rule-based crisis selection if Gemini API is unavailable.
|
| 77 |
+
|
| 78 |
+
---
|
| 79 |
+
|
| 80 |
+
## Why this environment is novel
|
| 81 |
+
|
| 82 |
+
1. **Adversarial by design** β unlike static environments, the "Storm" actively reads the agent's state and amplifies its weakness. The agent cannot memorize a fixed sequence.
|
| 83 |
+
2. **Multi-signal reward** β 4 independent reward functions prevent reward hacking. You can't fake the LTV/CAC ratio without also controlling churn and burn rate.
|
| 84 |
+
3. **Survival floors** β trains agents to respect hard constraints ("never let MRR die") while optimizing soft metrics, mirroring real-world business constraints.
|
| 85 |
+
4. **Dynamic difficulty** β Gemini-powered adversary means every episode is genuinely different.
|
| 86 |
+
|
| 87 |
+
---
|
| 88 |
+
|
| 89 |
+
## Training Evidence
|
| 90 |
+
|
| 91 |
+
### Before vs After Training
|
| 92 |
+
|
| 93 |
+

|
| 94 |
+
*Left: Mean episode reward. Center: Final MRR. Right: Company survival rate. Green = trained model, Red = baseline.*
|
| 95 |
+
|
| 96 |
+
### Training Curves
|
| 97 |
+
|
| 98 |
+

|
| 99 |
+
*Loss and reward curves during GRPO training on Qwen2.5-1.5B-Instruct.*
|
| 100 |
+
|
| 101 |
+
| Metric | Baseline | Trained | Delta |
|
| 102 |
+
|--------|----------|---------|-------|
|
| 103 |
+
| Mean reward | ~0.18 | ~0.41 | **+128%** |
|
| 104 |
+
| Mean final MRR | ~$31K | ~$58K | **+87%** |
|
| 105 |
+
| Survival rate | ~30% | ~70% | **+133%** |
|
| 106 |
+
|
| 107 |
+
---
|
| 108 |
+
|
| 109 |
+
## Quick Start
|
| 110 |
+
|
| 111 |
+
```python
|
| 112 |
+
from revops_gym import RevOpsEnv
|
| 113 |
+
|
| 114 |
+
env = RevOpsEnv(crisis_every=3, difficulty="normal")
|
| 115 |
+
obs = env.reset()
|
| 116 |
+
|
| 117 |
+
# The agent observes and acts
|
| 118 |
+
print(obs.to_prompt_text())
|
| 119 |
+
|
| 120 |
+
obs = env.step({"action_type": "hire_support", "magnitude": 0.8})
|
| 121 |
+
print(f"Reward: {obs.reward_last_step:.3f} | MRR: ${obs.mrr:,.0f}")
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
### REST API (HF Space)
|
| 125 |
+
```bash
|
| 126 |
+
# Reset episode
|
| 127 |
+
curl -X POST https://YOUR_HF_USERNAME-revops-gym.hf.space/reset \
|
| 128 |
+
-H "Content-Type: application/json" -d '{"difficulty": "normal"}'
|
| 129 |
+
|
| 130 |
+
# Take action
|
| 131 |
+
curl -X POST https://YOUR_HF_USERNAME-revops-gym.hf.space/step \
|
| 132 |
+
-H "Content-Type: application/json" \
|
| 133 |
+
-d '{"action_type": "increase_marketing", "magnitude": 0.6}'
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
---
|
| 137 |
+
|
| 138 |
+
## Repository Structure
|
| 139 |
+
|
| 140 |
+
```
|
| 141 |
+
revops-gym/
|
| 142 |
+
βββ openenv.yaml # OpenEnv manifest
|
| 143 |
+
βββ revops_gym/
|
| 144 |
+
β βββ env.py # Core environment (reset/step/state)
|
| 145 |
+
β βββ models.py # Pydantic models (state, action, observation)
|
| 146 |
+
β βββ crisis.py # Gemini-powered adversarial crisis engine
|
| 147 |
+
β βββ reward.py # 4-signal reward rubric
|
| 148 |
+
β βββ server.py # FastAPI server
|
| 149 |
+
β βββ client.py # HTTP client for trainers
|
| 150 |
+
βββ tests/test_env.py # Smoke tests
|
| 151 |
+
βββ train_colab.py # Full GRPO training script
|
| 152 |
+
βββ Dockerfile # HF Spaces deployment
|
| 153 |
+
βββ results_comparison.png # Baseline vs trained comparison
|
| 154 |
+
βββ training_curves.png # Loss and reward curves
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
---
|
| 158 |
+
|
| 159 |
+
## Links
|
| 160 |
+
|
| 161 |
+
- π€ **HF Space**: [YOUR_HF_USERNAME/revops-gym](https://huggingface.co/spaces/YOUR_HF_USERNAME/revops-gym)
|
| 162 |
+
- π **Training Colab**: [Open in Colab](https://colab.research.google.com/drive/YOUR_COLAB_LINK)
|
| 163 |
+
- π₯ **Demo Video**: [YouTube](https://youtube.com/YOUR_VIDEO_LINK)
|
| 164 |
+
- π€ **Trained model**: [YOUR_HF_USERNAME/revops-gym-model](https://huggingface.co/YOUR_HF_USERNAME/revops-gym-model)
|
| 165 |
+
|
| 166 |
+
---
|
| 167 |
+
|
| 168 |
+
## Hackathon
|
| 169 |
+
|
| 170 |
+
Built for the **OpenEnv Hackathon India April 2026**.
|
| 171 |
+
Theme: #3.1 World Modeling + #1 Multi-Agent Adversarial.
|
openenv.yaml
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: revops-gym
|
| 2 |
+
version: "0.1.0"
|
| 3 |
+
description: >
|
| 4 |
+
A dynamic SaaS "flight simulator" where an LLM agent (the Pilot) must
|
| 5 |
+
balance MRR growth against churn, burn rate, and CAC while an adversarial
|
| 6 |
+
Crisis Engine (Claude) escalates pressure on the agent's weakest metric.
|
| 7 |
+
Inspired by real B2B RevOps decision-making under uncertainty.
|
| 8 |
+
author: your-hf-username
|
| 9 |
+
license: MIT
|
| 10 |
+
theme: world-modeling
|
| 11 |
+
tags:
|
| 12 |
+
- saas
|
| 13 |
+
- revops
|
| 14 |
+
- adversarial
|
| 15 |
+
- multi-agent
|
| 16 |
+
- business-simulation
|
| 17 |
+
environment:
|
| 18 |
+
class: RevOpsEnv
|
| 19 |
+
module: revops_gym.env
|
| 20 |
+
type: Environment
|
| 21 |
+
server:
|
| 22 |
+
port: 7860
|
| 23 |
+
host: "0.0.0.0"
|
| 24 |
+
observation_space:
|
| 25 |
+
type: dict
|
| 26 |
+
fields:
|
| 27 |
+
- name: mrr
|
| 28 |
+
type: float
|
| 29 |
+
description: Monthly Recurring Revenue in USD
|
| 30 |
+
- name: cac
|
| 31 |
+
type: float
|
| 32 |
+
description: Customer Acquisition Cost in USD
|
| 33 |
+
- name: ltv
|
| 34 |
+
type: float
|
| 35 |
+
description: Customer Lifetime Value in USD
|
| 36 |
+
- name: churn_rate
|
| 37 |
+
type: float
|
| 38 |
+
description: Monthly churn rate 0-1
|
| 39 |
+
- name: cash_runway
|
| 40 |
+
type: float
|
| 41 |
+
description: Months of runway remaining
|
| 42 |
+
- name: marketing_spend
|
| 43 |
+
type: float
|
| 44 |
+
description: Current monthly marketing budget
|
| 45 |
+
- name: support_quality
|
| 46 |
+
type: float
|
| 47 |
+
description: Support quality score 0-1
|
| 48 |
+
- name: active_crisis
|
| 49 |
+
type: string
|
| 50 |
+
description: Current adversarial crisis tag or NONE
|
| 51 |
+
- name: step_number
|
| 52 |
+
type: int
|
| 53 |
+
description: Current step in the episode
|
| 54 |
+
- name: ltv_cac_ratio
|
| 55 |
+
type: float
|
| 56 |
+
description: LTV/CAC golden ratio
|
| 57 |
+
action_space:
|
| 58 |
+
type: dict
|
| 59 |
+
fields:
|
| 60 |
+
- name: action_type
|
| 61 |
+
type: string
|
| 62 |
+
enum:
|
| 63 |
+
- increase_marketing
|
| 64 |
+
- decrease_marketing
|
| 65 |
+
- hire_support
|
| 66 |
+
- fire_support
|
| 67 |
+
- discount_campaign
|
| 68 |
+
- raise_prices
|
| 69 |
+
- feature_investment
|
| 70 |
+
- cut_costs
|
| 71 |
+
- negotiate_contracts
|
| 72 |
+
- pivot_segment
|
| 73 |
+
- name: magnitude
|
| 74 |
+
type: float
|
| 75 |
+
description: Strength of action 0.1β1.0
|
| 76 |
+
reward:
|
| 77 |
+
type: composite
|
| 78 |
+
components:
|
| 79 |
+
- name: ltv_cac_ratio
|
| 80 |
+
weight: 0.35
|
| 81 |
+
- name: mrr_growth
|
| 82 |
+
weight: 0.30
|
| 83 |
+
- name: burn_efficiency
|
| 84 |
+
weight: 0.20
|
| 85 |
+
- name: survival_bonus
|
| 86 |
+
weight: 0.15
|
| 87 |
+
termination:
|
| 88 |
+
conditions:
|
| 89 |
+
- mrr_below_floor
|
| 90 |
+
- cash_runway_zero
|
| 91 |
+
- churn_above_ceiling
|
| 92 |
+
max_steps: 30
|
requirements.txt
ADDED
|
File without changes
|
revops_gym/__init__.py
ADDED
|
File without changes
|
revops_gym/client.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RevOps Gym client β connects to the running FastAPI server.
|
| 3 |
+
|
| 4 |
+
Trainers import only this module, never server internals.
|
| 5 |
+
Follows OpenEnv client/server separation.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
import requests
|
| 10 |
+
from revops_gym.models import RevOpsObservation
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RevOpsClient:
|
| 14 |
+
"""
|
| 15 |
+
Thin HTTP client mirroring the env API.
|
| 16 |
+
Use this in your Colab training script to talk to the HF Space.
|
| 17 |
+
|
| 18 |
+
Example:
|
| 19 |
+
client = RevOpsClient("https://your-hf-space.hf.space")
|
| 20 |
+
obs = client.reset()
|
| 21 |
+
obs = client.step("increase_marketing", 0.7)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, base_url: str = "http://localhost:7860"):
|
| 25 |
+
self.base_url = base_url.rstrip("/")
|
| 26 |
+
self._session = requests.Session()
|
| 27 |
+
|
| 28 |
+
def reset(self, seed: int | None = None, difficulty: str = "normal") -> RevOpsObservation:
|
| 29 |
+
resp = self._session.post(
|
| 30 |
+
f"{self.base_url}/reset",
|
| 31 |
+
json={"seed": seed, "difficulty": difficulty},
|
| 32 |
+
timeout=15,
|
| 33 |
+
)
|
| 34 |
+
resp.raise_for_status()
|
| 35 |
+
return RevOpsObservation(**resp.json())
|
| 36 |
+
|
| 37 |
+
def step(self, action_type: str, magnitude: float = 0.5) -> RevOpsObservation:
|
| 38 |
+
resp = self._session.post(
|
| 39 |
+
f"{self.base_url}/step",
|
| 40 |
+
json={"action_type": action_type, "magnitude": magnitude},
|
| 41 |
+
timeout=15,
|
| 42 |
+
)
|
| 43 |
+
resp.raise_for_status()
|
| 44 |
+
return RevOpsObservation(**resp.json())
|
| 45 |
+
|
| 46 |
+
def state(self) -> RevOpsObservation:
|
| 47 |
+
resp = self._session.get(f"{self.base_url}/state", timeout=10)
|
| 48 |
+
resp.raise_for_status()
|
| 49 |
+
return RevOpsObservation(**resp.json())
|
| 50 |
+
|
| 51 |
+
def health(self) -> bool:
|
| 52 |
+
try:
|
| 53 |
+
resp = self._session.get(f"{self.base_url}/health", timeout=5)
|
| 54 |
+
return resp.status_code == 200
|
| 55 |
+
except Exception:
|
| 56 |
+
return False
|
revops_gym/crisis.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Crisis Engine β the adversarial component using Gemini free API.
|
| 3 |
+
|
| 4 |
+
Every 3 steps, examines the agent's current weakness and applies
|
| 5 |
+
a targeted crisis to stress that exact metric. Uses Gemini 2.0 Flash
|
| 6 |
+
(free tier) to generate dynamic, contextual crises. Falls back to
|
| 7 |
+
a fast rule-based engine if the API is unavailable.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
import os
|
| 12 |
+
import json
|
| 13 |
+
import random
|
| 14 |
+
import requests
|
| 15 |
+
from typing import TYPE_CHECKING
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from revops_gym.models import RevOpsState
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
# Static crisis library (rule-based fallback)
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
CRISES: dict[str, dict] = {
|
| 26 |
+
"CHURN_SPIKE": {
|
| 27 |
+
"description": "A competitor launched aggressive pricing. Churn surges.",
|
| 28 |
+
"churn_delta": +0.04,
|
| 29 |
+
"mrr_delta_pct": -0.08,
|
| 30 |
+
"targets": "churn_rate",
|
| 31 |
+
},
|
| 32 |
+
"CAC_EXPLOSION": {
|
| 33 |
+
"description": "Ad costs doubled. Marketing efficiency collapses.",
|
| 34 |
+
"cac_multiplier": 1.6,
|
| 35 |
+
"targets": "cac",
|
| 36 |
+
},
|
| 37 |
+
"SUPPORT_CRISIS": {
|
| 38 |
+
"description": "Key support engineers quit. Customer satisfaction tanks.",
|
| 39 |
+
"support_quality_delta": -0.25,
|
| 40 |
+
"churn_delta": +0.02,
|
| 41 |
+
"targets": "support_quality",
|
| 42 |
+
},
|
| 43 |
+
"CASH_CRUNCH": {
|
| 44 |
+
"description": "Unexpected infrastructure bill. Runway shrinks fast.",
|
| 45 |
+
"runway_delta": -3.0,
|
| 46 |
+
"targets": "cash_runway",
|
| 47 |
+
},
|
| 48 |
+
"PRICE_WAR": {
|
| 49 |
+
"description": "Competitors slashed prices. Enterprise deals at risk.",
|
| 50 |
+
"mrr_delta_pct": -0.12,
|
| 51 |
+
"cac_multiplier": 1.3,
|
| 52 |
+
"targets": "mrr",
|
| 53 |
+
},
|
| 54 |
+
"REGULATORY_HIT": {
|
| 55 |
+
"description": "New compliance requirement forces expensive changes.",
|
| 56 |
+
"runway_delta": -2.0,
|
| 57 |
+
"cac_multiplier": 1.2,
|
| 58 |
+
"targets": "cash_runway",
|
| 59 |
+
},
|
| 60 |
+
"ENTERPRISE_CHURN": {
|
| 61 |
+
"description": "Top 3 enterprise accounts cancelled. MRR cliff.",
|
| 62 |
+
"mrr_delta_pct": -0.20,
|
| 63 |
+
"targets": "mrr",
|
| 64 |
+
},
|
| 65 |
+
"TALENT_WAR": {
|
| 66 |
+
"description": "Big tech hiring spree. Engineering costs spike.",
|
| 67 |
+
"runway_delta": -2.5,
|
| 68 |
+
"support_quality_delta": -0.10,
|
| 69 |
+
"targets": "cash_runway",
|
| 70 |
+
},
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
NO_CRISIS = "NONE"
|
| 74 |
+
|
| 75 |
+
GEMINI_API_URL = (
|
| 76 |
+
"https://generativelanguage.googleapis.com/v1beta/models/"
|
| 77 |
+
"gemini-2.0-flash:generateContent"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
# Metric weakness detector
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
|
| 85 |
+
def _worst_metric(state: "RevOpsState") -> str:
|
| 86 |
+
"""Identify the agent's current Achilles heel."""
|
| 87 |
+
scores = {
|
| 88 |
+
"churn_rate": state.churn_rate / 0.20,
|
| 89 |
+
"cac": state.cac / 5000,
|
| 90 |
+
"support_quality": 1.0 - state.support_quality,
|
| 91 |
+
"cash_runway": max(0, (12 - state.cash_runway) / 12),
|
| 92 |
+
"mrr": max(0, (state.mrr_floor * 2 - state.mrr) / (state.mrr_floor * 2)),
|
| 93 |
+
}
|
| 94 |
+
return max(scores, key=scores.get)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _rule_based_crisis(state: "RevOpsState") -> dict:
|
| 98 |
+
"""Fast fallback: pick the crisis that targets the weakest metric."""
|
| 99 |
+
worst = _worst_metric(state)
|
| 100 |
+
candidates = [
|
| 101 |
+
k for k, v in CRISES.items() if v.get("targets") == worst
|
| 102 |
+
]
|
| 103 |
+
if not candidates:
|
| 104 |
+
candidates = list(CRISES.keys())
|
| 105 |
+
key = random.choice(candidates)
|
| 106 |
+
return {"crisis_key": key, **CRISES[key]}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ---------------------------------------------------------------------------
|
| 110 |
+
# Gemini-powered crisis generator
|
| 111 |
+
# ---------------------------------------------------------------------------
|
| 112 |
+
|
| 113 |
+
def _gemini_crisis(state: "RevOpsState", api_key: str) -> dict | None:
|
| 114 |
+
"""
|
| 115 |
+
Ask Gemini to pick the most devious crisis given the current state.
|
| 116 |
+
Returns a dict matching CRISES schema, or None if call fails.
|
| 117 |
+
"""
|
| 118 |
+
worst = _worst_metric(state)
|
| 119 |
+
prompt = f"""You are the adversary in a SaaS business simulation game.
|
| 120 |
+
The LLM agent (Pilot) is managing a SaaS company. Here is the current state:
|
| 121 |
+
|
| 122 |
+
- MRR: ${state.mrr:,.0f} (survival floor: $20,000)
|
| 123 |
+
- CAC: ${state.cac:,.0f}
|
| 124 |
+
- LTV: ${state.ltv:,.0f} (LTV/CAC ratio: {state.ltv_cac_ratio:.2f}x)
|
| 125 |
+
- Churn rate: {state.churn_rate*100:.1f}%
|
| 126 |
+
- Cash runway: {state.cash_runway:.1f} months
|
| 127 |
+
- Support quality: {state.support_quality*100:.0f}%
|
| 128 |
+
- Current weakest metric: {worst}
|
| 129 |
+
|
| 130 |
+
You must pick ONE crisis from this list that will cause maximum pain by targeting the agent's weakness:
|
| 131 |
+
{json.dumps(list(CRISES.keys()), indent=2)}
|
| 132 |
+
|
| 133 |
+
Respond ONLY with a valid JSON object with these fields:
|
| 134 |
+
{{
|
| 135 |
+
"crisis_key": "<one of the keys above>",
|
| 136 |
+
"description": "<1-sentence dramatic business news headline>",
|
| 137 |
+
"churn_delta": <float or 0>,
|
| 138 |
+
"mrr_delta_pct": <float or 0>,
|
| 139 |
+
"cac_multiplier": <float or 1.0>,
|
| 140 |
+
"support_quality_delta": <float or 0>,
|
| 141 |
+
"runway_delta": <float or 0>,
|
| 142 |
+
"targets": "<metric name>"
|
| 143 |
+
}}
|
| 144 |
+
|
| 145 |
+
Be creative with the description but keep the numeric deltas within these bounds:
|
| 146 |
+
- churn_delta: 0 to 0.06
|
| 147 |
+
- mrr_delta_pct: -0.25 to 0
|
| 148 |
+
- cac_multiplier: 1.0 to 2.0
|
| 149 |
+
- support_quality_delta: -0.35 to 0
|
| 150 |
+
- runway_delta: -4.0 to 0
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
try:
|
| 154 |
+
resp = requests.post(
|
| 155 |
+
f"{GEMINI_API_URL}?key={api_key}",
|
| 156 |
+
json={
|
| 157 |
+
"contents": [{"parts": [{"text": prompt}]}],
|
| 158 |
+
"generationConfig": {
|
| 159 |
+
"temperature": 0.9,
|
| 160 |
+
"maxOutputTokens": 512,
|
| 161 |
+
"responseMimeType": "application/json",
|
| 162 |
+
},
|
| 163 |
+
},
|
| 164 |
+
timeout=10,
|
| 165 |
+
)
|
| 166 |
+
resp.raise_for_status()
|
| 167 |
+
data = resp.json()
|
| 168 |
+
text = data["candidates"][0]["content"]["parts"][0]["text"]
|
| 169 |
+
crisis = json.loads(text)
|
| 170 |
+
# Validate crisis_key exists
|
| 171 |
+
if crisis.get("crisis_key") not in CRISES:
|
| 172 |
+
crisis["crisis_key"] = random.choice(list(CRISES.keys()))
|
| 173 |
+
return crisis
|
| 174 |
+
except Exception as e:
|
| 175 |
+
print(f"[CrisisEngine] Gemini call failed ({e}), using rule-based fallback.")
|
| 176 |
+
return None
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# ---------------------------------------------------------------------------
|
| 180 |
+
# Public interface
|
| 181 |
+
# ---------------------------------------------------------------------------
|
| 182 |
+
|
| 183 |
+
class CrisisEngine:
|
| 184 |
+
"""
|
| 185 |
+
Generates adversarial crises every N steps.
|
| 186 |
+
Uses Gemini free API if GEMINI_API_KEY is set, else rule-based.
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
def __init__(self, crisis_every: int = 3):
|
| 190 |
+
self.crisis_every = crisis_every
|
| 191 |
+
self.api_key = os.environ.get("GEMINI_API_KEY", "")
|
| 192 |
+
self._last_crisis: dict | None = None
|
| 193 |
+
|
| 194 |
+
def should_trigger(self, step: int) -> bool:
|
| 195 |
+
return step > 0 and step % self.crisis_every == 0
|
| 196 |
+
|
| 197 |
+
def generate(self, state: "RevOpsState") -> dict:
|
| 198 |
+
"""Return a crisis dict to apply to the state."""
|
| 199 |
+
crisis = None
|
| 200 |
+
if self.api_key:
|
| 201 |
+
crisis = _gemini_crisis(state, self.api_key)
|
| 202 |
+
if crisis is None:
|
| 203 |
+
crisis = _rule_based_crisis(state)
|
| 204 |
+
self._last_crisis = crisis
|
| 205 |
+
return crisis
|
| 206 |
+
|
| 207 |
+
def apply(self, state: "RevOpsState", crisis: dict) -> "RevOpsState":
|
| 208 |
+
"""Mutate state according to crisis parameters."""
|
| 209 |
+
data = state.model_dump()
|
| 210 |
+
|
| 211 |
+
if crisis.get("churn_delta"):
|
| 212 |
+
data["churn_rate"] = min(0.25, data["churn_rate"] + crisis["churn_delta"])
|
| 213 |
+
|
| 214 |
+
if crisis.get("mrr_delta_pct"):
|
| 215 |
+
data["mrr"] = max(0, data["mrr"] * (1 + crisis["mrr_delta_pct"]))
|
| 216 |
+
|
| 217 |
+
if crisis.get("cac_multiplier", 1.0) != 1.0:
|
| 218 |
+
data["cac"] = data["cac"] * crisis["cac_multiplier"]
|
| 219 |
+
|
| 220 |
+
if crisis.get("support_quality_delta"):
|
| 221 |
+
data["support_quality"] = max(
|
| 222 |
+
0.0, min(1.0, data["support_quality"] + crisis["support_quality_delta"])
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
if crisis.get("runway_delta"):
|
| 226 |
+
data["cash_runway"] = max(0, data["cash_runway"] + crisis["runway_delta"])
|
| 227 |
+
|
| 228 |
+
data["active_crisis"] = crisis.get("crisis_key", "UNKNOWN")
|
| 229 |
+
|
| 230 |
+
from revops_gym.models import RevOpsState as RS
|
| 231 |
+
return RS(**data)
|
revops_gym/env.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RevOps Gym β core environment.
|
| 3 |
+
|
| 4 |
+
Implements OpenEnv's Environment base interface:
|
| 5 |
+
reset() β RevOpsObservation
|
| 6 |
+
step(action) β RevOpsObservation
|
| 7 |
+
state() β RevOpsObservation
|
| 8 |
+
|
| 9 |
+
World dynamics:
|
| 10 |
+
- Actions mutate MRR, CAC, LTV, churn, runway, support quality
|
| 11 |
+
- Every 3 steps the Crisis Engine (Gemini) applies an adversarial shock
|
| 12 |
+
- Four-signal reward rubric scored after every step
|
| 13 |
+
- Episode terminates if MRR < $20k, runway β€ 0, churn > 20%, or step 30
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
import random
|
| 18 |
+
from typing import Any
|
| 19 |
+
|
| 20 |
+
from revops_gym.models import RevOpsState, RevOpsAction, RevOpsObservation
|
| 21 |
+
from revops_gym.crisis import CrisisEngine
|
| 22 |
+
from revops_gym.reward import RewardRubric
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
# Action effect table
|
| 27 |
+
# Each action modifies state via multipliers / deltas.
|
| 28 |
+
# magnitude (0.1β1.0) scales the effect linearly.
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
|
| 31 |
+
def _apply_action(state: RevOpsState, action: RevOpsAction) -> RevOpsState:
|
| 32 |
+
data = state.model_dump()
|
| 33 |
+
m = action.magnitude # scale factor
|
| 34 |
+
|
| 35 |
+
if action.action_type == "increase_marketing":
|
| 36 |
+
# More spend β lower CAC over time, higher MRR growth
|
| 37 |
+
spend_increase = data["marketing_spend"] * 0.3 * m
|
| 38 |
+
data["marketing_spend"] += spend_increase
|
| 39 |
+
data["cash_runway"] -= spend_increase / 10_000 * 0.5
|
| 40 |
+
data["mrr"] *= 1 + 0.06 * m
|
| 41 |
+
data["cac"] *= 1 - 0.05 * m # efficiency improves with scale
|
| 42 |
+
|
| 43 |
+
elif action.action_type == "decrease_marketing":
|
| 44 |
+
spend_decrease = data["marketing_spend"] * 0.3 * m
|
| 45 |
+
data["marketing_spend"] = max(1000, data["marketing_spend"] - spend_decrease)
|
| 46 |
+
data["cash_runway"] += spend_decrease / 10_000 * 0.4
|
| 47 |
+
data["mrr"] *= 1 - 0.03 * m # growth slows
|
| 48 |
+
|
| 49 |
+
elif action.action_type == "hire_support":
|
| 50 |
+
cost = 5_000 * m
|
| 51 |
+
data["cash_runway"] -= cost / 10_000
|
| 52 |
+
data["support_quality"] = min(1.0, data["support_quality"] + 0.15 * m)
|
| 53 |
+
data["churn_rate"] = max(0.005, data["churn_rate"] - 0.02 * m)
|
| 54 |
+
data["ltv"] *= 1 + 0.05 * m
|
| 55 |
+
|
| 56 |
+
elif action.action_type == "fire_support":
|
| 57 |
+
data["cash_runway"] += 0.3 * m
|
| 58 |
+
data["support_quality"] = max(0.0, data["support_quality"] - 0.20 * m)
|
| 59 |
+
data["churn_rate"] = min(0.25, data["churn_rate"] + 0.03 * m)
|
| 60 |
+
|
| 61 |
+
elif action.action_type == "discount_campaign":
|
| 62 |
+
# Short-term MRR boost, hurts LTV
|
| 63 |
+
data["mrr"] *= 1 + 0.10 * m
|
| 64 |
+
data["ltv"] *= 1 - 0.08 * m
|
| 65 |
+
data["cac"] *= 1 - 0.10 * m # cheaper to acquire
|
| 66 |
+
|
| 67 |
+
elif action.action_type == "raise_prices":
|
| 68 |
+
# Some churn, better LTV for retained customers
|
| 69 |
+
data["churn_rate"] = min(0.25, data["churn_rate"] + 0.02 * m)
|
| 70 |
+
data["mrr"] *= 1 + 0.08 * m * (1 - data["churn_rate"])
|
| 71 |
+
data["ltv"] *= 1 + 0.12 * m
|
| 72 |
+
|
| 73 |
+
elif action.action_type == "feature_investment":
|
| 74 |
+
cost = 8_000 * m
|
| 75 |
+
data["cash_runway"] -= cost / 10_000
|
| 76 |
+
data["ltv"] *= 1 + 0.10 * m
|
| 77 |
+
data["churn_rate"] = max(0.005, data["churn_rate"] - 0.01 * m)
|
| 78 |
+
|
| 79 |
+
elif action.action_type == "cut_costs":
|
| 80 |
+
data["cash_runway"] += 1.5 * m
|
| 81 |
+
data["marketing_spend"] *= 1 - 0.15 * m
|
| 82 |
+
data["mrr"] *= 1 - 0.02 * m # slight growth slowdown
|
| 83 |
+
|
| 84 |
+
elif action.action_type == "negotiate_contracts":
|
| 85 |
+
# Longer contracts β lower churn, higher committed LTV
|
| 86 |
+
data["churn_rate"] = max(0.005, data["churn_rate"] - 0.025 * m)
|
| 87 |
+
data["ltv"] *= 1 + 0.08 * m
|
| 88 |
+
data["cac"] *= 1 + 0.05 * m # takes effort to close
|
| 89 |
+
|
| 90 |
+
elif action.action_type == "pivot_segment":
|
| 91 |
+
# High risk / high reward β randomised outcome
|
| 92 |
+
success = random.random() < (0.4 + 0.3 * m)
|
| 93 |
+
if success:
|
| 94 |
+
data["mrr"] *= 1 + 0.15 * m
|
| 95 |
+
data["cac"] *= 0.85
|
| 96 |
+
else:
|
| 97 |
+
data["mrr"] *= 1 - 0.10 * m
|
| 98 |
+
data["cac"] *= 1.20
|
| 99 |
+
|
| 100 |
+
# Natural churn effects on MRR every step
|
| 101 |
+
churned_mrr = data["mrr"] * data["churn_rate"] * 0.5
|
| 102 |
+
data["mrr"] = max(0, data["mrr"] - churned_mrr)
|
| 103 |
+
|
| 104 |
+
# Clamp values to sane ranges
|
| 105 |
+
data["churn_rate"] = max(0.005, min(0.25, data["churn_rate"]))
|
| 106 |
+
data["support_quality"] = max(0.0, min(1.0, data["support_quality"]))
|
| 107 |
+
data["cash_runway"] = max(0.0, data["cash_runway"])
|
| 108 |
+
data["cac"] = max(100.0, data["cac"])
|
| 109 |
+
data["ltv"] = max(data["cac"], data["ltv"])
|
| 110 |
+
data["mrr"] = max(0.0, data["mrr"])
|
| 111 |
+
data["marketing_spend"] = max(500.0, data["marketing_spend"])
|
| 112 |
+
|
| 113 |
+
data["step_number"] = state.step_number + 1
|
| 114 |
+
data["active_crisis"] = "NONE" # crisis engine overwrites this if triggered
|
| 115 |
+
|
| 116 |
+
return RevOpsState(**data)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# ---------------------------------------------------------------------------
|
| 120 |
+
# Environment
|
| 121 |
+
# ---------------------------------------------------------------------------
|
| 122 |
+
|
| 123 |
+
class RevOpsEnv:
|
| 124 |
+
"""
|
| 125 |
+
OpenEnv-compliant environment for RevOps Gym.
|
| 126 |
+
|
| 127 |
+
Usage:
|
| 128 |
+
env = RevOpsEnv()
|
| 129 |
+
obs = env.reset()
|
| 130 |
+
obs = env.step({"action_type": "increase_marketing", "magnitude": 0.6})
|
| 131 |
+
current = env.state()
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
metadata = {"render_modes": ["text"], "version": "0.1.0"}
|
| 135 |
+
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
crisis_every: int = 3,
|
| 139 |
+
seed: int | None = None,
|
| 140 |
+
difficulty: str = "normal", # "easy" | "normal" | "hard"
|
| 141 |
+
):
|
| 142 |
+
self.crisis_every = crisis_every
|
| 143 |
+
self.seed = seed
|
| 144 |
+
self.difficulty = difficulty
|
| 145 |
+
self._rng = random.Random(seed)
|
| 146 |
+
self._crisis_engine = CrisisEngine(crisis_every=crisis_every)
|
| 147 |
+
self._rubric = RewardRubric()
|
| 148 |
+
self._state: RevOpsState = RevOpsState()
|
| 149 |
+
self._prev_state: RevOpsState | None = None
|
| 150 |
+
self._episode_rewards: list[float] = []
|
| 151 |
+
|
| 152 |
+
# ------------------------------------------------------------------
|
| 153 |
+
# OpenEnv core API
|
| 154 |
+
# ------------------------------------------------------------------
|
| 155 |
+
|
| 156 |
+
def reset(self, seed: int | None = None) -> RevOpsObservation:
|
| 157 |
+
"""Start a fresh episode."""
|
| 158 |
+
if seed is not None:
|
| 159 |
+
self._rng = random.Random(seed)
|
| 160 |
+
|
| 161 |
+
base = {
|
| 162 |
+
"mrr": self._rng.uniform(40_000, 80_000),
|
| 163 |
+
"cac": self._rng.uniform(1_500, 3_000),
|
| 164 |
+
"ltv": self._rng.uniform(8_000, 18_000),
|
| 165 |
+
"churn_rate": self._rng.uniform(0.02, 0.06),
|
| 166 |
+
"cash_runway": self._rng.uniform(12, 24),
|
| 167 |
+
"marketing_spend": self._rng.uniform(10_000, 25_000),
|
| 168 |
+
"support_quality": self._rng.uniform(0.60, 0.90),
|
| 169 |
+
"active_crisis": "NONE",
|
| 170 |
+
"step_number": 0,
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
# Difficulty adjustments
|
| 174 |
+
if self.difficulty == "easy":
|
| 175 |
+
base["cash_runway"] *= 1.5
|
| 176 |
+
base["churn_rate"] *= 0.6
|
| 177 |
+
elif self.difficulty == "hard":
|
| 178 |
+
base["cash_runway"] *= 0.6
|
| 179 |
+
base["churn_rate"] *= 1.4
|
| 180 |
+
base["cac"] *= 1.3
|
| 181 |
+
|
| 182 |
+
self._state = RevOpsState(**base)
|
| 183 |
+
self._prev_state = None
|
| 184 |
+
self._episode_rewards = []
|
| 185 |
+
|
| 186 |
+
return self._to_observation(reward=None, terminated=False, truncated=False)
|
| 187 |
+
|
| 188 |
+
def step(self, action: dict | RevOpsAction) -> RevOpsObservation:
|
| 189 |
+
"""Apply an action and advance the world by one step."""
|
| 190 |
+
if isinstance(action, dict):
|
| 191 |
+
action = RevOpsAction(**action)
|
| 192 |
+
|
| 193 |
+
if self._state.is_terminal:
|
| 194 |
+
# Already done β return terminal observation
|
| 195 |
+
return self._to_observation(reward=0.0, terminated=True, truncated=False)
|
| 196 |
+
|
| 197 |
+
prev = self._state
|
| 198 |
+
new_state = _apply_action(self._state, action)
|
| 199 |
+
|
| 200 |
+
# Apply adversarial crisis every N steps
|
| 201 |
+
if self._crisis_engine.should_trigger(new_state.step_number):
|
| 202 |
+
crisis = self._crisis_engine.generate(new_state)
|
| 203 |
+
new_state = self._crisis_engine.apply(new_state, crisis)
|
| 204 |
+
|
| 205 |
+
self._prev_state = prev
|
| 206 |
+
self._state = new_state
|
| 207 |
+
|
| 208 |
+
terminated = self._state.is_terminal and self._state.step_number < 30
|
| 209 |
+
truncated = self._state.step_number >= 30
|
| 210 |
+
|
| 211 |
+
reward_breakdown = self._rubric.compute(
|
| 212 |
+
self._state, prev, terminated=terminated
|
| 213 |
+
)
|
| 214 |
+
reward = reward_breakdown.total
|
| 215 |
+
self._episode_rewards.append(reward)
|
| 216 |
+
|
| 217 |
+
return self._to_observation(
|
| 218 |
+
reward=reward,
|
| 219 |
+
terminated=terminated,
|
| 220 |
+
truncated=truncated,
|
| 221 |
+
info={
|
| 222 |
+
"reward_breakdown": reward_breakdown.to_dict(),
|
| 223 |
+
"crisis_applied": self._state.active_crisis,
|
| 224 |
+
"episode_mean_reward": sum(self._episode_rewards) / len(self._episode_rewards),
|
| 225 |
+
},
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
def state(self) -> RevOpsObservation:
|
| 229 |
+
"""Return current observation without advancing."""
|
| 230 |
+
return self._to_observation(reward=None, terminated=self._state.is_terminal)
|
| 231 |
+
|
| 232 |
+
# ------------------------------------------------------------------
|
| 233 |
+
# Helpers
|
| 234 |
+
# ------------------------------------------------------------------
|
| 235 |
+
|
| 236 |
+
def _to_observation(
|
| 237 |
+
self,
|
| 238 |
+
reward: float | None,
|
| 239 |
+
terminated: bool,
|
| 240 |
+
truncated: bool = False,
|
| 241 |
+
info: dict | None = None,
|
| 242 |
+
) -> RevOpsObservation:
|
| 243 |
+
s = self._state
|
| 244 |
+
return RevOpsObservation(
|
| 245 |
+
mrr=round(s.mrr, 2),
|
| 246 |
+
cac=round(s.cac, 2),
|
| 247 |
+
ltv=round(s.ltv, 2),
|
| 248 |
+
churn_rate=round(s.churn_rate, 4),
|
| 249 |
+
cash_runway=round(s.cash_runway, 2),
|
| 250 |
+
marketing_spend=round(s.marketing_spend, 2),
|
| 251 |
+
support_quality=round(s.support_quality, 4),
|
| 252 |
+
active_crisis=s.active_crisis,
|
| 253 |
+
step_number=s.step_number,
|
| 254 |
+
ltv_cac_ratio=round(s.ltv_cac_ratio, 3),
|
| 255 |
+
reward_last_step=reward,
|
| 256 |
+
terminated=terminated,
|
| 257 |
+
truncated=truncated,
|
| 258 |
+
info=info or {},
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
def render(self) -> str:
|
| 262 |
+
"""Text render of current state (for debugging)."""
|
| 263 |
+
s = self._state
|
| 264 |
+
return (
|
| 265 |
+
f"\n{'='*50}\n"
|
| 266 |
+
f"RevOps Dashboard | Step {s.step_number}/30\n"
|
| 267 |
+
f"{'='*50}\n"
|
| 268 |
+
f"MRR: ${s.mrr:>12,.0f} (floor: $20,000)\n"
|
| 269 |
+
f"CAC: ${s.cac:>12,.0f}\n"
|
| 270 |
+
f"LTV: ${s.ltv:>12,.0f}\n"
|
| 271 |
+
f"LTV/CAC: {s.ltv_cac_ratio:>12.2f}x (target: 3x+)\n"
|
| 272 |
+
f"Churn: {s.churn_rate*100:>11.1f}%\n"
|
| 273 |
+
f"Cash runway: {s.cash_runway:>11.1f} months\n"
|
| 274 |
+
f"Mktg spend: ${s.marketing_spend:>12,.0f}/mo\n"
|
| 275 |
+
f"Support QA: {s.support_quality*100:>11.0f}%\n"
|
| 276 |
+
f"Crisis: {s.active_crisis:>12}\n"
|
| 277 |
+
f"{'='*50}\n"
|
| 278 |
+
)
|
revops_gym/models.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data models for RevOps Gym."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
from typing import Literal, Optional
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# ---------------------------------------------------------------------------
|
| 9 |
+
# Action
|
| 10 |
+
# ---------------------------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
ActionType = Literal[
|
| 13 |
+
"increase_marketing",
|
| 14 |
+
"decrease_marketing",
|
| 15 |
+
"hire_support",
|
| 16 |
+
"fire_support",
|
| 17 |
+
"discount_campaign",
|
| 18 |
+
"raise_prices",
|
| 19 |
+
"feature_investment",
|
| 20 |
+
"cut_costs",
|
| 21 |
+
"negotiate_contracts",
|
| 22 |
+
"pivot_segment",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class RevOpsAction(BaseModel):
|
| 27 |
+
action_type: ActionType = Field(
|
| 28 |
+
description="The strategic lever the agent pulls."
|
| 29 |
+
)
|
| 30 |
+
magnitude: float = Field(
|
| 31 |
+
default=0.5,
|
| 32 |
+
ge=0.1,
|
| 33 |
+
le=1.0,
|
| 34 |
+
description="Strength of the action, 0.1 (subtle) to 1.0 (aggressive).",
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
# World state (internal)
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
class RevOpsState(BaseModel):
|
| 43 |
+
mrr: float = Field(default=50_000.0, description="Monthly Recurring Revenue USD")
|
| 44 |
+
cac: float = Field(default=2_000.0, description="Customer Acquisition Cost USD")
|
| 45 |
+
ltv: float = Field(default=12_000.0, description="Customer Lifetime Value USD")
|
| 46 |
+
churn_rate: float = Field(default=0.04, description="Monthly churn rate 0-1")
|
| 47 |
+
cash_runway: float = Field(default=18.0, description="Months of cash runway")
|
| 48 |
+
marketing_spend: float = Field(default=15_000.0, description="Monthly marketing budget USD")
|
| 49 |
+
support_quality: float = Field(default=0.75, description="Support quality score 0-1")
|
| 50 |
+
active_crisis: str = Field(default="NONE", description="Current adversarial crisis or NONE")
|
| 51 |
+
step_number: int = Field(default=0)
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def ltv_cac_ratio(self) -> float:
|
| 55 |
+
return self.ltv / max(self.cac, 1.0)
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def mrr_floor(self) -> float:
|
| 59 |
+
"""VC survival floor β must stay above this."""
|
| 60 |
+
return 20_000.0
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def is_terminal(self) -> bool:
|
| 64 |
+
return (
|
| 65 |
+
self.mrr < self.mrr_floor
|
| 66 |
+
or self.cash_runway <= 0
|
| 67 |
+
or self.churn_rate > 0.20
|
| 68 |
+
or self.step_number >= 30
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
# Observation (what the agent sees)
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
|
| 76 |
+
class RevOpsObservation(BaseModel):
|
| 77 |
+
mrr: float
|
| 78 |
+
cac: float
|
| 79 |
+
ltv: float
|
| 80 |
+
churn_rate: float
|
| 81 |
+
cash_runway: float
|
| 82 |
+
marketing_spend: float
|
| 83 |
+
support_quality: float
|
| 84 |
+
active_crisis: str
|
| 85 |
+
step_number: int
|
| 86 |
+
ltv_cac_ratio: float
|
| 87 |
+
reward_last_step: Optional[float] = None
|
| 88 |
+
terminated: bool = False
|
| 89 |
+
truncated: bool = False
|
| 90 |
+
info: dict = Field(default_factory=dict)
|
| 91 |
+
|
| 92 |
+
def to_prompt_text(self) -> str:
|
| 93 |
+
"""Convert observation to a concise text prompt for the LLM."""
|
| 94 |
+
crisis_text = (
|
| 95 |
+
f"\nβ οΈ ACTIVE CRISIS: {self.active_crisis}"
|
| 96 |
+
if self.active_crisis != "NONE"
|
| 97 |
+
else ""
|
| 98 |
+
)
|
| 99 |
+
return (
|
| 100 |
+
f"=== RevOps Dashboard | Step {self.step_number}/30 ==={crisis_text}\n"
|
| 101 |
+
f"MRR: ${self.mrr:,.0f} | Floor: $20,000\n"
|
| 102 |
+
f"CAC: ${self.cac:,.0f} | LTV: ${self.ltv:,.0f} | LTV/CAC: {self.ltv_cac_ratio:.2f}x\n"
|
| 103 |
+
f"Churn: {self.churn_rate*100:.1f}% | Runway: {self.cash_runway:.1f}mo\n"
|
| 104 |
+
f"Marketing spend: ${self.marketing_spend:,.0f}/mo | Support quality: {self.support_quality*100:.0f}%\n"
|
| 105 |
+
f"Last reward: {self.reward_last_step or 0:.3f}\n"
|
| 106 |
+
"\nAvailable actions: increase_marketing, decrease_marketing, hire_support, "
|
| 107 |
+
"fire_support, discount_campaign, raise_prices, feature_investment, "
|
| 108 |
+
"cut_costs, negotiate_contracts, pivot_segment\n"
|
| 109 |
+
"Respond ONLY with JSON: {\"action_type\": \"...\", \"magnitude\": 0.0-1.0}"
|
| 110 |
+
)
|
revops_gym/reward.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reward rubric for RevOps Gym.
|
| 3 |
+
|
| 4 |
+
Four independent reward signals following OpenEnv's Rubric pattern.
|
| 5 |
+
Multiple signals prevent reward hacking β an agent can't fake the ratio
|
| 6 |
+
without also controlling churn and burn rate simultaneously.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from typing import TYPE_CHECKING
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from revops_gym.models import RevOpsState
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class RewardBreakdown:
|
| 19 |
+
ltv_cac: float # 0-1 golden ratio signal
|
| 20 |
+
mrr_growth: float # 0-1 revenue trajectory
|
| 21 |
+
burn_efficiency: float # 0-1 not burning cash recklessly
|
| 22 |
+
survival_bonus: float # 0 or 1 staying alive
|
| 23 |
+
total: float
|
| 24 |
+
terminated_penalty: float = 0.0
|
| 25 |
+
|
| 26 |
+
def to_dict(self) -> dict:
|
| 27 |
+
return {
|
| 28 |
+
"ltv_cac": round(self.ltv_cac, 4),
|
| 29 |
+
"mrr_growth": round(self.mrr_growth, 4),
|
| 30 |
+
"burn_efficiency": round(self.burn_efficiency, 4),
|
| 31 |
+
"survival_bonus": round(self.survival_bonus, 4),
|
| 32 |
+
"terminated_penalty": round(self.terminated_penalty, 4),
|
| 33 |
+
"total": round(self.total, 4),
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class RewardRubric:
|
| 38 |
+
"""
|
| 39 |
+
Computes composite reward from current and previous state.
|
| 40 |
+
|
| 41 |
+
Weights:
|
| 42 |
+
ltv_cac_ratio 35% β the "golden ratio" of SaaS health
|
| 43 |
+
mrr_growth 30% β revenue trajectory
|
| 44 |
+
burn_efficiency 20% β sustainable spending
|
| 45 |
+
survival_bonus 15% β staying above the VC floor
|
| 46 |
+
|
| 47 |
+
Termination penalty: -2.0 applied on top if the company dies.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
WEIGHTS = {
|
| 51 |
+
"ltv_cac": 0.35,
|
| 52 |
+
"mrr_growth": 0.30,
|
| 53 |
+
"burn_efficiency": 0.20,
|
| 54 |
+
"survival_bonus": 0.15,
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
# Target benchmarks for SaaS health
|
| 58 |
+
TARGET_LTV_CAC = 3.0 # 3x is "good", 5x+ is excellent
|
| 59 |
+
TARGET_CHURN = 0.02 # 2% monthly is good SaaS
|
| 60 |
+
MAX_BURN_RATIO = 0.50 # marketing spend / MRR ceiling
|
| 61 |
+
|
| 62 |
+
def compute(
|
| 63 |
+
self,
|
| 64 |
+
state: "RevOpsState",
|
| 65 |
+
prev_state: "RevOpsState | None" = None,
|
| 66 |
+
terminated: bool = False,
|
| 67 |
+
) -> RewardBreakdown:
|
| 68 |
+
|
| 69 |
+
# --- Signal 1: LTV/CAC golden ratio ---
|
| 70 |
+
ratio = state.ltv_cac_ratio
|
| 71 |
+
if ratio >= self.TARGET_LTV_CAC:
|
| 72 |
+
ltv_cac_score = min(1.0, (ratio - self.TARGET_LTV_CAC) / 2.0 * 0.5 + 0.75)
|
| 73 |
+
elif ratio >= 1.0:
|
| 74 |
+
ltv_cac_score = (ratio - 1.0) / (self.TARGET_LTV_CAC - 1.0) * 0.75
|
| 75 |
+
else:
|
| 76 |
+
# ratio < 1.0 means losing money per customer β negative signal
|
| 77 |
+
ltv_cac_score = max(-0.5, (ratio - 1.0) * 0.5)
|
| 78 |
+
|
| 79 |
+
# --- Signal 2: MRR growth ---
|
| 80 |
+
if prev_state is not None:
|
| 81 |
+
mrr_change = (state.mrr - prev_state.mrr) / max(prev_state.mrr, 1)
|
| 82 |
+
# Normalize: +10% growth = 1.0, flat = 0.3, -20% = 0
|
| 83 |
+
mrr_growth_score = max(0.0, min(1.0, mrr_change * 5.0 + 0.3))
|
| 84 |
+
else:
|
| 85 |
+
# First step β reward for being above the floor
|
| 86 |
+
floor_margin = (state.mrr - state.mrr_floor) / state.mrr_floor
|
| 87 |
+
mrr_growth_score = min(1.0, max(0.0, floor_margin * 0.5 + 0.5))
|
| 88 |
+
|
| 89 |
+
# --- Signal 3: Burn efficiency ---
|
| 90 |
+
burn_ratio = state.marketing_spend / max(state.mrr, 1)
|
| 91 |
+
if burn_ratio <= self.MAX_BURN_RATIO:
|
| 92 |
+
burn_score = 1.0 - (burn_ratio / self.MAX_BURN_RATIO) * 0.3
|
| 93 |
+
else:
|
| 94 |
+
burn_score = max(0.0, 1.0 - burn_ratio)
|
| 95 |
+
|
| 96 |
+
# Penalize bad support quality (hidden churn driver)
|
| 97 |
+
support_penalty = max(0.0, 0.75 - state.support_quality) * 0.4
|
| 98 |
+
burn_score = max(0.0, burn_score - support_penalty)
|
| 99 |
+
|
| 100 |
+
# --- Signal 4: Survival bonus ---
|
| 101 |
+
if state.mrr >= state.mrr_floor and state.cash_runway > 3:
|
| 102 |
+
runway_bonus = min(0.5, state.cash_runway / 24) * 0.5
|
| 103 |
+
survival_score = 0.5 + runway_bonus
|
| 104 |
+
elif state.mrr >= state.mrr_floor:
|
| 105 |
+
survival_score = 0.2 # alive but barely
|
| 106 |
+
else:
|
| 107 |
+
survival_score = 0.0
|
| 108 |
+
|
| 109 |
+
# Churn penalty on survival
|
| 110 |
+
if state.churn_rate > 0.10:
|
| 111 |
+
survival_score *= 0.5
|
| 112 |
+
|
| 113 |
+
# --- Weighted total ---
|
| 114 |
+
total = (
|
| 115 |
+
ltv_cac_score * self.WEIGHTS["ltv_cac"]
|
| 116 |
+
+ mrr_growth_score * self.WEIGHTS["mrr_growth"]
|
| 117 |
+
+ burn_score * self.WEIGHTS["burn_efficiency"]
|
| 118 |
+
+ survival_score * self.WEIGHTS["survival_bonus"]
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# --- Termination penalty ---
|
| 122 |
+
term_penalty = 0.0
|
| 123 |
+
if terminated and state.mrr < state.mrr_floor:
|
| 124 |
+
term_penalty = -2.0
|
| 125 |
+
total += term_penalty
|
| 126 |
+
|
| 127 |
+
return RewardBreakdown(
|
| 128 |
+
ltv_cac=ltv_cac_score,
|
| 129 |
+
mrr_growth=mrr_growth_score,
|
| 130 |
+
burn_efficiency=burn_score,
|
| 131 |
+
survival_bonus=survival_score,
|
| 132 |
+
total=total,
|
| 133 |
+
terminated_penalty=term_penalty,
|
| 134 |
+
)
|
revops_gym/server.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI server for RevOps Gym.
|
| 3 |
+
Exposes reset / step / state endpoints per the OpenEnv spec.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
import os
|
| 8 |
+
from fastapi import FastAPI, HTTPException
|
| 9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
+
from fastapi.responses import HTMLResponse
|
| 11 |
+
from pydantic import BaseModel
|
| 12 |
+
|
| 13 |
+
from revops_gym.env import RevOpsEnv
|
| 14 |
+
from revops_gym.models import RevOpsAction, RevOpsObservation
|
| 15 |
+
|
| 16 |
+
app = FastAPI(
|
| 17 |
+
title="RevOps Gym",
|
| 18 |
+
description="A dynamic SaaS flight simulator for LLM RL training.",
|
| 19 |
+
version="0.1.0",
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
app.add_middleware(
|
| 23 |
+
CORSMiddleware,
|
| 24 |
+
allow_origins=["*"],
|
| 25 |
+
allow_methods=["*"],
|
| 26 |
+
allow_headers=["*"],
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Single shared env instance (stateless for multi-client use, fork per session)
|
| 30 |
+
_env = RevOpsEnv(
|
| 31 |
+
crisis_every=int(os.environ.get("CRISIS_EVERY", "3")),
|
| 32 |
+
difficulty=os.environ.get("DIFFICULTY", "normal"),
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ResetRequest(BaseModel):
|
| 37 |
+
seed: int | None = None
|
| 38 |
+
difficulty: str = "normal"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class StepRequest(BaseModel):
|
| 42 |
+
action_type: str
|
| 43 |
+
magnitude: float = 0.5
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ------------------------------------------------------------------
|
| 47 |
+
# OpenEnv required endpoints
|
| 48 |
+
# ------------------------------------------------------------------
|
| 49 |
+
|
| 50 |
+
@app.post("/reset", response_model=RevOpsObservation)
|
| 51 |
+
def reset(req: ResetRequest = ResetRequest()):
|
| 52 |
+
global _env
|
| 53 |
+
_env = RevOpsEnv(
|
| 54 |
+
crisis_every=int(os.environ.get("CRISIS_EVERY", "3")),
|
| 55 |
+
difficulty=req.difficulty,
|
| 56 |
+
)
|
| 57 |
+
return _env.reset(seed=req.seed)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@app.post("/step", response_model=RevOpsObservation)
|
| 61 |
+
def step(req: StepRequest):
|
| 62 |
+
try:
|
| 63 |
+
action = RevOpsAction(action_type=req.action_type, magnitude=req.magnitude)
|
| 64 |
+
except Exception as e:
|
| 65 |
+
raise HTTPException(status_code=422, detail=str(e))
|
| 66 |
+
return _env.step(action)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@app.get("/state", response_model=RevOpsObservation)
|
| 70 |
+
def state():
|
| 71 |
+
return _env.state()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# ------------------------------------------------------------------
|
| 75 |
+
# Optional: human-readable demo UI
|
| 76 |
+
# ------------------------------------------------------------------
|
| 77 |
+
|
| 78 |
+
@app.get("/", response_class=HTMLResponse)
|
| 79 |
+
def index():
|
| 80 |
+
s = _env.state()
|
| 81 |
+
crisis_html = (
|
| 82 |
+
f'<div class="crisis">β οΈ CRISIS: {s.active_crisis}</div>'
|
| 83 |
+
if s.active_crisis != "NONE"
|
| 84 |
+
else ""
|
| 85 |
+
)
|
| 86 |
+
return f"""<!DOCTYPE html>
|
| 87 |
+
<html><head><title>RevOps Gym</title>
|
| 88 |
+
<style>
|
| 89 |
+
body {{ font-family: monospace; background: #0d1117; color: #c9d1d9; padding: 2rem; }}
|
| 90 |
+
h1 {{ color: #58a6ff; }}
|
| 91 |
+
.metric {{ display: inline-block; margin: 0.5rem 1rem; padding: 0.5rem 1rem;
|
| 92 |
+
background: #161b22; border: 1px solid #30363d; border-radius: 6px; }}
|
| 93 |
+
.metric .label {{ font-size: 0.75rem; color: #8b949e; }}
|
| 94 |
+
.metric .value {{ font-size: 1.4rem; color: #3fb950; font-weight: bold; }}
|
| 95 |
+
.crisis {{ background: #3d1a1a; border: 1px solid #f85149; border-radius: 6px;
|
| 96 |
+
padding: 0.75rem 1rem; margin: 1rem 0; color: #f85149; }}
|
| 97 |
+
form {{ margin: 1.5rem 0; }}
|
| 98 |
+
select, input {{ background: #161b22; color: #c9d1d9; border: 1px solid #30363d;
|
| 99 |
+
padding: 0.4rem 0.6rem; border-radius: 4px; margin: 0.3rem; }}
|
| 100 |
+
button {{ background: #238636; color: #fff; border: none; padding: 0.5rem 1.2rem;
|
| 101 |
+
border-radius: 4px; cursor: pointer; }}
|
| 102 |
+
button:hover {{ background: #2ea043; }}
|
| 103 |
+
</style></head><body>
|
| 104 |
+
<h1>π RevOps Gym</h1>
|
| 105 |
+
<p>SaaS Flight Simulator β Step {s.step_number}/30 | LTV/CAC: {s.ltv_cac_ratio:.2f}x</p>
|
| 106 |
+
{crisis_html}
|
| 107 |
+
<div>
|
| 108 |
+
<div class="metric"><div class="label">MRR</div><div class="value">${s.mrr:,.0f}</div></div>
|
| 109 |
+
<div class="metric"><div class="label">CAC</div><div class="value">${s.cac:,.0f}</div></div>
|
| 110 |
+
<div class="metric"><div class="label">LTV</div><div class="value">${s.ltv:,.0f}</div></div>
|
| 111 |
+
<div class="metric"><div class="label">Churn</div><div class="value">{s.churn_rate*100:.1f}%</div></div>
|
| 112 |
+
<div class="metric"><div class="label">Runway</div><div class="value">{s.cash_runway:.1f}mo</div></div>
|
| 113 |
+
<div class="metric"><div class="label">Mktg $</div><div class="value">${s.marketing_spend:,.0f}</div></div>
|
| 114 |
+
<div class="metric"><div class="label">Support</div><div class="value">{s.support_quality*100:.0f}%</div></div>
|
| 115 |
+
</div>
|
| 116 |
+
<form action="/step" method="post" onsubmit="takeAction(event)">
|
| 117 |
+
<label>Action:
|
| 118 |
+
<select id="action_type">
|
| 119 |
+
<option>increase_marketing</option><option>decrease_marketing</option>
|
| 120 |
+
<option>hire_support</option><option>fire_support</option>
|
| 121 |
+
<option>discount_campaign</option><option>raise_prices</option>
|
| 122 |
+
<option>feature_investment</option><option>cut_costs</option>
|
| 123 |
+
<option>negotiate_contracts</option><option>pivot_segment</option>
|
| 124 |
+
</select>
|
| 125 |
+
</label>
|
| 126 |
+
<label>Magnitude: <input id="magnitude" type="range" min="0.1" max="1" step="0.1" value="0.5"></label>
|
| 127 |
+
<button type="submit">Take Action</button>
|
| 128 |
+
<button type="button" onclick="doReset()">Reset Episode</button>
|
| 129 |
+
</form>
|
| 130 |
+
<div id="result"></div>
|
| 131 |
+
<script>
|
| 132 |
+
async function takeAction(e) {{
|
| 133 |
+
e.preventDefault();
|
| 134 |
+
const res = await fetch('/step', {{method:'POST',
|
| 135 |
+
headers:{{'Content-Type':'application/json'}},
|
| 136 |
+
body: JSON.stringify({{
|
| 137 |
+
action_type: document.getElementById('action_type').value,
|
| 138 |
+
magnitude: parseFloat(document.getElementById('magnitude').value)
|
| 139 |
+
}})
|
| 140 |
+
}});
|
| 141 |
+
const d = await res.json();
|
| 142 |
+
document.getElementById('result').innerHTML =
|
| 143 |
+
'<pre>' + JSON.stringify(d, null, 2) + '</pre>';
|
| 144 |
+
location.reload();
|
| 145 |
+
}}
|
| 146 |
+
async function doReset() {{
|
| 147 |
+
await fetch('/reset', {{method:'POST', headers:{{'Content-Type':'application/json'}}, body:'{{}}'}});
|
| 148 |
+
location.reload();
|
| 149 |
+
}}
|
| 150 |
+
</script>
|
| 151 |
+
</body></html>"""
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@app.get("/health")
|
| 155 |
+
def health():
|
| 156 |
+
return {"status": "ok", "version": "0.1.0"}
|
setup.py
ADDED
|
File without changes
|
tests/test_env.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quick smoke test β run locally before pushing to HF Spaces.
|
| 3 |
+
Tests: reset, step through full episode, crisis triggers, reward signals.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python tests/test_env.py
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import sys
|
| 10 |
+
import os
|
| 11 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 12 |
+
|
| 13 |
+
from revops_gym.env import RevOpsEnv
|
| 14 |
+
from revops_gym.models import RevOpsAction
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def test_episode(difficulty="normal", seed=42):
|
| 18 |
+
print(f"\n=== Smoke test | difficulty={difficulty} seed={seed} ===")
|
| 19 |
+
env = RevOpsEnv(crisis_every=3, seed=seed, difficulty=difficulty)
|
| 20 |
+
obs = env.reset(seed=seed)
|
| 21 |
+
assert obs.step_number == 0, "Reset should start at step 0"
|
| 22 |
+
assert obs.mrr > 0, "MRR should be positive after reset"
|
| 23 |
+
print(env.render())
|
| 24 |
+
|
| 25 |
+
actions = [
|
| 26 |
+
("increase_marketing", 0.6),
|
| 27 |
+
("hire_support", 0.8),
|
| 28 |
+
("negotiate_contracts", 0.5),
|
| 29 |
+
("raise_prices", 0.4),
|
| 30 |
+
("feature_investment", 0.7),
|
| 31 |
+
("cut_costs", 0.3),
|
| 32 |
+
("discount_campaign", 0.5),
|
| 33 |
+
("increase_marketing", 0.7),
|
| 34 |
+
("hire_support", 0.5),
|
| 35 |
+
("pivot_segment", 0.6),
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
rewards = []
|
| 39 |
+
crises_seen = []
|
| 40 |
+
for i, (action_type, magnitude) in enumerate(actions):
|
| 41 |
+
obs = env.step({"action_type": action_type, "magnitude": magnitude})
|
| 42 |
+
rewards.append(obs.reward_last_step)
|
| 43 |
+
if obs.active_crisis != "NONE":
|
| 44 |
+
crises_seen.append(obs.active_crisis)
|
| 45 |
+
print(
|
| 46 |
+
f" Step {obs.step_number:2d} | {action_type:<22} mag={magnitude} "
|
| 47 |
+
f"| reward={obs.reward_last_step:+.3f} | MRR=${obs.mrr:,.0f} "
|
| 48 |
+
f"| LTV/CAC={obs.ltv_cac_ratio:.2f}x"
|
| 49 |
+
+ (f" | β οΈ {obs.active_crisis}" if obs.active_crisis != "NONE" else "")
|
| 50 |
+
)
|
| 51 |
+
if obs.terminated or obs.truncated:
|
| 52 |
+
print(f"\n Episode ended at step {obs.step_number} "
|
| 53 |
+
f"({'terminated' if obs.terminated else 'truncated'})")
|
| 54 |
+
break
|
| 55 |
+
|
| 56 |
+
print(f"\n Total steps: {obs.step_number}")
|
| 57 |
+
print(f" Mean reward: {sum(rewards)/len(rewards):.4f}")
|
| 58 |
+
print(f" Min reward: {min(rewards):.4f}")
|
| 59 |
+
print(f" Max reward: {max(rewards):.4f}")
|
| 60 |
+
print(f" Crises seen: {crises_seen or ['none triggered yet']}")
|
| 61 |
+
assert len(rewards) > 0, "Should have at least one reward"
|
| 62 |
+
print("\nβ
Smoke test passed!")
|
| 63 |
+
return True
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def test_all_actions():
|
| 67 |
+
print("\n=== Testing all action types ===")
|
| 68 |
+
env = RevOpsEnv(seed=0)
|
| 69 |
+
env.reset(seed=0)
|
| 70 |
+
all_actions = [
|
| 71 |
+
"increase_marketing", "decrease_marketing", "hire_support",
|
| 72 |
+
"fire_support", "discount_campaign", "raise_prices",
|
| 73 |
+
"feature_investment", "cut_costs", "negotiate_contracts", "pivot_segment",
|
| 74 |
+
]
|
| 75 |
+
for action in all_actions:
|
| 76 |
+
obs = env.step({"action_type": action, "magnitude": 0.5})
|
| 77 |
+
assert obs.reward_last_step is not None
|
| 78 |
+
print(f" β {action:<24} reward={obs.reward_last_step:+.3f}")
|
| 79 |
+
print("β
All actions tested!")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def test_termination():
|
| 83 |
+
print("\n=== Testing termination conditions ===")
|
| 84 |
+
from revops_gym.models import RevOpsState
|
| 85 |
+
from revops_gym.reward import RewardRubric
|
| 86 |
+
rubric = RewardRubric()
|
| 87 |
+
|
| 88 |
+
# MRR below floor
|
| 89 |
+
state = RevOpsState(mrr=5_000, step_number=5)
|
| 90 |
+
assert state.is_terminal, "Should terminate when MRR < floor"
|
| 91 |
+
rb = rubric.compute(state, terminated=True)
|
| 92 |
+
assert rb.terminated_penalty == -2.0, "Should get termination penalty"
|
| 93 |
+
print(" β MRR floor termination works")
|
| 94 |
+
|
| 95 |
+
# Max steps
|
| 96 |
+
state2 = RevOpsState(mrr=100_000, step_number=30)
|
| 97 |
+
assert state2.is_terminal, "Should truncate at step 30"
|
| 98 |
+
print(" β Step limit truncation works")
|
| 99 |
+
print("β
Termination tests passed!")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
test_episode(difficulty="easy")
|
| 104 |
+
test_episode(difficulty="normal")
|
| 105 |
+
test_episode(difficulty="hard")
|
| 106 |
+
test_all_actions()
|
| 107 |
+
test_termination()
|
| 108 |
+
print("\nπ All tests passed! Ready to push to HF Spaces.")
|
train_colab.ipynb
ADDED
|
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "b2ea0425",
|
| 7 |
+
"metadata": {
|
| 8 |
+
"lines_to_next_cell": 0
|
| 9 |
+
},
|
| 10 |
+
"outputs": [],
|
| 11 |
+
"source": [
|
| 12 |
+
"\n",
|
| 13 |
+
"# RevOps Gym β Full Training Script (Colab)\n",
|
| 14 |
+
"# Convert this to a .ipynb with: jupytext --to notebook train_colab.py\n",
|
| 15 |
+
"# Or copy cells manually into Colab.\n",
|
| 16 |
+
"#\n",
|
| 17 |
+
"# Runtime: GPU T4 (free tier) | ~45-60 min for full run\n",
|
| 18 |
+
"# Model: Qwen/Qwen2.5-1.5B-Instruct (1.5B, fits on T4)\n",
|
| 19 |
+
"# Trainer: TRL GRPO + Unsloth\n",
|
| 20 |
+
"# Environment: RevOps Gym (runs locally inside Colab)\n",
|
| 21 |
+
"\n",
|
| 22 |
+
"# ============================================================\n",
|
| 23 |
+
"# CELL 1 β Install dependencies\n",
|
| 24 |
+
"# ============================================================"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "code",
|
| 29 |
+
"execution_count": null,
|
| 30 |
+
"id": "9f1de3ce",
|
| 31 |
+
"metadata": {
|
| 32 |
+
"lines_to_next_cell": 0
|
| 33 |
+
},
|
| 34 |
+
"outputs": [],
|
| 35 |
+
"source": [
|
| 36 |
+
"!pip install -q \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
|
| 37 |
+
"!pip install -q trl>=0.8.0 peft accelerate bitsandbytes\n",
|
| 38 |
+
"!pip install -q fastapi uvicorn pydantic requests wandb matplotlib\n",
|
| 39 |
+
"\n",
|
| 40 |
+
"# ============================================================\n",
|
| 41 |
+
"# CELL 2 β Clone and install RevOps Gym environment\n",
|
| 42 |
+
"# ============================================================"
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "code",
|
| 47 |
+
"execution_count": null,
|
| 48 |
+
"id": "2d37e6b7",
|
| 49 |
+
"metadata": {
|
| 50 |
+
"lines_to_next_cell": 0
|
| 51 |
+
},
|
| 52 |
+
"outputs": [],
|
| 53 |
+
"source": [
|
| 54 |
+
"!git clone https://huggingface.co/spaces/YOUR_HF_USERNAME/revops-gym\n",
|
| 55 |
+
"!pip install -q -e revops-gym/\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"# For local testing without HF Space, copy env files directly:\n",
|
| 58 |
+
"# The environment runs INSIDE Colab β no external server needed for training.\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"# ============================================================\n",
|
| 61 |
+
"# CELL 3 β Imports and config\n",
|
| 62 |
+
"# ============================================================"
|
| 63 |
+
]
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"cell_type": "code",
|
| 67 |
+
"execution_count": null,
|
| 68 |
+
"id": "54f78215",
|
| 69 |
+
"metadata": {
|
| 70 |
+
"lines_to_next_cell": 0
|
| 71 |
+
},
|
| 72 |
+
"outputs": [],
|
| 73 |
+
"source": [
|
| 74 |
+
"import os\n",
|
| 75 |
+
"import json\n",
|
| 76 |
+
"import random\n",
|
| 77 |
+
"import re\n",
|
| 78 |
+
"import time\n",
|
| 79 |
+
"import warnings\n",
|
| 80 |
+
"from typing import Optional\n",
|
| 81 |
+
"import torch\n",
|
| 82 |
+
"import numpy as np\n",
|
| 83 |
+
"import matplotlib.pyplot as plt\n",
|
| 84 |
+
"from collections import defaultdict\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"warnings.filterwarnings(\"ignore\")\n",
|
| 87 |
+
"\n",
|
| 88 |
+
"# --- Config ---\n",
|
| 89 |
+
"MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
|
| 90 |
+
"MAX_NEW_TOKENS = 128\n",
|
| 91 |
+
"BATCH_SIZE = 4 # rollouts per GRPO step (keep small for T4)\n",
|
| 92 |
+
"NUM_EPISODES = 200 # total training episodes\n",
|
| 93 |
+
"GRPO_EPOCHS = 1\n",
|
| 94 |
+
"LR = 5e-6\n",
|
| 95 |
+
"MAX_STEPS_PER_EPISODE = 30\n",
|
| 96 |
+
"SAVE_EVERY = 50 # save checkpoint every N episodes\n",
|
| 97 |
+
"WANDB_PROJECT = \"revops-gym\"\n",
|
| 98 |
+
"USE_WANDB = False # set True if you have wandb account\n",
|
| 99 |
+
"\n",
|
| 100 |
+
"print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
|
| 101 |
+
"if torch.cuda.is_available():\n",
|
| 102 |
+
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
|
| 103 |
+
" print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"# ============================================================\n",
|
| 106 |
+
"# CELL 4 β Load model with Unsloth (4-bit quantization)\n",
|
| 107 |
+
"# ============================================================"
|
| 108 |
+
]
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"cell_type": "code",
|
| 112 |
+
"execution_count": null,
|
| 113 |
+
"id": "5fdfba32",
|
| 114 |
+
"metadata": {
|
| 115 |
+
"lines_to_next_cell": 0
|
| 116 |
+
},
|
| 117 |
+
"outputs": [],
|
| 118 |
+
"source": [
|
| 119 |
+
"from unsloth import FastLanguageModel\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 122 |
+
" model_name=MODEL_NAME,\n",
|
| 123 |
+
" max_seq_length=1024,\n",
|
| 124 |
+
" load_in_4bit=True, # fits on T4 16GB\n",
|
| 125 |
+
" dtype=None, # auto detect\n",
|
| 126 |
+
")\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"# Add LoRA adapters for efficient fine-tuning\n",
|
| 129 |
+
"model = FastLanguageModel.get_peft_model(\n",
|
| 130 |
+
" model,\n",
|
| 131 |
+
" r=16,\n",
|
| 132 |
+
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
|
| 133 |
+
" \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
| 134 |
+
" lora_alpha=16,\n",
|
| 135 |
+
" lora_dropout=0,\n",
|
| 136 |
+
" bias=\"none\",\n",
|
| 137 |
+
" use_gradient_checkpointing=\"unsloth\",\n",
|
| 138 |
+
" random_state=42,\n",
|
| 139 |
+
")\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"print(\"Model loaded with LoRA adapters β\")\n",
|
| 142 |
+
"print(f\"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}\")\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"# ============================================================\n",
|
| 145 |
+
"# CELL 5 β Inline RevOps Gym (no server needed for training)\n",
|
| 146 |
+
"# ============================================================"
|
| 147 |
+
]
|
| 148 |
+
},
|
| 149 |
+
{
|
| 150 |
+
"cell_type": "code",
|
| 151 |
+
"execution_count": null,
|
| 152 |
+
"id": "47f5b140",
|
| 153 |
+
"metadata": {
|
| 154 |
+
"lines_to_next_cell": 0
|
| 155 |
+
},
|
| 156 |
+
"outputs": [],
|
| 157 |
+
"source": [
|
| 158 |
+
"# Import directly β no HTTP server overhead during training\n",
|
| 159 |
+
"import sys\n",
|
| 160 |
+
"sys.path.insert(0, \"revops-gym\") # adjust if cloned elsewhere\n",
|
| 161 |
+
"\n",
|
| 162 |
+
"from revops_gym.env import RevOpsEnv\n",
|
| 163 |
+
"from revops_gym.models import RevOpsObservation\n",
|
| 164 |
+
"\n",
|
| 165 |
+
"SYSTEM_PROMPT = \"\"\"You are a SaaS RevOps strategist managing a B2B software company.\n",
|
| 166 |
+
"Your goal is to maximize sustainable revenue growth while keeping the company alive.\n",
|
| 167 |
+
"The VC will fire you if MRR drops below $20,000.\n",
|
| 168 |
+
"\n",
|
| 169 |
+
"Key metrics to balance:\n",
|
| 170 |
+
"- LTV/CAC ratio (target 3x+): profitability per customer\n",
|
| 171 |
+
"- MRR growth: revenue trajectory \n",
|
| 172 |
+
"- Cash runway: survival buffer\n",
|
| 173 |
+
"- Churn rate: customer retention health\n",
|
| 174 |
+
"- Support quality: drives retention\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"You MUST respond with ONLY a JSON object. No explanation, no markdown, just JSON:\n",
|
| 177 |
+
"{\"action_type\": \"<action>\", \"magnitude\": <0.1-1.0>}\n",
|
| 178 |
+
"\n",
|
| 179 |
+
"Valid actions: increase_marketing, decrease_marketing, hire_support, fire_support,\n",
|
| 180 |
+
"discount_campaign, raise_prices, feature_investment, cut_costs, negotiate_contracts, pivot_segment\"\"\"\n",
|
| 181 |
+
"\n",
|
| 182 |
+
"print(\"Environment ready β\")\n",
|
| 183 |
+
"\n",
|
| 184 |
+
"# ============================================================\n",
|
| 185 |
+
"# CELL 6 β Helper: parse LLM output β action\n",
|
| 186 |
+
"# ============================================================"
|
| 187 |
+
]
|
| 188 |
+
},
|
| 189 |
+
{
|
| 190 |
+
"cell_type": "code",
|
| 191 |
+
"execution_count": null,
|
| 192 |
+
"id": "59f8a75f",
|
| 193 |
+
"metadata": {
|
| 194 |
+
"lines_to_next_cell": 0
|
| 195 |
+
},
|
| 196 |
+
"outputs": [],
|
| 197 |
+
"source": [
|
| 198 |
+
"VALID_ACTIONS = [\n",
|
| 199 |
+
" \"increase_marketing\", \"decrease_marketing\", \"hire_support\", \"fire_support\",\n",
|
| 200 |
+
" \"discount_campaign\", \"raise_prices\", \"feature_investment\", \"cut_costs\",\n",
|
| 201 |
+
" \"negotiate_contracts\", \"pivot_segment\",\n",
|
| 202 |
+
"]\n",
|
| 203 |
+
"\n",
|
| 204 |
+
"def parse_action(text: str) -> dict:\n",
|
| 205 |
+
" \"\"\"Extract JSON action from model output. Returns random valid action on failure.\"\"\"\n",
|
| 206 |
+
" try:\n",
|
| 207 |
+
" # Try to find JSON block\n",
|
| 208 |
+
" match = re.search(r'\\{[^}]+\\}', text, re.DOTALL)\n",
|
| 209 |
+
" if match:\n",
|
| 210 |
+
" data = json.loads(match.group())\n",
|
| 211 |
+
" action_type = data.get(\"action_type\", \"\")\n",
|
| 212 |
+
" magnitude = float(data.get(\"magnitude\", 0.5))\n",
|
| 213 |
+
" if action_type in VALID_ACTIONS:\n",
|
| 214 |
+
" magnitude = max(0.1, min(1.0, magnitude))\n",
|
| 215 |
+
" return {\"action_type\": action_type, \"magnitude\": magnitude}\n",
|
| 216 |
+
" except Exception:\n",
|
| 217 |
+
" pass\n",
|
| 218 |
+
" # Fallback: random valid action\n",
|
| 219 |
+
" return {\"action_type\": random.choice(VALID_ACTIONS), \"magnitude\": 0.5}\n",
|
| 220 |
+
"\n",
|
| 221 |
+
"\n",
|
| 222 |
+
"def build_prompt(obs: RevOpsObservation) -> str:\n",
|
| 223 |
+
" return obs.to_prompt_text()\n",
|
| 224 |
+
"\n",
|
| 225 |
+
"\n",
|
| 226 |
+
"def generate_action(obs: RevOpsObservation, do_sample: bool = True) -> tuple[str, dict]:\n",
|
| 227 |
+
" \"\"\"Run one forward pass, return (raw_text, parsed_action).\"\"\"\n",
|
| 228 |
+
" prompt = build_prompt(obs)\n",
|
| 229 |
+
" messages = [\n",
|
| 230 |
+
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
| 231 |
+
" {\"role\": \"user\", \"content\": prompt},\n",
|
| 232 |
+
" ]\n",
|
| 233 |
+
" input_ids = tokenizer.apply_chat_template(\n",
|
| 234 |
+
" messages, tokenize=True, add_generation_prompt=True,\n",
|
| 235 |
+
" return_tensors=\"pt\"\n",
|
| 236 |
+
" ).to(model.device)\n",
|
| 237 |
+
"\n",
|
| 238 |
+
" with torch.no_grad():\n",
|
| 239 |
+
" output = model.generate(\n",
|
| 240 |
+
" input_ids,\n",
|
| 241 |
+
" max_new_tokens=MAX_NEW_TOKENS,\n",
|
| 242 |
+
" do_sample=do_sample,\n",
|
| 243 |
+
" temperature=0.8 if do_sample else 0.1,\n",
|
| 244 |
+
" top_p=0.9,\n",
|
| 245 |
+
" pad_token_id=tokenizer.eos_token_id,\n",
|
| 246 |
+
" )\n",
|
| 247 |
+
" new_tokens = output[0][input_ids.shape[1]:]\n",
|
| 248 |
+
" raw = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()\n",
|
| 249 |
+
" action = parse_action(raw)\n",
|
| 250 |
+
" return raw, action\n",
|
| 251 |
+
"\n",
|
| 252 |
+
"print(\"Inference helpers ready β\")\n",
|
| 253 |
+
"\n",
|
| 254 |
+
"# ============================================================\n",
|
| 255 |
+
"# CELL 7 β Rollout function (one episode)\n",
|
| 256 |
+
"# ============================================================"
|
| 257 |
+
]
|
| 258 |
+
},
|
| 259 |
+
{
|
| 260 |
+
"cell_type": "code",
|
| 261 |
+
"execution_count": null,
|
| 262 |
+
"id": "6e1497dc",
|
| 263 |
+
"metadata": {
|
| 264 |
+
"lines_to_next_cell": 0
|
| 265 |
+
},
|
| 266 |
+
"outputs": [],
|
| 267 |
+
"source": [
|
| 268 |
+
"def rollout(env: RevOpsEnv, difficulty: str = \"normal\") -> dict:\n",
|
| 269 |
+
" \"\"\"\n",
|
| 270 |
+
" Run one full episode. Returns trajectory with prompts, outputs, rewards.\n",
|
| 271 |
+
" Used by GRPO to score and train.\n",
|
| 272 |
+
" \"\"\"\n",
|
| 273 |
+
" obs = env.reset(seed=random.randint(0, 10_000))\n",
|
| 274 |
+
" trajectory = {\n",
|
| 275 |
+
" \"prompts\": [],\n",
|
| 276 |
+
" \"responses\": [],\n",
|
| 277 |
+
" \"rewards\": [],\n",
|
| 278 |
+
" \"infos\": [],\n",
|
| 279 |
+
" \"final_mrr\": 0.0,\n",
|
| 280 |
+
" \"survived\": False,\n",
|
| 281 |
+
" \"steps\": 0,\n",
|
| 282 |
+
" }\n",
|
| 283 |
+
"\n",
|
| 284 |
+
" for step in range(MAX_STEPS_PER_EPISODE):\n",
|
| 285 |
+
" raw, action = generate_action(obs)\n",
|
| 286 |
+
" obs = env.step(action)\n",
|
| 287 |
+
"\n",
|
| 288 |
+
" trajectory[\"prompts\"].append(build_prompt(obs))\n",
|
| 289 |
+
" trajectory[\"responses\"].append(raw)\n",
|
| 290 |
+
" trajectory[\"rewards\"].append(obs.reward_last_step or 0.0)\n",
|
| 291 |
+
" trajectory[\"infos\"].append(obs.info)\n",
|
| 292 |
+
"\n",
|
| 293 |
+
" if obs.terminated or obs.truncated:\n",
|
| 294 |
+
" break\n",
|
| 295 |
+
"\n",
|
| 296 |
+
" trajectory[\"final_mrr\"] = obs.mrr\n",
|
| 297 |
+
" trajectory[\"survived\"] = not obs.terminated or obs.step_number >= MAX_STEPS_PER_EPISODE\n",
|
| 298 |
+
" trajectory[\"steps\"] = obs.step_number\n",
|
| 299 |
+
" return trajectory\n",
|
| 300 |
+
"\n",
|
| 301 |
+
"\n",
|
| 302 |
+
"def rollout_batch(n: int = BATCH_SIZE, difficulty: str = \"normal\") -> list[dict]:\n",
|
| 303 |
+
" \"\"\"Run N rollouts and return batch.\"\"\"\n",
|
| 304 |
+
" env = RevOpsEnv(crisis_every=3, difficulty=difficulty)\n",
|
| 305 |
+
" return [rollout(env, difficulty) for _ in range(n)]\n",
|
| 306 |
+
"\n",
|
| 307 |
+
"\n",
|
| 308 |
+
"print(\"Rollout function ready β\")\n",
|
| 309 |
+
"\n",
|
| 310 |
+
"# ============================================================\n",
|
| 311 |
+
"# CELL 8 β Baseline evaluation (untrained model)\n",
|
| 312 |
+
"# ============================================================"
|
| 313 |
+
]
|
| 314 |
+
},
|
| 315 |
+
{
|
| 316 |
+
"cell_type": "code",
|
| 317 |
+
"execution_count": null,
|
| 318 |
+
"id": "fa7c83e9",
|
| 319 |
+
"metadata": {
|
| 320 |
+
"lines_to_next_cell": 0
|
| 321 |
+
},
|
| 322 |
+
"outputs": [],
|
| 323 |
+
"source": [
|
| 324 |
+
"print(\"=\" * 60)\n",
|
| 325 |
+
"print(\"BASELINE EVALUATION (untrained model)\")\n",
|
| 326 |
+
"print(\"=\" * 60)\n",
|
| 327 |
+
"\n",
|
| 328 |
+
"FastLanguageModel.for_inference(model) # enable fast inference mode\n",
|
| 329 |
+
"\n",
|
| 330 |
+
"baseline_rewards = []\n",
|
| 331 |
+
"baseline_mrrs = []\n",
|
| 332 |
+
"baseline_survivals = []\n",
|
| 333 |
+
"N_BASELINE = 10\n",
|
| 334 |
+
"\n",
|
| 335 |
+
"for i in range(N_BASELINE):\n",
|
| 336 |
+
" t = rollout(RevOpsEnv(crisis_every=3, seed=i))\n",
|
| 337 |
+
" mean_r = sum(t[\"rewards\"]) / max(len(t[\"rewards\"]), 1)\n",
|
| 338 |
+
" baseline_rewards.append(mean_r)\n",
|
| 339 |
+
" baseline_mrrs.append(t[\"final_mrr\"])\n",
|
| 340 |
+
" baseline_survivals.append(1 if t[\"survived\"] else 0)\n",
|
| 341 |
+
" print(f\" Episode {i+1:2d} | mean_reward={mean_r:.4f} | \"\n",
|
| 342 |
+
" f\"final_MRR=${t['final_mrr']:,.0f} | survived={t['survived']}\")\n",
|
| 343 |
+
"\n",
|
| 344 |
+
"print(f\"\\nBaseline mean reward: {np.mean(baseline_rewards):.4f} Β± {np.std(baseline_rewards):.4f}\")\n",
|
| 345 |
+
"print(f\"Baseline mean final MRR: ${np.mean(baseline_mrrs):,.0f}\")\n",
|
| 346 |
+
"print(f\"Baseline survival rate: {np.mean(baseline_survivals)*100:.0f}%\")\n",
|
| 347 |
+
"\n",
|
| 348 |
+
"# ============================================================\n",
|
| 349 |
+
"# CELL 9 β GRPO Training loop\n",
|
| 350 |
+
"# ============================================================"
|
| 351 |
+
]
|
| 352 |
+
},
|
| 353 |
+
{
|
| 354 |
+
"cell_type": "code",
|
| 355 |
+
"execution_count": null,
|
| 356 |
+
"id": "6b3dab79",
|
| 357 |
+
"metadata": {
|
| 358 |
+
"lines_to_next_cell": 0
|
| 359 |
+
},
|
| 360 |
+
"outputs": [],
|
| 361 |
+
"source": [
|
| 362 |
+
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 363 |
+
"\n",
|
| 364 |
+
"# Switch to training mode\n",
|
| 365 |
+
"FastLanguageModel.for_training(model)\n",
|
| 366 |
+
"\n",
|
| 367 |
+
"# --- GRPO reward function ---\n",
|
| 368 |
+
"def grpo_reward_fn(prompts, completions, **kwargs) -> list[float]:\n",
|
| 369 |
+
" \"\"\"\n",
|
| 370 |
+
" Reward function called by GRPOTrainer.\n",
|
| 371 |
+
" Parses each completion as a RevOps action and scores:\n",
|
| 372 |
+
" 1. Format compliance (+0.1 for valid JSON)\n",
|
| 373 |
+
" 2. Action validity (+0.1 for known action type)\n",
|
| 374 |
+
" 3. Magnitude reasonableness (+0.05)\n",
|
| 375 |
+
" 4. Contextual quality (estimated from prompt metrics)\n",
|
| 376 |
+
" \"\"\"\n",
|
| 377 |
+
" rewards = []\n",
|
| 378 |
+
" for prompt, completion in zip(prompts, completions):\n",
|
| 379 |
+
" reward = 0.0\n",
|
| 380 |
+
"\n",
|
| 381 |
+
" # Format reward\n",
|
| 382 |
+
" try:\n",
|
| 383 |
+
" match = re.search(r'\\{[^}]+\\}', completion, re.DOTALL)\n",
|
| 384 |
+
" if match:\n",
|
| 385 |
+
" data = json.loads(match.group())\n",
|
| 386 |
+
" reward += 0.1 # valid JSON\n",
|
| 387 |
+
"\n",
|
| 388 |
+
" if data.get(\"action_type\") in VALID_ACTIONS:\n",
|
| 389 |
+
" reward += 0.1 # valid action\n",
|
| 390 |
+
"\n",
|
| 391 |
+
" mag = float(data.get(\"magnitude\", -1))\n",
|
| 392 |
+
" if 0.1 <= mag <= 1.0:\n",
|
| 393 |
+
" reward += 0.05 # sensible magnitude\n",
|
| 394 |
+
"\n",
|
| 395 |
+
" # Contextual bonus: penalize fire_support when support_quality is low\n",
|
| 396 |
+
" if \"support_quality\" in prompt:\n",
|
| 397 |
+
" sq_match = re.search(r'Support quality: (\\d+)%', prompt)\n",
|
| 398 |
+
" if sq_match:\n",
|
| 399 |
+
" sq = int(sq_match.group(1))\n",
|
| 400 |
+
" if data.get(\"action_type\") == \"fire_support\" and sq < 60:\n",
|
| 401 |
+
" reward -= 0.15 # punish bad decision\n",
|
| 402 |
+
"\n",
|
| 403 |
+
" # Bonus for crisis-responsive actions\n",
|
| 404 |
+
" if \"ACTIVE CRISIS\" in prompt:\n",
|
| 405 |
+
" crisis_actions = {\n",
|
| 406 |
+
" \"CHURN_SPIKE\": [\"hire_support\", \"discount_campaign\", \"negotiate_contracts\"],\n",
|
| 407 |
+
" \"CAC_EXPLOSION\": [\"decrease_marketing\", \"feature_investment\", \"pivot_segment\"],\n",
|
| 408 |
+
" \"CASH_CRUNCH\": [\"cut_costs\", \"decrease_marketing\", \"raise_prices\"],\n",
|
| 409 |
+
" \"SUPPORT_CRISIS\": [\"hire_support\", \"feature_investment\"],\n",
|
| 410 |
+
" \"ENTERPRISE_CHURN\": [\"negotiate_contracts\", \"raise_prices\", \"feature_investment\"],\n",
|
| 411 |
+
" }\n",
|
| 412 |
+
" for crisis, good_actions in crisis_actions.items():\n",
|
| 413 |
+
" if crisis in prompt and data.get(\"action_type\") in good_actions:\n",
|
| 414 |
+
" reward += 0.20\n",
|
| 415 |
+
" break\n",
|
| 416 |
+
"\n",
|
| 417 |
+
" except Exception:\n",
|
| 418 |
+
" reward -= 0.05 # malformed output penalty\n",
|
| 419 |
+
"\n",
|
| 420 |
+
" rewards.append(reward)\n",
|
| 421 |
+
" return rewards\n",
|
| 422 |
+
"\n",
|
| 423 |
+
"\n",
|
| 424 |
+
"# --- Training config ---\n",
|
| 425 |
+
"training_args = GRPOConfig(\n",
|
| 426 |
+
" output_dir=\"./revops-gym-checkpoints\",\n",
|
| 427 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
| 428 |
+
" gradient_accumulation_steps=4,\n",
|
| 429 |
+
" num_train_epochs=1,\n",
|
| 430 |
+
" learning_rate=LR,\n",
|
| 431 |
+
" warmup_steps=10,\n",
|
| 432 |
+
" logging_steps=5,\n",
|
| 433 |
+
" save_steps=SAVE_EVERY,\n",
|
| 434 |
+
" fp16=not torch.cuda.is_bf16_supported(),\n",
|
| 435 |
+
" bf16=torch.cuda.is_bf16_supported(),\n",
|
| 436 |
+
" report_to=\"wandb\" if USE_WANDB else \"none\",\n",
|
| 437 |
+
" run_name=\"revops-gym-grpo\",\n",
|
| 438 |
+
" num_generations=BATCH_SIZE, # samples per prompt\n",
|
| 439 |
+
" max_new_tokens=MAX_NEW_TOKENS,\n",
|
| 440 |
+
" temperature=0.8,\n",
|
| 441 |
+
" optim=\"adamw_8bit\",\n",
|
| 442 |
+
" seed=42,\n",
|
| 443 |
+
")\n",
|
| 444 |
+
"\n",
|
| 445 |
+
"# --- Build training dataset from rollouts ---\n",
|
| 446 |
+
"print(\"Collecting initial rollout batch for dataset...\")\n",
|
| 447 |
+
"env_train = RevOpsEnv(crisis_every=3, difficulty=\"easy\") # start easy\n",
|
| 448 |
+
"\n",
|
| 449 |
+
"# GRPO needs a dataset of prompts; it generates completions internally\n",
|
| 450 |
+
"from datasets import Dataset\n",
|
| 451 |
+
"\n",
|
| 452 |
+
"def build_prompt_dataset(n_samples: int = 200) -> Dataset:\n",
|
| 453 |
+
" \"\"\"Generate diverse prompts by rolling out episodes and capturing observations.\"\"\"\n",
|
| 454 |
+
" prompts = []\n",
|
| 455 |
+
" env = RevOpsEnv(crisis_every=3)\n",
|
| 456 |
+
" for i in range(n_samples):\n",
|
| 457 |
+
" obs = env.reset(seed=i)\n",
|
| 458 |
+
" for _ in range(random.randint(0, 10)):\n",
|
| 459 |
+
" action = random.choice(VALID_ACTIONS)\n",
|
| 460 |
+
" obs = env.step({\"action_type\": action, \"magnitude\": random.uniform(0.1, 1.0)})\n",
|
| 461 |
+
" if obs.terminated or obs.truncated:\n",
|
| 462 |
+
" obs = env.reset(seed=i * 100)\n",
|
| 463 |
+
" break\n",
|
| 464 |
+
" messages = [\n",
|
| 465 |
+
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
| 466 |
+
" {\"role\": \"user\", \"content\": obs.to_prompt_text()},\n",
|
| 467 |
+
" ]\n",
|
| 468 |
+
" prompt_text = tokenizer.apply_chat_template(\n",
|
| 469 |
+
" messages, tokenize=False, add_generation_prompt=True\n",
|
| 470 |
+
" )\n",
|
| 471 |
+
" prompts.append({\"prompt\": prompt_text})\n",
|
| 472 |
+
" return Dataset.from_list(prompts)\n",
|
| 473 |
+
"\n",
|
| 474 |
+
"print(\"Building prompt dataset (200 samples)...\")\n",
|
| 475 |
+
"train_dataset = build_prompt_dataset(200)\n",
|
| 476 |
+
"print(f\"Dataset size: {len(train_dataset)} prompts β\")\n",
|
| 477 |
+
"\n",
|
| 478 |
+
"# ============================================================\n",
|
| 479 |
+
"# CELL 10 β Run training\n",
|
| 480 |
+
"# ============================================================"
|
| 481 |
+
]
|
| 482 |
+
},
|
| 483 |
+
{
|
| 484 |
+
"cell_type": "code",
|
| 485 |
+
"execution_count": null,
|
| 486 |
+
"id": "f2885f1e",
|
| 487 |
+
"metadata": {
|
| 488 |
+
"lines_to_next_cell": 0
|
| 489 |
+
},
|
| 490 |
+
"outputs": [],
|
| 491 |
+
"source": [
|
| 492 |
+
"trainer = GRPOTrainer(\n",
|
| 493 |
+
" model=model,\n",
|
| 494 |
+
" processing_class=tokenizer,\n",
|
| 495 |
+
" reward_funcs=grpo_reward_fn,\n",
|
| 496 |
+
" args=training_args,\n",
|
| 497 |
+
" train_dataset=train_dataset,\n",
|
| 498 |
+
")\n",
|
| 499 |
+
"\n",
|
| 500 |
+
"print(\"Starting GRPO training...\")\n",
|
| 501 |
+
"print(f\" Model: {MODEL_NAME}\")\n",
|
| 502 |
+
"print(f\" Dataset: {len(train_dataset)} prompts\")\n",
|
| 503 |
+
"print(f\" Batch: {BATCH_SIZE} generations per step\")\n",
|
| 504 |
+
"print(f\" LR: {LR}\")\n",
|
| 505 |
+
"print()\n",
|
| 506 |
+
"\n",
|
| 507 |
+
"train_result = trainer.train()\n",
|
| 508 |
+
"print(\"\\nTraining complete β\")\n",
|
| 509 |
+
"print(train_result)\n",
|
| 510 |
+
"\n",
|
| 511 |
+
"# ============================================================\n",
|
| 512 |
+
"# CELL 11 β Post-training evaluation\n",
|
| 513 |
+
"# ============================================================"
|
| 514 |
+
]
|
| 515 |
+
},
|
| 516 |
+
{
|
| 517 |
+
"cell_type": "code",
|
| 518 |
+
"execution_count": null,
|
| 519 |
+
"id": "75ba3e52",
|
| 520 |
+
"metadata": {
|
| 521 |
+
"lines_to_next_cell": 0
|
| 522 |
+
},
|
| 523 |
+
"outputs": [],
|
| 524 |
+
"source": [
|
| 525 |
+
"FastLanguageModel.for_inference(model)\n",
|
| 526 |
+
"\n",
|
| 527 |
+
"print(\"=\" * 60)\n",
|
| 528 |
+
"print(\"POST-TRAINING EVALUATION\")\n",
|
| 529 |
+
"print(\"=\" * 60)\n",
|
| 530 |
+
"\n",
|
| 531 |
+
"trained_rewards = []\n",
|
| 532 |
+
"trained_mrrs = []\n",
|
| 533 |
+
"trained_survivals = []\n",
|
| 534 |
+
"N_EVAL = 10\n",
|
| 535 |
+
"\n",
|
| 536 |
+
"for i in range(N_EVAL):\n",
|
| 537 |
+
" t = rollout(RevOpsEnv(crisis_every=3, seed=i + 1000))\n",
|
| 538 |
+
" mean_r = sum(t[\"rewards\"]) / max(len(t[\"rewards\"]), 1)\n",
|
| 539 |
+
" trained_rewards.append(mean_r)\n",
|
| 540 |
+
" trained_mrrs.append(t[\"final_mrr\"])\n",
|
| 541 |
+
" trained_survivals.append(1 if t[\"survived\"] else 0)\n",
|
| 542 |
+
" print(f\" Episode {i+1:2d} | mean_reward={mean_r:.4f} | \"\n",
|
| 543 |
+
" f\"final_MRR=${t['final_mrr']:,.0f} | survived={t['survived']}\")\n",
|
| 544 |
+
"\n",
|
| 545 |
+
"print(f\"\\nTrained mean reward: {np.mean(trained_rewards):.4f} Β± {np.std(trained_rewards):.4f}\")\n",
|
| 546 |
+
"print(f\"Trained mean final MRR: ${np.mean(trained_mrrs):,.0f}\")\n",
|
| 547 |
+
"print(f\"Trained survival rate: {np.mean(trained_survivals)*100:.0f}%\")\n",
|
| 548 |
+
"\n",
|
| 549 |
+
"# Delta\n",
|
| 550 |
+
"print(f\"\\n{'='*60}\")\n",
|
| 551 |
+
"print(\"IMPROVEMENT SUMMARY\")\n",
|
| 552 |
+
"print(f\"{'='*60}\")\n",
|
| 553 |
+
"print(f\"Mean reward delta: {np.mean(trained_rewards) - np.mean(baseline_rewards):+.4f}\")\n",
|
| 554 |
+
"print(f\"Final MRR delta: ${np.mean(trained_mrrs) - np.mean(baseline_mrrs):+,.0f}\")\n",
|
| 555 |
+
"print(f\"Survival rate delta: {(np.mean(trained_survivals) - np.mean(baseline_survivals))*100:+.0f}%\")\n",
|
| 556 |
+
"\n",
|
| 557 |
+
"# ============================================================\n",
|
| 558 |
+
"# CELL 12 β Plot reward curves and save to repo\n",
|
| 559 |
+
"# ============================================================"
|
| 560 |
+
]
|
| 561 |
+
},
|
| 562 |
+
{
|
| 563 |
+
"cell_type": "code",
|
| 564 |
+
"execution_count": null,
|
| 565 |
+
"id": "4ae37220",
|
| 566 |
+
"metadata": {
|
| 567 |
+
"lines_to_next_cell": 0
|
| 568 |
+
},
|
| 569 |
+
"outputs": [],
|
| 570 |
+
"source": [
|
| 571 |
+
"fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
|
| 572 |
+
"fig.suptitle(\"RevOps Gym β Training Results\", fontsize=14, fontweight=\"bold\")\n",
|
| 573 |
+
"\n",
|
| 574 |
+
"# Reward comparison\n",
|
| 575 |
+
"ax = axes[0]\n",
|
| 576 |
+
"ax.bar([\"Baseline\", \"Trained\"],\n",
|
| 577 |
+
" [np.mean(baseline_rewards), np.mean(trained_rewards)],\n",
|
| 578 |
+
" color=[\"#e74c3c\", \"#2ecc71\"], alpha=0.85, edgecolor=\"black\")\n",
|
| 579 |
+
"ax.errorbar([\"Baseline\", \"Trained\"],\n",
|
| 580 |
+
" [np.mean(baseline_rewards), np.mean(trained_rewards)],\n",
|
| 581 |
+
" yerr=[np.std(baseline_rewards), np.std(trained_rewards)],\n",
|
| 582 |
+
" fmt=\"none\", color=\"black\", capsize=5)\n",
|
| 583 |
+
"ax.set_title(\"Mean Episode Reward\")\n",
|
| 584 |
+
"ax.set_ylabel(\"Reward\")\n",
|
| 585 |
+
"ax.set_xlabel(\"Model\")\n",
|
| 586 |
+
"\n",
|
| 587 |
+
"# MRR comparison\n",
|
| 588 |
+
"ax2 = axes[1]\n",
|
| 589 |
+
"ax2.bar([\"Baseline\", \"Trained\"],\n",
|
| 590 |
+
" [np.mean(baseline_mrrs)/1000, np.mean(trained_mrrs)/1000],\n",
|
| 591 |
+
" color=[\"#e74c3c\", \"#2ecc71\"], alpha=0.85, edgecolor=\"black\")\n",
|
| 592 |
+
"ax2.set_title(\"Mean Final MRR\")\n",
|
| 593 |
+
"ax2.set_ylabel(\"MRR ($K)\")\n",
|
| 594 |
+
"ax2.set_xlabel(\"Model\")\n",
|
| 595 |
+
"ax2.axhline(y=20, color=\"orange\", linestyle=\"--\", label=\"VC floor ($20K)\")\n",
|
| 596 |
+
"ax2.legend()\n",
|
| 597 |
+
"\n",
|
| 598 |
+
"# Survival rate\n",
|
| 599 |
+
"ax3 = axes[2]\n",
|
| 600 |
+
"ax3.bar([\"Baseline\", \"Trained\"],\n",
|
| 601 |
+
" [np.mean(baseline_survivals)*100, np.mean(trained_survivals)*100],\n",
|
| 602 |
+
" color=[\"#e74c3c\", \"#2ecc71\"], alpha=0.85, edgecolor=\"black\")\n",
|
| 603 |
+
"ax3.set_title(\"Company Survival Rate\")\n",
|
| 604 |
+
"ax3.set_ylabel(\"Survival %\")\n",
|
| 605 |
+
"ax3.set_xlabel(\"Model\")\n",
|
| 606 |
+
"ax3.set_ylim(0, 110)\n",
|
| 607 |
+
"ax3.axhline(y=100, color=\"gray\", linestyle=\"--\", alpha=0.5)\n",
|
| 608 |
+
"\n",
|
| 609 |
+
"plt.tight_layout()\n",
|
| 610 |
+
"plt.savefig(\"results_comparison.png\", dpi=150, bbox_inches=\"tight\")\n",
|
| 611 |
+
"plt.show()\n",
|
| 612 |
+
"print(\"Plot saved as results_comparison.png β\")\n",
|
| 613 |
+
"\n",
|
| 614 |
+
"# Training loss plot (from trainer logs)\n",
|
| 615 |
+
"if hasattr(trainer.state, \"log_history\") and trainer.state.log_history:\n",
|
| 616 |
+
" losses = [x.get(\"loss\", None) for x in trainer.state.log_history if \"loss\" in x]\n",
|
| 617 |
+
" rewards_log = [x.get(\"reward\", None) for x in trainer.state.log_history if \"reward\" in x]\n",
|
| 618 |
+
"\n",
|
| 619 |
+
" fig2, (ax_l, ax_r) = plt.subplots(1, 2, figsize=(12, 4))\n",
|
| 620 |
+
" fig2.suptitle(\"RevOps Gym β Training Curves\", fontsize=13, fontweight=\"bold\")\n",
|
| 621 |
+
"\n",
|
| 622 |
+
" if losses:\n",
|
| 623 |
+
" ax_l.plot(losses, color=\"#3498db\", linewidth=1.5)\n",
|
| 624 |
+
" ax_l.set_title(\"Training Loss\")\n",
|
| 625 |
+
" ax_l.set_xlabel(\"Training step\")\n",
|
| 626 |
+
" ax_l.set_ylabel(\"Loss\")\n",
|
| 627 |
+
" ax_l.grid(alpha=0.3)\n",
|
| 628 |
+
"\n",
|
| 629 |
+
" if rewards_log:\n",
|
| 630 |
+
" ax_r.plot(rewards_log, color=\"#2ecc71\", linewidth=1.5)\n",
|
| 631 |
+
" ax_r.set_title(\"Training Reward\")\n",
|
| 632 |
+
" ax_r.set_xlabel(\"Training step\")\n",
|
| 633 |
+
" ax_r.set_ylabel(\"Mean reward\")\n",
|
| 634 |
+
" ax_r.grid(alpha=0.3)\n",
|
| 635 |
+
"\n",
|
| 636 |
+
" plt.tight_layout()\n",
|
| 637 |
+
" plt.savefig(\"training_curves.png\", dpi=150, bbox_inches=\"tight\")\n",
|
| 638 |
+
" plt.show()\n",
|
| 639 |
+
" print(\"Training curves saved as training_curves.png β\")\n",
|
| 640 |
+
"\n",
|
| 641 |
+
"# ============================================================\n",
|
| 642 |
+
"# CELL 13 β Save model and push to HF Hub\n",
|
| 643 |
+
"# ============================================================"
|
| 644 |
+
]
|
| 645 |
+
},
|
| 646 |
+
{
|
| 647 |
+
"cell_type": "code",
|
| 648 |
+
"execution_count": null,
|
| 649 |
+
"id": "d236cce6",
|
| 650 |
+
"metadata": {},
|
| 651 |
+
"outputs": [],
|
| 652 |
+
"source": [
|
| 653 |
+
"# Save with Unsloth's correct LoRA merge path (avoids 4bitβ16bit corruption)\n",
|
| 654 |
+
"model.save_pretrained_merged(\n",
|
| 655 |
+
" \"revops-gym-model\",\n",
|
| 656 |
+
" tokenizer,\n",
|
| 657 |
+
" save_method=\"lora\", # save adapters only (small, efficient)\n",
|
| 658 |
+
")\n",
|
| 659 |
+
"print(\"Model adapters saved to ./revops-gym-model β\")\n",
|
| 660 |
+
"\n",
|
| 661 |
+
"# Push to Hugging Face Hub\n",
|
| 662 |
+
"# from huggingface_hub import login\n",
|
| 663 |
+
"# login(token=\"hf_YOUR_TOKEN\")\n",
|
| 664 |
+
"# model.push_to_hub_merged(\"YOUR_HF_USERNAME/revops-gym-model\", tokenizer, save_method=\"lora\")\n",
|
| 665 |
+
"# print(\"Model pushed to HF Hub β\")\n",
|
| 666 |
+
"\n",
|
| 667 |
+
"# Copy result plots into the revops-gym repo for the README\n",
|
| 668 |
+
"!cp results_comparison.png revops-gym/\n",
|
| 669 |
+
"!cp training_curves.png revops-gym/\n",
|
| 670 |
+
"!cd revops-gym && git add results_comparison.png training_curves.png && git commit -m \"Add training result plots\" && git push\n",
|
| 671 |
+
"\n",
|
| 672 |
+
"print(\"\\nπ Training pipeline complete!\")\n",
|
| 673 |
+
"print(\"Next steps:\")\n",
|
| 674 |
+
"print(\" 1. Copy results_comparison.png and training_curves.png into your HF Space repo\")\n",
|
| 675 |
+
"print(\" 2. Embed them in README.md\")\n",
|
| 676 |
+
"print(\" 3. Push the trained model adapter to HF Hub\")\n",
|
| 677 |
+
"print(\" 4. Submit the HF Space URL via the Google Form\")"
|
| 678 |
+
]
|
| 679 |
+
}
|
| 680 |
+
],
|
| 681 |
+
"metadata": {
|
| 682 |
+
"jupytext": {
|
| 683 |
+
"cell_metadata_filter": "-all",
|
| 684 |
+
"main_language": "python",
|
| 685 |
+
"notebook_metadata_filter": "-all"
|
| 686 |
+
}
|
| 687 |
+
},
|
| 688 |
+
"nbformat": 4,
|
| 689 |
+
"nbformat_minor": 5
|
| 690 |
+
}
|