saravanatanjiro commited on
Commit
dfc5996
·
1 Parent(s): af6bbef

Migrate LLM pipeline to custom GRPO with robust rewards

Browse files

Replace the REINFORCE-style loop with grouped GRPO optimization and add decomposed reward telemetry, GRPO diagnostics, and UI controls so training is more stable, observable, and reproducible.

README.md CHANGED
@@ -5,18 +5,69 @@ colorFrom: blue
5
  colorTo: purple
6
  sdk: docker
7
  pinned: false
8
- short_description: Cloud Arena Mathematical Model RL Training
9
  ---
10
 
11
- # Cloud Arena — Mathematical Model RL Training
12
 
13
- Multi-objective cloud operations RL environment trained with **MaskablePPO**.
14
 
15
- This is the **Mathematical Model** (MLP + stable-baselines3), NOT the LLM model.
 
16
 
17
- ## Features
18
- - 125-dim observation space, 150 discrete actions
19
- - 6-phase curriculum learning
20
- - Action masking, fog-of-war, chaos events
21
- - Boss fight evaluation scenarios
22
- - Interactive training dashboard
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  colorTo: purple
6
  sdk: docker
7
  pinned: false
8
+ short_description: Cloud Arena RL with Custom GRPO
9
  ---
10
 
11
+ # Cloud Arena RL
12
 
13
+ This Space contains two independent RL systems:
14
 
15
+ - Mathematical model RL (`MaskablePPO` + MLP) for structured cloud-ops optimization.
16
+ - LLM RL using a **custom GRPO** training loop with LoRA adaptation.
17
 
18
+ ## LLM Algorithm
19
+
20
+ The LLM pipeline now uses Group Relative Policy Optimization (GRPO):
21
+
22
+ 1. For each state, sample `K` responses.
23
+ 2. Simulate each sampled action on an environment clone.
24
+ 3. Compute group-relative normalized advantage.
25
+ 4. Optimize a clipped policy objective with KL and entropy regularization.
26
+
27
+ Key implementation file:
28
+ - `cloud_arena/llm_training.py`
29
+
30
+ ## Reward Design
31
+
32
+ The LLM environment reward is decomposed into robust components:
33
+
34
+ - `cost_delta`: incentivize concrete savings.
35
+ - `risk`: reward lower operational risk.
36
+ - `reliability`: reward safe, stable outcomes.
37
+ - `action_quality`: valid action bonus, tool misuse and veto penalties.
38
+ - `anti_loop`: repeated-action and hesitation penalties.
39
+ - `terminal`: success bonus and failure penalties.
40
+
41
+ Safety protections:
42
+ - Semantic veto for production-like resources.
43
+ - Structural crash penalty on dependency-breaking stop/delete actions.
44
+ - Tool reward caps and repeated-action penalties to prevent farming loops.
45
+
46
+ Key environment file:
47
+ - `cloud_arena/llm_environment.py`
48
+
49
+ ## GRPO Runtime Optimizations
50
+
51
+ - LoRA fine-tuning over causal LMs.
52
+ - Gradient accumulation and gradient clipping.
53
+ - KL watchdog with adaptive KL coefficient.
54
+ - VRAM cleanup between model runs.
55
+ - Deterministic seeds for reproducible smoke checks.
56
+
57
+ ## Recommended Config (Smoke Benchmark)
58
+
59
+ - Iterations: `30`
60
+ - Steps per episode: `12`
61
+ - Group size: `4`
62
+ - Clip epsilon: `0.2`
63
+ - KL coefficient: `0.01`
64
+ - Entropy coefficient: `0.001`
65
+ - Max generation tokens: `80`
66
+ - Temperature: `0.7`
67
+
68
+ ## Validation Criteria
69
+
70
+ - Determinism: repeated runs with fixed seed show similar trends.
71
+ - Safety: veto/violation rates stay stable or improve.
72
+ - Learning: post-training reward exceeds pre-training baseline for at least one default model.
73
+ - Stability: no persistent NaNs, KL blowups, or reward collapse.
app.py CHANGED
@@ -54,7 +54,17 @@ def run_math_evaluation():
54
 
55
  # ── LLM Model Training ───────────────────────────────────────────────────────
56
 
57
- def run_llm_training(model_name, num_iterations, steps_per_episode):
 
 
 
 
 
 
 
 
 
 
58
  from cloud_arena.llm_training import train_llm
59
  try:
60
  iters = int(num_iterations)
@@ -62,11 +72,18 @@ def run_llm_training(model_name, num_iterations, steps_per_episode):
62
  model_name=model_name,
63
  num_iterations=iters,
64
  steps_per_episode=int(steps_per_episode),
 
 
 
 
 
 
65
  )
66
  delta = all_rewards[-1] - all_rewards[0]
67
  summary = (
68
- f"✅ LLM Training Complete\n"
69
  f"Model: {model_name}\n"
 
70
  f"Pre-training reward: {all_rewards[0]:+.3f}\n"
71
  f"Post-training reward: {all_rewards[-1]:+.3f}\n"
72
  f"Δ Change: {delta:+.3f}\n\n"
@@ -98,7 +115,7 @@ with gr.Blocks(title="Cloud Arena RL") as demo:
98
  eval_btn.click(run_math_evaluation, outputs=eval_output)
99
 
100
  with gr.Tab("🧠 LLM RL"):
101
- gr.Markdown("### Multi-Model RL Benchmark — REINFORCE + LoRA")
102
  gr.Markdown("> Comma-separate model names to benchmark multiple models sequentially")
103
  llm_model = gr.Textbox(
104
  value="unsloth/Qwen2.5-Math-7B-Instruct-bnb-4bit, unsloth/gemma-2b-it-bnb-4bit, unsloth/llama-3-8b-Instruct-bnb-4bit",
@@ -106,11 +123,30 @@ with gr.Blocks(title="Cloud Arena RL") as demo:
106
  )
107
  llm_iters = gr.Number(value=200, label="Training Iterations per Model")
108
  llm_steps = gr.Number(value=15, label="Steps per Episode")
 
 
 
 
 
 
109
  llm_btn = gr.Button("🚀 Start LLM Training", variant="primary")
110
  llm_output = gr.Textbox(label="Training Log", lines=15)
111
  llm_img = gr.Image(label="Results")
112
- llm_btn.click(run_llm_training, inputs=[llm_model, llm_iters, llm_steps],
113
- outputs=[llm_output, llm_img])
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  if __name__ == "__main__":
116
  demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Base())
 
54
 
55
  # ── LLM Model Training ───────────────────────────────────────────────────────
56
 
57
+ def run_llm_training(
58
+ model_name,
59
+ num_iterations,
60
+ steps_per_episode,
61
+ group_size,
62
+ clip_epsilon,
63
+ kl_coef,
64
+ entropy_coef,
65
+ max_gen_tokens,
66
+ temperature,
67
+ ):
68
  from cloud_arena.llm_training import train_llm
69
  try:
70
  iters = int(num_iterations)
 
72
  model_name=model_name,
73
  num_iterations=iters,
74
  steps_per_episode=int(steps_per_episode),
75
+ group_size=int(group_size),
76
+ clip_epsilon=float(clip_epsilon),
77
+ kl_coef=float(kl_coef),
78
+ entropy_coef=float(entropy_coef),
79
+ max_gen_tokens=int(max_gen_tokens),
80
+ temperature=float(temperature),
81
  )
82
  delta = all_rewards[-1] - all_rewards[0]
83
  summary = (
84
+ f"✅ LLM GRPO Training Complete\n"
85
  f"Model: {model_name}\n"
86
+ f"Algorithm: Custom GRPO\n"
87
  f"Pre-training reward: {all_rewards[0]:+.3f}\n"
88
  f"Post-training reward: {all_rewards[-1]:+.3f}\n"
89
  f"Δ Change: {delta:+.3f}\n\n"
 
115
  eval_btn.click(run_math_evaluation, outputs=eval_output)
116
 
117
  with gr.Tab("🧠 LLM RL"):
118
+ gr.Markdown("### Multi-Model RL Benchmark — Custom GRPO + LoRA")
119
  gr.Markdown("> Comma-separate model names to benchmark multiple models sequentially")
120
  llm_model = gr.Textbox(
121
  value="unsloth/Qwen2.5-Math-7B-Instruct-bnb-4bit, unsloth/gemma-2b-it-bnb-4bit, unsloth/llama-3-8b-Instruct-bnb-4bit",
 
123
  )
124
  llm_iters = gr.Number(value=200, label="Training Iterations per Model")
125
  llm_steps = gr.Number(value=15, label="Steps per Episode")
126
+ grpo_group = gr.Number(value=4, label="GRPO Group Size (K)")
127
+ grpo_clip = gr.Number(value=0.2, label="GRPO Clip Epsilon")
128
+ grpo_kl = gr.Number(value=0.01, label="KL Coefficient")
129
+ grpo_entropy = gr.Number(value=0.001, label="Entropy Coefficient")
130
+ grpo_tokens = gr.Number(value=80, label="Max Generation Tokens")
131
+ grpo_temp = gr.Number(value=0.7, label="Sampling Temperature")
132
  llm_btn = gr.Button("🚀 Start LLM Training", variant="primary")
133
  llm_output = gr.Textbox(label="Training Log", lines=15)
134
  llm_img = gr.Image(label="Results")
135
+ llm_btn.click(
136
+ run_llm_training,
137
+ inputs=[
138
+ llm_model,
139
+ llm_iters,
140
+ llm_steps,
141
+ grpo_group,
142
+ grpo_clip,
143
+ grpo_kl,
144
+ grpo_entropy,
145
+ grpo_tokens,
146
+ grpo_temp,
147
+ ],
148
+ outputs=[llm_output, llm_img],
149
+ )
150
 
151
  if __name__ == "__main__":
152
  demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Base())
cloud_arena/evaluation.py CHANGED
@@ -124,3 +124,88 @@ def run_boss_fights(model_path="./models/cloud_arena_final",
124
  boss_scores[s_id] = score
125
 
126
  return boss_scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  boss_scores[s_id] = score
125
 
126
  return boss_scores
127
+
128
+
129
+ def evaluate_llm_grpo(model, tokenizer, n_eval=20, steps_per_episode=15, seed=123):
130
+ """
131
+ Evaluate LLM policy quality on the FinOps environment using the same
132
+ ACTION parser logic as training.
133
+ """
134
+ import random
135
+ import torch
136
+
137
+ from cloud_arena.llm_environment import SB3Adapter
138
+ from cloud_arena.llm_training import extract_action_and_reasoning, format_prompt
139
+
140
+ random.seed(seed)
141
+ np.random.seed(seed)
142
+ if torch.cuda.is_available():
143
+ torch.cuda.manual_seed_all(seed)
144
+
145
+ env = SB3Adapter()
146
+ metrics = {
147
+ "episodes": n_eval,
148
+ "win_rate": 0.0,
149
+ "avg_savings_pct": 0.0,
150
+ "avg_episode_len": 0.0,
151
+ "safety_violation_rate": 0.0,
152
+ "action_distribution": {str(i): 0 for i in range(5)},
153
+ "avg_reward_components": {},
154
+ }
155
+
156
+ wins = 0
157
+ total_savings = 0.0
158
+ total_steps = 0
159
+ total_safety_violations = 0
160
+ reward_components_sum = {}
161
+ total_component_steps = 0
162
+
163
+ for _ in range(n_eval):
164
+ _, _ = env.reset()
165
+ done = False
166
+ step_count = 0
167
+ last_info = {}
168
+ while not done and step_count < steps_per_episode:
169
+ state_dict = env.core._get_internal_state()
170
+ prompt = format_prompt(state_dict)
171
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
172
+ input_ids = inputs["input_ids"].to(model.device)
173
+ attn_mask = inputs["attention_mask"].to(model.device)
174
+ with torch.no_grad():
175
+ out = model.generate(
176
+ input_ids=input_ids,
177
+ attention_mask=attn_mask,
178
+ max_new_tokens=80,
179
+ do_sample=False,
180
+ pad_token_id=tokenizer.pad_token_id,
181
+ )
182
+ response = tokenizer.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True)
183
+ action, _ = extract_action_and_reasoning(response)
184
+ metrics["action_distribution"][str(action)] += 1
185
+
186
+ _, _, terminated, truncated, info = env.step(action)
187
+ done = bool(terminated or truncated)
188
+ step_count += 1
189
+ last_info = info
190
+ total_safety_violations += int(info.get("safety_violation", 0))
191
+ rc = info.get("reward_components", {})
192
+ for k, v in rc.items():
193
+ reward_components_sum[k] = reward_components_sum.get(k, 0.0) + float(v)
194
+ total_component_steps += 1
195
+
196
+ wins += int(last_info.get("win", False))
197
+ total_savings += float(last_info.get("savings_pct", 0.0))
198
+ total_steps += step_count
199
+
200
+ total_actions = max(sum(metrics["action_distribution"].values()), 1)
201
+ metrics["action_distribution"] = {
202
+ k: round(v / total_actions, 4) for k, v in metrics["action_distribution"].items()
203
+ }
204
+ metrics["win_rate"] = round(wins / max(n_eval, 1), 4)
205
+ metrics["avg_savings_pct"] = round(total_savings / max(n_eval, 1), 3)
206
+ metrics["avg_episode_len"] = round(total_steps / max(n_eval, 1), 3)
207
+ metrics["safety_violation_rate"] = round(total_safety_violations / max(total_steps, 1), 4)
208
+ metrics["avg_reward_components"] = {
209
+ k: round(v / max(total_component_steps, 1), 4) for k, v in reward_components_sum.items()
210
+ }
211
+ return metrics
cloud_arena/llm_environment.py CHANGED
@@ -1,419 +1,341 @@
1
- # ============================================================
2
- # CELL 3 — Cloud FinOps Environment (Final Fixed Version)
3
- #
4
- # ALL loopholes closed:
5
- # 1. CHECK_DEPENDENCIES after cap → hesitation penalty (not 0.0)
6
- # This kills the "+0.200 every episode" passive policy
7
- # 2. W_HESITATION = 0.10 — strong enough to force action
8
- # 3. Win bonus +2.0 — rewards completing the goal, not just steps
9
- # 4. RESIZE guaranteed to reduce cost (uniform 0.40-0.65)
10
- # 5. MIN_DELETABLE_COST_RATIO = 0.35 — win is always reachable
11
- # 6. Stronger semantic veto — also catches high-dependency temp nodes
12
- # ============================================================
13
 
14
- import numpy as np
15
  import gymnasium as gym
 
16
  from gymnasium import spaces
17
- from enum import IntEnum
18
- import random
19
 
20
  random.seed(42)
21
  np.random.seed(42)
22
 
23
- # ─── Action Space ─────────────────────────────────────────────────────────────
24
 
25
  class Action(IntEnum):
26
- NOOP = 0
27
  CHECK_DEPENDENCIES = 1
28
- RESIZE = 2
29
- STOP = 3
30
- DELETE = 4
31
-
32
- NUM_ACTIONS = len(Action)
33
 
34
- # ─── Constants ────────────────────────────────────────────────────────────────
35
 
 
36
  N_RESOURCES = 6
37
  OBS_PER_RES = 5
38
- OBS_DIM = N_RESOURCES * OBS_PER_RES + 2 # = 32
39
 
40
  PROD_NAMES = [
41
- "storage-prod-db", "core-auth-router", "primary-k8s-master",
42
- "billing-db-01", "payment-gateway-prod", "prod-cache-redis",
43
- "prod-elb-frontend", "rds-prod-main", "main-api-prod",
44
- "prod-cosmos-db", "primary-gke-cluster", "prod-spanner-db",
45
  ]
46
  TEMP_NAMES = [
47
- "worker-node-temp", "test-frontend-ui", "sandbox-db-04",
48
- "batch-processor-temp", "dev-cache-redis", "temp-worker-88",
49
- "staging-api-v2", "dev-log-collector", "temp-ecs-task",
50
- "dev-gke-node", "test-bigquery-scratch", "sandbox-spanner-dev",
51
  ]
52
 
53
- # ─── Cloud Resource ───────────────────────────────────────────────────────────
54
 
55
  class CloudResource:
56
  def __init__(self, name, cost_per_hr, cpu_pct, dependencies, is_prod):
57
- self.name = name
58
- self.cost_per_hr = cost_per_hr
59
- self.cpu_pct = cpu_pct
60
  self.dependencies = dependencies
61
- self.is_prod = is_prod
62
- self.active = True
63
 
64
  def to_obs_array(self):
65
- return np.array([
66
- self.cpu_pct / 100.0,
67
- self.cost_per_hr / 5.0,
68
- len(self.dependencies) / 14.0,
69
- float(self.is_prod),
70
- float(self.active),
71
- ], dtype=np.float32)
72
-
 
 
73
 
74
- # ─── Resource Generator ───────────────────────────────────────────────────────
75
 
76
  def generate_resources(n=N_RESOURCES):
77
- """
78
- Creates a randomised pool of cloud resources per episode.
79
-
80
- Guarantees:
81
- 1. One production trap with low CPU (looks deletable but isn't)
82
- 2. Temp resources account for >= 35% of total cost
83
- so the 20% savings win condition is always reachable
84
- """
85
  resources = []
86
-
87
- # Guaranteed prod trap — low CPU makes it look safe to delete
88
  prod_name = random.choice(PROD_NAMES)
89
- resources.append(CloudResource(
90
- name = prod_name,
91
- cost_per_hr = round(random.uniform(0.5, 3.0), 2),
92
- cpu_pct = random.randint(2, 12),
93
- dependencies = random.sample(TEMP_NAMES, k=random.randint(2, 4)),
94
- is_prod = True,
95
- ))
96
-
97
- # Fill remaining slots with random mix
 
98
  for _ in range(n - 1):
99
- is_prod = random.random() < 0.30 # 30% chance prod
100
  name_pool = PROD_NAMES if is_prod else TEMP_NAMES
101
  dep_count = random.randint(1, 5) if is_prod else random.randint(0, 3)
102
- resources.append(CloudResource(
103
- name = random.choice(name_pool),
104
- cost_per_hr = round(random.uniform(0.8, 4.0), 2),
105
- cpu_pct = random.randint(1, 95),
106
- dependencies = random.sample(TEMP_NAMES, k=min(dep_count, len(TEMP_NAMES))),
107
- is_prod = is_prod,
108
- ))
109
-
110
- # ── Guarantee minimum deletable cost ratio ────────────────────────────
111
- # Raises temp resource costs until they represent >= 35% of total.
112
- # Without this guarantee, some episodes are mathematically unwinnable.
113
- MIN_RATIO = 0.35
114
- for _ in range(10): # iterate up to 10x to converge
115
- total = sum(r.cost_per_hr for r in resources)
116
  temp_total = sum(r.cost_per_hr for r in resources if not r.is_prod)
117
- if total > 0 and (temp_total / total) < MIN_RATIO:
118
  for r in resources:
119
  if not r.is_prod:
120
  r.cost_per_hr = round(r.cost_per_hr * 1.3, 2)
121
  else:
122
  break
123
-
124
  return resources
125
 
126
 
127
- # ─── Core Environment (OpenEnv dict API) ─────────────────────────────────────
128
-
129
  class AWSCostEnv:
130
- """
131
- Cloud FinOps Optimisation Environment — OpenEnv dict API.
132
- Wrap with SB3Adapter for stable-baselines3 PPO training.
133
-
134
- REWARD FORMULA
135
- --------------
136
- Savings : clip(delta_cost_pct × W_SAVINGS, -5, +5)
137
- Win bonus: +W_WIN_BONUS when savings >= target (one-time)
138
- NOOP : -W_HESITATION per step
139
- Tool : +W_TOOL per new node checked (capped at W_TOOL_EPISODE_CAP)
140
- After cap -W_HESITATION (closes passive policy loophole)
141
- Veto : PENALTY_VETO (semantic guardrail blocked the action)
142
- Crash : PENALTY_CRASH, episode ends immediately
143
-
144
- KEY LOOPHOLE FIXES
145
- ------------------
146
- Fix 1 — CHECK after cap returns -W_HESITATION not 0.0
147
- Prevents "+0.200 every episode" passive exploit
148
- Fix 2 — RESIZE guaranteed to reduce cost (0.40-0.65 multiplier)
149
- Prevents zero-saving resize farming
150
- Fix 3 — Tool cap resets every episode via reset()
151
- Fix 4 — Semantic veto also catches high-dependency temp nodes
152
- Fix 5 — Min deletable ratio guarantee makes win always reachable
153
- """
154
-
155
- # ── Reward weights (do not change without updating Cell 4 too) ──────────
156
- W_SAVINGS = 20.0
157
- W_HESITATION = 0.10 # raised: strong enough to force decisive action
158
- W_TOOL = 0.20
159
- W_TOOL_EPISODE_CAP = 0.60 # max tool reward per episode (3 uses)
160
- W_WIN_BONUS = 2.0 # one-time bonus for completing the goal
161
- PENALTY_CRASH = -10.0
162
- PENALTY_VETO = -0.50
163
- MAX_STEPS = 100
164
 
165
  def __init__(self, n_resources=N_RESOURCES, target_savings=0.20):
166
- self.n_resources = n_resources
167
  self.target_savings = target_savings
168
- self.resources = []
169
- self.baseline_cost = 0.0
170
- self.current_cost = 0.0
171
- self.current_step = 0
172
  self.nodes_investigated_this_episode = set()
173
- self.total_tool_reward_this_episode = 0.0
174
-
175
- # ── Private helpers ─────────────────────────────���────────────────────────
 
176
 
177
  def _resource_from_action(self, action_idx):
178
  idx = (action_idx - 2) % self.n_resources
179
  return self.resources[idx % len(self.resources)]
180
 
181
  def _has_dependency_violation(self, resource):
182
- """True if deleting this resource breaks any other active resource."""
183
  for other in self.resources:
184
- if other.active and other.name != resource.name:
185
- if resource.name in other.dependencies:
186
- return True
187
  return False
188
 
189
  def _calc_cost(self):
190
  return sum(r.cost_per_hr for r in self.resources if r.active)
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  def _get_obs(self):
193
  obs = []
194
  for r in self.resources:
195
  obs.extend(r.to_obs_array())
196
- budget_used = (
197
- 1.0 - (self.current_cost / self.baseline_cost)
198
- if self.baseline_cost > 0 else 0.0
199
- )
200
  steps_left = 1.0 - (self.current_step / self.MAX_STEPS)
201
  obs.extend([budget_used, steps_left])
202
  return np.array(obs, dtype=np.float32)
203
 
204
  def _get_internal_state(self):
205
- """Human-readable state dict for OpenEnv /state endpoint."""
206
  return {
207
- "step": self.current_step,
208
  "baseline_cost": self.baseline_cost,
209
- "current_cost": self.current_cost,
210
- "savings_pct": round(
211
- (1 - self.current_cost / self.baseline_cost) * 100, 2
212
- ) if self.baseline_cost > 0 else 0.0,
213
- "resources": [{
214
- "name": r.name,
215
- "active": r.active,
216
- "is_prod": r.is_prod,
217
- "cost_per_hr": r.cost_per_hr,
218
- "cpu_pct": r.cpu_pct,
219
- "dependencies": r.dependencies,
220
- } for r in self.resources]
 
221
  }
222
 
223
- def _semantic_veto(self, name: str, dep_count: int) -> bool:
224
- """
225
- Semantic guardrail — returns True if action should be blocked.
226
-
227
- Two veto triggers:
228
- 1. Name contains production keywords (primary check)
229
- 2. High dependency count on any resource (structural safety net)
230
- Even temp-named nodes with 5+ deps get vetoed
231
- This catches the edge case that caused the -31.800 crash
232
-
233
- In production: replace with call to fine-tuned Llama inference endpoint.
234
- """
235
- name_lower = name.lower()
236
- prod_keywords = [
237
- "prod", "primary", "main", "core",
238
- "billing", "payment", "rds", "master"
239
- ]
240
- # Primary: semantic name check
241
- if any(kw in name_lower for kw in prod_keywords):
242
- return True
243
- # Secondary: structural safety net — high deps = critical regardless of name
244
- if dep_count >= 5:
245
- return True
246
- return False
247
 
248
- # ── Lifecycle ─────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
  def reset(self):
251
- """Reset environment for a new episode. Returns OpenEnv dict."""
252
- self.current_step = 0
253
  self.nodes_investigated_this_episode = set()
254
- self.total_tool_reward_this_episode = 0.0
255
- self.resources = generate_resources(self.n_resources)
256
- self.baseline_cost = self._calc_cost()
257
- self.current_cost = self.baseline_cost
 
 
 
258
  return {
259
  "observation": self._get_obs(),
260
- "info": {
261
- "msg": "Episode reset",
262
- "baseline_cost": self.baseline_cost,
263
- }
264
  }
265
 
266
  def step(self, action):
267
- """
268
- Execute one environment step.
269
-
270
- Args:
271
- action : int, one of Action enum values (0-4)
272
-
273
- Returns:
274
- dict with keys: observation, state, reward, done, info
275
- """
276
  self.current_step += 1
277
  truncated = self.current_step >= self.MAX_STEPS
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
- # ── 1. NOOP — hesitation penalty ──────────────────────────────────
280
  if action == Action.NOOP:
281
- return {
282
- "observation": self._get_obs(),
283
- "state": self._get_internal_state(),
284
- "reward": float(-self.W_HESITATION),
285
- "done": bool(truncated),
286
- "info": {"msg": "Hesitation penalty", "win": False,
287
- "savings_pct": round(
288
- (1 - self.current_cost / self.baseline_cost) * 100, 2)}
289
- }
290
 
291
  target = self._resource_from_action(action)
292
 
293
- # ── 2. CHECK_DEPENDENCIES ─────────────────────────────────────────
294
- # LOOPHOLE FIX: After cap is reached, return hesitation penalty
295
- # instead of 0.0. This kills the passive "+0.200 every episode" policy.
296
  if action == Action.CHECK_DEPENDENCIES:
297
- under_cap = self.total_tool_reward_this_episode < self.W_TOOL_EPISODE_CAP
298
- new_node = target.name not in self.nodes_investigated_this_episode
299
-
300
  if new_node and under_cap:
301
- # Valid tool use — reward it
302
  self.nodes_investigated_this_episode.add(target.name)
303
  self.total_tool_reward_this_episode += self.W_TOOL
304
- tool_reward = self.W_TOOL
305
  msg = f"Checked {target.name}"
306
  else:
307
- # Cap reached or node already checked — penalise like NOOP
308
- tool_reward = -self.W_HESITATION
309
- msg = "Tool cap reached — penalised"
310
-
311
- return {
312
- "observation": self._get_obs(),
313
- "state": self._get_internal_state(),
314
- "reward": float(tool_reward),
315
- "done": bool(truncated),
316
- "info": {"msg": msg, "win": False,
317
- "savings_pct": round(
318
- (1 - self.current_cost / self.baseline_cost) * 100, 2)}
319
- }
320
-
321
- # ── 3. SEMANTIC + STRUCTURAL GUARDRAIL ────────────────────────────
322
- # Blocks dangerous actions using name keywords AND dependency count.
323
- # Dependency count fix closes the edge case that caused -31.800 crash.
324
  danger = action in (Action.STOP, Action.DELETE)
325
  if danger and self._semantic_veto(target.name, len(target.dependencies)):
326
- return {
327
- "observation": self._get_obs(),
328
- "state": self._get_internal_state(),
329
- "reward": float(self.PENALTY_VETO),
330
- "done": bool(truncated),
331
- "info": {"msg": f"SEMANTIC VETO on {target.name}",
332
- "win": False,
333
- "savings_pct": round(
334
- (1 - self.current_cost / self.baseline_cost) * 100, 2)}
335
- }
336
-
337
- # ── 4. EXECUTE ACTION ─────────────────────────────────────────────
338
- prev_cost = self.current_cost
339
-
340
- if action == Action.RESIZE:
341
- if target.active:
342
- old_cost = target.cost_per_hr
343
- # LOOPHOLE FIX: 0.40-0.65 multiplier guarantees meaningful reduction
344
- target.cost_per_hr = round(
345
- target.cost_per_hr * random.uniform(0.40, 0.65), 2
346
- )
347
- # Extra safety: if somehow no reduction, penalise
348
- if target.cost_per_hr >= old_cost:
349
- target.cost_per_hr = round(old_cost * 0.50, 2)
350
 
351
  elif action in (Action.STOP, Action.DELETE):
352
- # ── 5. STRUCTURAL DEPENDENCY CHECK ────────────────────────────
353
  if self._has_dependency_violation(target):
354
- return {
355
- "observation": self._get_obs(),
356
- "state": self._get_internal_state(),
357
- "reward": float(self.PENALTY_CRASH),
358
- "done": True,
359
- "info": {
360
- "msg": f"CATASTROPHIC FAILURE: {target.name}",
361
- "win": False,
362
- "savings_pct": round(
363
- (1 - self.current_cost / self.baseline_cost) * 100, 2)
364
- }
365
- }
366
  target.active = False
 
367
 
368
- # ── 6. FINANCIAL REWARD ───────────────────────────────────────────
369
  self.current_cost = self._calc_cost()
370
- delta_pct = (prev_cost - self.current_cost) / self.baseline_cost
371
- savings_reward = float(np.clip(delta_pct * self.W_SAVINGS, -5.0, 5.0))
372
-
373
- # ── 7. WIN CONDITION + BONUS ──────────────────────────────────────
374
- total_saved = (
375
- (self.baseline_cost - self.current_cost) / self.baseline_cost
376
- )
377
  is_win = total_saved >= self.target_savings
378
 
379
- # One-time win bonus — rewards completing the goal
380
- if is_win:
381
- savings_reward += self.W_WIN_BONUS
382
 
383
- is_done = bool(is_win or truncated)
 
 
384
 
385
- return {
386
- "observation": self._get_obs(),
387
- "state": self._get_internal_state(),
388
- "reward": savings_reward,
389
- "done": is_done,
390
- "info": {
391
- "msg": "Win!" if is_win else "Action Successful",
392
- "win": is_win,
393
- "savings_pct": round(total_saved * 100, 2),
394
- }
395
- }
396
 
 
 
 
397
 
398
- # ─── SB3 Adapter (Gymnasium wrapper for PPO) ─────────────────────────────────
399
 
400
  class SB3Adapter(gym.Env):
401
- """
402
- Wraps AWSCostEnv (OpenEnv dict API) into the Gymnasium 5-tuple API
403
- that stable-baselines3 PPO expects.
404
-
405
- terminated = agent achieved the savings target (win)
406
- truncated = MAX_STEPS reached without winning
407
- """
408
  metadata = {"render_modes": []}
409
 
410
  def __init__(self):
411
  super().__init__()
412
  self.core = AWSCostEnv()
413
  self.action_space = spaces.Discrete(NUM_ACTIONS)
414
- self.observation_space = spaces.Box(
415
- low=-np.inf, high=np.inf, shape=(OBS_DIM,), dtype=np.float32
416
- )
417
 
418
  def reset(self, seed=None, options=None):
419
  super().reset(seed=seed)
@@ -421,16 +343,10 @@ class SB3Adapter(gym.Env):
421
  return result["observation"], result["info"]
422
 
423
  def step(self, action):
424
- result = self.core.step(action)
425
  terminated = result["done"] and result["info"].get("win", False)
426
- truncated = result["done"] and not result["info"].get("win", False)
427
- return (
428
- result["observation"],
429
- result["reward"],
430
- terminated,
431
- truncated,
432
- result["info"],
433
- )
434
 
435
  def render(self):
436
  pass
 
1
+ import random
2
+ from enum import IntEnum
 
 
 
 
 
 
 
 
 
 
3
 
 
4
  import gymnasium as gym
5
+ import numpy as np
6
  from gymnasium import spaces
 
 
7
 
8
  random.seed(42)
9
  np.random.seed(42)
10
 
 
11
 
12
  class Action(IntEnum):
13
+ NOOP = 0
14
  CHECK_DEPENDENCIES = 1
15
+ RESIZE = 2
16
+ STOP = 3
17
+ DELETE = 4
 
 
18
 
 
19
 
20
+ NUM_ACTIONS = len(Action)
21
  N_RESOURCES = 6
22
  OBS_PER_RES = 5
23
+ OBS_DIM = N_RESOURCES * OBS_PER_RES + 2
24
 
25
  PROD_NAMES = [
26
+ "storage-prod-db", "core-auth-router", "primary-k8s-master",
27
+ "billing-db-01", "payment-gateway-prod", "prod-cache-redis",
28
+ "prod-elb-frontend", "rds-prod-main", "main-api-prod",
29
+ "prod-cosmos-db", "primary-gke-cluster", "prod-spanner-db",
30
  ]
31
  TEMP_NAMES = [
32
+ "worker-node-temp", "test-frontend-ui", "sandbox-db-04",
33
+ "batch-processor-temp", "dev-cache-redis", "temp-worker-88",
34
+ "staging-api-v2", "dev-log-collector", "temp-ecs-task",
35
+ "dev-gke-node", "test-bigquery-scratch", "sandbox-spanner-dev",
36
  ]
37
 
 
38
 
39
  class CloudResource:
40
  def __init__(self, name, cost_per_hr, cpu_pct, dependencies, is_prod):
41
+ self.name = name
42
+ self.cost_per_hr = cost_per_hr
43
+ self.cpu_pct = cpu_pct
44
  self.dependencies = dependencies
45
+ self.is_prod = is_prod
46
+ self.active = True
47
 
48
  def to_obs_array(self):
49
+ return np.array(
50
+ [
51
+ self.cpu_pct / 100.0,
52
+ self.cost_per_hr / 5.0,
53
+ len(self.dependencies) / 14.0,
54
+ float(self.is_prod),
55
+ float(self.active),
56
+ ],
57
+ dtype=np.float32,
58
+ )
59
 
 
60
 
61
  def generate_resources(n=N_RESOURCES):
 
 
 
 
 
 
 
 
62
  resources = []
 
 
63
  prod_name = random.choice(PROD_NAMES)
64
+ resources.append(
65
+ CloudResource(
66
+ name=prod_name,
67
+ cost_per_hr=round(random.uniform(0.5, 3.0), 2),
68
+ cpu_pct=random.randint(2, 12),
69
+ dependencies=random.sample(TEMP_NAMES, k=random.randint(2, 4)),
70
+ is_prod=True,
71
+ )
72
+ )
73
+
74
  for _ in range(n - 1):
75
+ is_prod = random.random() < 0.30
76
  name_pool = PROD_NAMES if is_prod else TEMP_NAMES
77
  dep_count = random.randint(1, 5) if is_prod else random.randint(0, 3)
78
+ resources.append(
79
+ CloudResource(
80
+ name=random.choice(name_pool),
81
+ cost_per_hr=round(random.uniform(0.8, 4.0), 2),
82
+ cpu_pct=random.randint(1, 95),
83
+ dependencies=random.sample(TEMP_NAMES, k=min(dep_count, len(TEMP_NAMES))),
84
+ is_prod=is_prod,
85
+ )
86
+ )
87
+
88
+ min_ratio = 0.35
89
+ for _ in range(10):
90
+ total = sum(r.cost_per_hr for r in resources)
 
91
  temp_total = sum(r.cost_per_hr for r in resources if not r.is_prod)
92
+ if total > 0 and (temp_total / total) < min_ratio:
93
  for r in resources:
94
  if not r.is_prod:
95
  r.cost_per_hr = round(r.cost_per_hr * 1.3, 2)
96
  else:
97
  break
 
98
  return resources
99
 
100
 
 
 
101
  class AWSCostEnv:
102
+ W_COST = 18.0
103
+ W_RISK = 5.0
104
+ W_RELIABILITY = 3.5
105
+ W_VALID_ACTION = 0.2
106
+ W_WIN_BONUS = 2.5
107
+ W_FAIL_PENALTY = -3.0
108
+ W_REPEAT_ACTION = -0.06
109
+ W_HESITATION = -0.10
110
+ W_TOOL = 0.20
111
+ W_TOOL_EPISODE_CAP = 0.60
112
+ W_VETO = -0.70
113
+ W_CRASH = -10.0
114
+ W_IDLE = -0.08
115
+ MAX_STEPS = 100
116
+ MAX_COMPONENT_ABS = 5.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  def __init__(self, n_resources=N_RESOURCES, target_savings=0.20):
119
+ self.n_resources = n_resources
120
  self.target_savings = target_savings
121
+ self.resources = []
122
+ self.baseline_cost = 0.0
123
+ self.current_cost = 0.0
124
+ self.current_step = 0
125
  self.nodes_investigated_this_episode = set()
126
+ self.total_tool_reward_this_episode = 0.0
127
+ self.action_history = []
128
+ self.last_action = None
129
+ self.same_action_streak = 0
130
 
131
  def _resource_from_action(self, action_idx):
132
  idx = (action_idx - 2) % self.n_resources
133
  return self.resources[idx % len(self.resources)]
134
 
135
  def _has_dependency_violation(self, resource):
 
136
  for other in self.resources:
137
+ if other.active and other.name != resource.name and resource.name in other.dependencies:
138
+ return True
 
139
  return False
140
 
141
  def _calc_cost(self):
142
  return sum(r.cost_per_hr for r in self.resources if r.active)
143
 
144
+ def _active_resources(self):
145
+ return [r for r in self.resources if r.active]
146
+
147
+ def _risk_score(self):
148
+ active = self._active_resources()
149
+ if not active:
150
+ return 0.0
151
+ risky = sum(1 for r in active if (r.is_prod or len(r.dependencies) >= 4))
152
+ return risky / len(active)
153
+
154
+ def _reliability_score(self):
155
+ active = self._active_resources()
156
+ if not active:
157
+ return 0.0
158
+ healthy = sum(1 for r in active if len(r.dependencies) < 5)
159
+ return healthy / len(active)
160
+
161
+ def _semantic_veto(self, name: str, dep_count: int) -> bool:
162
+ name_lower = name.lower()
163
+ prod_keywords = ["prod", "primary", "main", "core", "billing", "payment", "rds", "master"]
164
+ if any(kw in name_lower for kw in prod_keywords):
165
+ return True
166
+ if dep_count >= 5:
167
+ return True
168
+ return False
169
+
170
  def _get_obs(self):
171
  obs = []
172
  for r in self.resources:
173
  obs.extend(r.to_obs_array())
174
+ budget_used = 1.0 - (self.current_cost / self.baseline_cost) if self.baseline_cost > 0 else 0.0
 
 
 
175
  steps_left = 1.0 - (self.current_step / self.MAX_STEPS)
176
  obs.extend([budget_used, steps_left])
177
  return np.array(obs, dtype=np.float32)
178
 
179
  def _get_internal_state(self):
180
+ savings_pct = (1 - self.current_cost / self.baseline_cost) * 100 if self.baseline_cost > 0 else 0.0
181
  return {
182
+ "step": self.current_step,
183
  "baseline_cost": self.baseline_cost,
184
+ "current_cost": self.current_cost,
185
+ "savings_pct": round(savings_pct, 2),
186
+ "resources": [
187
+ {
188
+ "name": r.name,
189
+ "active": r.active,
190
+ "is_prod": r.is_prod,
191
+ "cost_per_hr": r.cost_per_hr,
192
+ "cpu_pct": r.cpu_pct,
193
+ "dependencies": r.dependencies,
194
+ }
195
+ for r in self.resources
196
+ ],
197
  }
198
 
199
+ def _clip_component(self, value):
200
+ return float(np.clip(value, -self.MAX_COMPONENT_ABS, self.MAX_COMPONENT_ABS))
201
+
202
+ def _update_repeat_penalty(self, action):
203
+ if self.last_action == action:
204
+ self.same_action_streak += 1
205
+ else:
206
+ self.same_action_streak = 0
207
+ self.last_action = action
208
+ return self._clip_component(self.same_action_streak * self.W_REPEAT_ACTION)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
+ def _build_step_result(self, reward_components, done, win, msg, safety_violation=False):
211
+ total_reward = float(sum(reward_components.values()))
212
+ savings_pct = round((1 - self.current_cost / self.baseline_cost) * 100, 2) if self.baseline_cost > 0 else 0.0
213
+ return {
214
+ "observation": self._get_obs(),
215
+ "state": self._get_internal_state(),
216
+ "reward": total_reward,
217
+ "done": bool(done),
218
+ "info": {
219
+ "msg": msg,
220
+ "win": bool(win),
221
+ "savings_pct": savings_pct,
222
+ "safety_violation": int(safety_violation),
223
+ "reward_components": reward_components,
224
+ },
225
+ }
226
 
227
  def reset(self):
228
+ self.current_step = 0
 
229
  self.nodes_investigated_this_episode = set()
230
+ self.total_tool_reward_this_episode = 0.0
231
+ self.action_history = []
232
+ self.last_action = None
233
+ self.same_action_streak = 0
234
+ self.resources = generate_resources(self.n_resources)
235
+ self.baseline_cost = self._calc_cost()
236
+ self.current_cost = self.baseline_cost
237
  return {
238
  "observation": self._get_obs(),
239
+ "info": {"msg": "Episode reset", "baseline_cost": self.baseline_cost},
 
 
 
240
  }
241
 
242
  def step(self, action):
 
 
 
 
 
 
 
 
 
243
  self.current_step += 1
244
  truncated = self.current_step >= self.MAX_STEPS
245
+ self.action_history.append(int(action))
246
+ reward_components = {
247
+ "cost_delta": 0.0,
248
+ "risk": 0.0,
249
+ "reliability": 0.0,
250
+ "action_quality": 0.0,
251
+ "terminal": 0.0,
252
+ "anti_loop": self._update_repeat_penalty(int(action)),
253
+ }
254
+
255
+ prev_cost = self.current_cost
256
+ prev_risk = self._risk_score()
257
+ prev_reliability = self._reliability_score()
258
 
 
259
  if action == Action.NOOP:
260
+ reward_components["action_quality"] += self._clip_component(self.W_HESITATION)
261
+ reward_components["anti_loop"] += self._clip_component(self.W_IDLE)
262
+ return self._build_step_result(
263
+ reward_components, truncated, False, "Hesitation penalty"
264
+ )
 
 
 
 
265
 
266
  target = self._resource_from_action(action)
267
 
 
 
 
268
  if action == Action.CHECK_DEPENDENCIES:
269
+ under_cap = self.total_tool_reward_this_episode < self.W_TOOL_EPISODE_CAP
270
+ new_node = target.name not in self.nodes_investigated_this_episode
 
271
  if new_node and under_cap:
 
272
  self.nodes_investigated_this_episode.add(target.name)
273
  self.total_tool_reward_this_episode += self.W_TOOL
274
+ reward_components["action_quality"] += self._clip_component(self.W_TOOL)
275
  msg = f"Checked {target.name}"
276
  else:
277
+ reward_components["action_quality"] += self._clip_component(self.W_HESITATION)
278
+ msg = "Tool cap reached or repeated check"
279
+ return self._build_step_result(
280
+ reward_components, truncated, False, msg
281
+ )
282
+
 
 
 
 
 
 
 
 
 
 
 
283
  danger = action in (Action.STOP, Action.DELETE)
284
  if danger and self._semantic_veto(target.name, len(target.dependencies)):
285
+ reward_components["action_quality"] += self._clip_component(self.W_VETO)
286
+ return self._build_step_result(
287
+ reward_components, truncated, False, f"SEMANTIC VETO on {target.name}", safety_violation=True
288
+ )
289
+
290
+ if action == Action.RESIZE and target.active:
291
+ old_cost = target.cost_per_hr
292
+ target.cost_per_hr = round(target.cost_per_hr * random.uniform(0.40, 0.65), 2)
293
+ if target.cost_per_hr >= old_cost:
294
+ target.cost_per_hr = round(old_cost * 0.50, 2)
295
+ reward_components["action_quality"] += self._clip_component(self.W_VALID_ACTION)
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
  elif action in (Action.STOP, Action.DELETE):
 
298
  if self._has_dependency_violation(target):
299
+ reward_components["terminal"] += self._clip_component(self.W_CRASH)
300
+ return self._build_step_result(
301
+ reward_components, True, False, f"CATASTROPHIC FAILURE: {target.name}", safety_violation=True
302
+ )
 
 
 
 
 
 
 
 
303
  target.active = False
304
+ reward_components["action_quality"] += self._clip_component(self.W_VALID_ACTION)
305
 
 
306
  self.current_cost = self._calc_cost()
307
+ total_saved = ((self.baseline_cost - self.current_cost) / self.baseline_cost) if self.baseline_cost > 0 else 0.0
 
 
 
 
 
 
308
  is_win = total_saved >= self.target_savings
309
 
310
+ new_risk = self._risk_score()
311
+ new_reliability = self._reliability_score()
 
312
 
313
+ cost_delta_pct = (prev_cost - self.current_cost) / self.baseline_cost if self.baseline_cost > 0 else 0.0
314
+ risk_improvement = prev_risk - new_risk
315
+ reliability_improvement = new_reliability - prev_reliability
316
 
317
+ reward_components["cost_delta"] += self._clip_component(cost_delta_pct * self.W_COST)
318
+ reward_components["risk"] += self._clip_component(risk_improvement * self.W_RISK)
319
+ reward_components["reliability"] += self._clip_component(reliability_improvement * self.W_RELIABILITY)
320
+
321
+ if is_win:
322
+ reward_components["terminal"] += self._clip_component(self.W_WIN_BONUS)
323
+ elif truncated:
324
+ reward_components["terminal"] += self._clip_component(self.W_FAIL_PENALTY)
 
 
 
325
 
326
+ done = bool(is_win or truncated)
327
+ msg = "Win!" if is_win else "Action Successful"
328
+ return self._build_step_result(reward_components, done, is_win, msg)
329
 
 
330
 
331
  class SB3Adapter(gym.Env):
 
 
 
 
 
 
 
332
  metadata = {"render_modes": []}
333
 
334
  def __init__(self):
335
  super().__init__()
336
  self.core = AWSCostEnv()
337
  self.action_space = spaces.Discrete(NUM_ACTIONS)
338
+ self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(OBS_DIM,), dtype=np.float32)
 
 
339
 
340
  def reset(self, seed=None, options=None):
341
  super().reset(seed=seed)
 
343
  return result["observation"], result["info"]
344
 
345
  def step(self, action):
346
+ result = self.core.step(action)
347
  terminated = result["done"] and result["info"].get("win", False)
348
+ truncated = result["done"] and not result["info"].get("win", False)
349
+ return (result["observation"], result["reward"], terminated, truncated, result["info"])
 
 
 
 
 
 
350
 
351
  def render(self):
352
  pass
cloud_arena/llm_training.py CHANGED
@@ -1,23 +1,24 @@
1
- # ============================================================
2
- # Multi-Model RL Benchmarking Pipeline
3
- # Sequential training of multiple LLMs with VRAM cleanup
4
- # REINFORCE + LoRA on Cloud FinOps Environment
5
- # ============================================================
 
 
 
6
 
7
- import os, re, json, time, gc
8
  import numpy as np
9
  import torch
10
  import torch.nn.functional as F
11
- import matplotlib
 
 
 
12
  matplotlib.use("Agg")
13
- import matplotlib.pyplot as plt
14
- import warnings
15
  warnings.filterwarnings("ignore", category=UserWarning)
16
  warnings.filterwarnings("ignore", category=FutureWarning)
17
 
18
- from cloud_arena.llm_environment import SB3Adapter, Action, AWSCostEnv
19
-
20
- # ── Configuration ─────────────────────────────────────────────────────────────
21
 
22
  MODELS_TO_TEST = [
23
  "unsloth/Qwen2.5-Math-7B-Instruct-bnb-4bit",
@@ -27,13 +28,27 @@ MODELS_TO_TEST = [
27
 
28
  ACTION_NAMES = {0: "NOOP", 1: "CHECK_DEPS", 2: "RESIZE", 3: "STOP", 4: "DELETE"}
29
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
- GRAD_ACCUM_STEPS = 4
31
  MAX_SEQ_LEN = 512
32
- MAX_GEN_TOKENS = 80
33
- EMA_ALPHA = 0.3 # EMA smoothing factor for reward graph
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # ── Prompt & Parser ───────────────────────────────────────────────────────────
37
 
38
  def format_prompt(state_dict):
39
  resources_text = ""
@@ -42,327 +57,360 @@ def format_prompt(state_dict):
42
  tag = "PROD" if r["is_prod"] else "TEMP"
43
  resources_text += (
44
  f" - {r['name']} [{status}] ({tag}): "
45
- f"${r['cost_per_hr']:.2f}/hr, CPU={r['cpu_pct']}%, "
46
- f"Deps={len(r['dependencies'])}\n"
47
  )
48
  savings_pct = state_dict.get("savings_pct", 0.0)
49
  return (
50
- f"You are a Cloud FinOps AI. Reduce cost by >=20% without breaking production.\n\n"
51
- f"Actions: 0=NOOP, 1=CHECK_DEPS, 2=RESIZE, 3=STOP, 4=DELETE\n\n"
 
52
  f"Resources:\n{resources_text}\n"
53
  f"Baseline: ${state_dict['baseline_cost']:.2f}/hr | "
54
  f"Current: ${state_dict['current_cost']:.2f}/hr | Savings: {savings_pct:.1f}%\n\n"
55
- f"Rules:\n- Never delete/stop prod resources or those with >=5 deps\n"
56
- f"- Temp resources with 0-1 deps are safe to delete\n- RESIZE is always safe\n\n"
57
- f"CRITICAL: Output ONLY a brief reason then ACTION: <number 0-4>. Nothing else.\n\n"
58
- f"REASONING:"
 
 
 
59
  )
60
 
61
 
62
  def extract_action_and_reasoning(response_text):
63
- """Regex safety net: extracts action even from truncated/malformed output."""
64
  reasoning = response_text.strip()
65
- action = 2 # Default RESIZE
66
-
67
- match = re.search(r'ACTION:\s*([0-4])', response_text, re.IGNORECASE)
68
  if match:
69
  return int(match.group(1)), reasoning
70
-
71
- json_match = re.search(r'\{.*?\}', response_text, re.DOTALL)
72
- if json_match:
73
- try:
74
- parsed = json.loads(json_match.group(0))
75
- a = parsed.get("action", 2)
76
- if isinstance(a, int) and 0 <= a <= 4:
77
- return a, reasoning
78
- except (json.JSONDecodeError, ValueError):
79
- pass
80
-
81
- digits = re.findall(r'\b([0-4])\b', response_text[-30:])
82
  if digits:
83
  action = int(digits[-1])
84
  return action, reasoning
85
 
86
 
87
- # ── REINFORCE Loss ────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- def compute_pg_loss(model, tokenizer, prompt, response_text, reward):
90
  full_text = prompt + response_text
91
  enc = tokenizer(full_text, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN).to(DEVICE)
92
  prompt_len = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN)["input_ids"].shape[1]
93
-
94
- outputs = model(**enc, labels=enc["input_ids"])
95
- logits = outputs.logits[:, prompt_len-1:-1, :]
96
  targets = enc["input_ids"][:, prompt_len:]
97
-
98
  if targets.shape[1] == 0 or logits.shape[1] == 0:
99
- return 0.0
100
-
101
- ml = min(logits.shape[1], targets.shape[1])
102
- log_probs = F.log_softmax(logits[:, :ml, :], dim=-1)
103
- token_lp = log_probs.gather(2, targets[:, :ml].unsqueeze(-1)).squeeze(-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- loss = -(reward / 10.0) * token_lp.mean()
106
- loss.backward()
107
- return loss.item()
108
 
 
 
 
 
109
 
110
- # ── Episode Runner ────────────────────────────────────────────────────────────
111
 
112
- def run_episode(model, tokenizer, env, is_training=False, optimizer=None,
113
- steps_per_episode=15, iteration_num=0, total_iters=0):
114
  obs, info = env.reset()
115
- state_dict = env.core._get_internal_state()
116
  done = False
117
- episode_reward = 0.0
118
  step_count = 0
119
- reasoning_log = []
120
-
121
- if is_training and optimizer is not None:
122
- optimizer.zero_grad()
123
-
124
- while not done and step_count < steps_per_episode:
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  prompt = format_prompt(state_dict)
126
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN)
127
- input_ids = inputs["input_ids"].to(DEVICE)
128
- attn_mask = inputs["attention_mask"].to(DEVICE)
129
-
130
- with torch.no_grad():
131
- gen = model.generate(
132
- input_ids, attention_mask=attn_mask,
133
- max_new_tokens=MAX_GEN_TOKENS,
134
- do_sample=True, temperature=0.7, top_p=0.95,
135
- pad_token_id=tokenizer.pad_token_id,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  )
137
 
138
- response_text = tokenizer.decode(gen[0][input_ids.shape[1]:], skip_special_tokens=True)
139
- action, reasoning = extract_action_and_reasoning(response_text)
140
-
141
- next_obs, reward, terminated, truncated, next_info = env.step(action)
142
- done = terminated or truncated
143
- episode_reward += reward
144
-
145
- # ── Detailed per-step terminal output ──
146
- if is_training and total_iters > 0:
147
- pct = (iteration_num / total_iters) * 100
148
- print(f" [{pct:5.1f}%] Ep {iteration_num} Step {step_count+1}: "
149
- f"{ACTION_NAMES.get(action,'?')} → r={reward:+.3f} | "
150
- f"💬 {reasoning[:80]}")
151
-
152
- reasoning_log.append({
153
- "step": step_count + 1, "action": action,
154
- "action_name": ACTION_NAMES.get(action, "?"),
155
- "reward": round(reward, 4),
156
- "reasoning": reasoning[:200],
157
- "message": next_info.get("msg", ""),
158
- })
159
-
160
- if is_training and optimizer is not None:
161
- compute_pg_loss(model, tokenizer, prompt, response_text, reward)
162
-
163
- obs = next_obs
164
- state_dict = env.core._get_internal_state()
165
  step_count += 1
166
 
167
- return episode_reward, reasoning_log
168
-
169
-
170
- # ── VRAM Cleanup ──────────────────────────────────────────────────────────────
171
-
172
- def nuke_vram(model=None, optimizer=None, tokenizer=None):
173
- """Aggressively free VRAM between model runs."""
174
- if model is not None:
175
- del model
176
- if optimizer is not None:
177
- del optimizer
178
- if tokenizer is not None:
179
- del tokenizer
180
- gc.collect()
181
- if torch.cuda.is_available():
182
- torch.cuda.empty_cache()
183
- torch.cuda.synchronize()
184
- vram = torch.cuda.memory_allocated() / 1e9
185
- print(f" 🧹 VRAM after cleanup: {vram:.2f} GB")
186
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
- # ── Single Model Training ────────────────────────────────────────────────────
189
 
190
- def train_single_model(model_name, num_iterations=200, steps_per_episode=15,
191
- learning_rate=2e-6):
192
- """Train one model, return rewards list."""
193
- hf_token = os.environ.get("HF_TOKEN")
194
  from transformers import AutoModelForCausalLM, AutoTokenizer
195
- from peft import get_peft_model, LoraConfig, TaskType
196
 
 
197
  short_name = model_name.split("/")[-1]
198
- print(f"\n{'='*60}")
199
- print(f" 🧠 Loading: {short_name}")
200
- print(f"{'='*60}")
201
 
 
202
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
203
- model = AutoModelForCausalLM.from_pretrained(
204
- model_name, torch_dtype=torch.bfloat16, token=hf_token,
 
 
 
 
 
205
  attn_implementation="sdpa",
206
  ).to(DEVICE)
 
 
 
 
 
 
 
 
 
207
 
208
  lora_cfg = LoraConfig(
209
- r=16, lora_alpha=16,
 
210
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
211
- lora_dropout=0.0, bias="none", task_type=TaskType.CAUSAL_LM,
212
- )
213
- model = get_peft_model(model, lora_cfg)
214
- if tokenizer.pad_token is None:
215
- tokenizer.pad_token = tokenizer.eos_token
216
-
217
- trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
218
- total = sum(p.numel() for p in model.parameters())
219
- print(f" ✅ Loaded | Trainable: {trainable:,} / {total:,}")
220
-
221
- if torch.cuda.is_available():
222
- vram = torch.cuda.memory_allocated() / 1e9
223
- print(f" 📊 VRAM used: {vram:.2f} GB")
224
-
225
- optimizer = torch.optim.AdamW(
226
- filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate
227
  )
 
 
228
  env = SB3Adapter()
229
- all_rewards = []
230
 
231
- # Pre-training eval
232
- print(f"\n ▶ PRE-TRAINING EVAL")
 
233
  model.eval()
234
- pre_r, _ = run_episode(model, tokenizer, env, steps_per_episode=steps_per_episode)
235
- all_rewards.append(pre_r)
236
- print(f" Baseline reward: {pre_r:+.3f}")
 
237
 
238
- # Training loop
239
- print(f"\n ▶ TRAINING ({num_iterations} iters, accum={GRAD_ACCUM_STEPS})")
240
  model.train()
241
- t0 = time.time()
242
-
243
- for i in range(num_iterations):
244
- reward, log_data = run_episode(
245
- model, tokenizer, env, is_training=True, optimizer=optimizer,
246
- steps_per_episode=steps_per_episode,
247
- iteration_num=i+1, total_iters=num_iterations,
248
- )
249
  all_rewards.append(reward)
250
-
251
- if (i + 1) % GRAD_ACCUM_STEPS == 0:
252
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
253
- optimizer.step()
254
- optimizer.zero_grad()
255
-
256
- # Per-iteration summary
257
- pct = ((i+1) / num_iterations) * 100
258
- elapsed = time.time() - t0
259
- eta = (elapsed / (i+1)) * (num_iterations - i - 1)
260
- ema = all_rewards[-1] if len(all_rewards) < 3 else (
261
- EMA_ALPHA * all_rewards[-1] + (1 - EMA_ALPHA) * all_rewards[-2]
 
 
 
262
  )
263
- print(f" ┃ [{pct:5.1f}%] Iter {i+1:3d}/{num_iterations} │ "
264
- f"r={reward:+.3f} │ EMA={ema:+.3f} │ "
265
- f"ETA={eta:.0f}s")
266
-
267
- # Final grad step
268
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
269
- optimizer.step()
270
 
271
- # Post-training eval
272
- print(f"\n ▶ POST-TRAINING EVAL")
273
  model.eval()
274
- post_r, _ = run_episode(model, tokenizer, env, steps_per_episode=steps_per_episode)
275
- all_rewards.append(post_r)
276
- delta = post_r - pre_r
277
- print(f" Final reward: {post_r:+.3f} (Δ={delta:+.3f})")
278
- print(f" Time: {time.time()-t0:.0f}s")
279
-
280
- # Cleanup VRAM
281
- nuke_vram(model, optimizer, tokenizer)
282
-
283
- return all_rewards
284
-
285
-
286
- # ── EMA Graph ─────────────────────────────────────────────────────────────────
287
-
288
- def compute_ema(rewards, alpha=EMA_ALPHA):
289
- ema = [rewards[0]]
290
- for r in rewards[1:]:
291
- ema.append(alpha * r + (1 - alpha) * ema[-1])
292
- return ema
293
-
294
-
295
- def generate_comparison_graph(all_results, output_path="outputs/multi_model_comparison.png"):
296
- BG = '#0e1117'
297
- COLORS = ['#00d4ff', '#ffa500', '#39ff14', '#ff6b6b', '#b47eff']
298
-
299
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 7), facecolor=BG)
300
- for ax in [ax1, ax2]:
301
- ax.set_facecolor(BG)
302
- ax.tick_params(colors='#e6e6e6', labelsize=10)
303
- ax.grid(True, alpha=0.08, color='white')
304
- for s in ['top', 'right']:
305
- ax.spines[s].set_visible(False)
306
- for s in ['left', 'bottom']:
307
- ax.spines[s].set_color('#333')
308
-
309
- # Left: EMA reward curves
310
- for idx, (name, rewards) in enumerate(all_results.items()):
311
- color = COLORS[idx % len(COLORS)]
312
- ema = compute_ema(rewards)
313
- ax1.plot(ema, color=color, lw=2.5, label=name, alpha=0.9)
314
- ax1.plot(rewards, color=color, lw=0.5, alpha=0.2)
315
-
316
- ax1.set_title("Training Reward (EMA Smoothed)", color='#e6e6e6', fontsize=14, fontweight='bold')
317
- ax1.set_xlabel("Episode", color='#e6e6e6', fontsize=11)
318
- ax1.set_ylabel("Reward", color='#e6e6e6', fontsize=11)
319
- ax1.legend(facecolor='#1a1a2e', edgecolor='#333', labelcolor='#e6e6e6', fontsize=9)
320
- ax1.axhline(y=0, color='gray', linestyle='--', alpha=0.3)
321
-
322
- # Right: Before vs After comparison bars
323
- names = list(all_results.keys())
324
- pre_scores = [all_results[n][0] for n in names]
325
- post_scores = [all_results[n][-1] for n in names]
326
-
327
- x = np.arange(len(names))
328
- w = 0.35
329
- bars1 = ax2.bar(x - w/2, pre_scores, w, label='Before', color='#ef4444', edgecolor='white', lw=1)
330
- bars2 = ax2.bar(x + w/2, post_scores, w, label='After', color='#22c55e', edgecolor='white', lw=1)
331
-
332
- ax2.set_xticks(x)
333
- ax2.set_xticklabels(names, fontsize=8, color='#e6e6e6', rotation=15)
334
- ax2.set_title("Pre vs Post Training", color='#e6e6e6', fontsize=14, fontweight='bold')
335
- ax2.set_ylabel("Reward", color='#e6e6e6', fontsize=11)
336
- ax2.legend(facecolor='#1a1a2e', edgecolor='#333', labelcolor='#e6e6e6')
337
- ax2.axhline(y=0, color='gray', linestyle='--', alpha=0.3)
338
-
339
- for bar, val in zip(list(bars1) + list(bars2), pre_scores + post_scores):
340
- ax2.text(bar.get_x() + bar.get_width()/2, val + 0.1,
341
- f"{val:+.1f}", ha='center', va='bottom', fontsize=9,
342
- color='#e6e6e6', fontweight='bold')
343
-
344
- plt.tight_layout()
345
- plt.savefig(output_path, dpi=200, bbox_inches='tight', facecolor=BG)
346
- plt.close()
347
- return output_path
348
-
349
-
350
- # ── Main Pipeline ──────────────���──────────────────────────────────────────────
351
-
352
- def train_llm(model_name=None, num_iterations=200, steps_per_episode=15,
353
- learning_rate=2e-6, progress_callback=None):
354
- """
355
- Multi-model or single-model training pipeline.
356
- If model_name contains commas, runs multi-model benchmark.
357
- """
358
  log_lines = []
 
359
  def log(msg):
360
  print(msg)
361
  log_lines.append(msg)
362
  if progress_callback:
363
  progress_callback("\n".join(log_lines))
364
 
365
- # Determine model list
366
  if model_name and "," in model_name:
367
  models = [m.strip() for m in model_name.split(",")]
368
  elif model_name:
@@ -370,53 +418,49 @@ def train_llm(model_name=None, num_iterations=200, steps_per_episode=15,
370
  else:
371
  models = MODELS_TO_TEST
372
 
373
- log(f"🖥️ Device: {DEVICE}")
374
- log(f"🔁 Models to test: {len(models)}")
375
- for m in models:
376
- log(f" • {m}")
377
-
378
  all_results = {}
 
379
  full_log = []
380
 
381
- for model_idx, mname in enumerate(models):
382
  short = mname.split("/")[-1]
383
- log(f"\n{''*60}")
384
- log(f" [{model_idx+1}/{len(models)}] {short}")
385
- log(f"{'━'*60}")
386
-
 
 
 
 
 
 
 
 
 
387
  try:
388
- rewards = train_single_model(
389
- mname, num_iterations=num_iterations,
390
- steps_per_episode=steps_per_episode,
391
- learning_rate=learning_rate,
392
- )
393
  all_results[short] = rewards
 
394
  delta = rewards[-1] - rewards[0]
395
- log(f" {short}: Pre={rewards[0]:+.3f} → Post={rewards[-1]:+.3f} ={delta:+.3f})")
396
- full_log.append({
397
- "model": mname, "pre": rewards[0], "post": rewards[-1],
398
- "delta": delta, "all_rewards": rewards,
399
- })
400
  except Exception as e:
401
- log(f" {short} FAILED: {e}")
402
  full_log.append({"model": mname, "error": str(e)})
403
- nuke_vram() # cleanup even on failure
404
 
405
- # Generate comparison graph
406
  graph_path = None
407
  if all_results:
408
- os.makedirs("outputs", exist_ok=True)
409
- graph_path = generate_comparison_graph(all_results)
410
- log(f"\n📊 Comparison graph saved to {graph_path}")
411
 
412
- # Save log
413
- with open("outputs/multi_model_log.json", "w") as f:
414
- json.dump(full_log, f, indent=2, default=str)
415
 
416
- # Build flat reward list for backward compat
417
  flat_rewards = []
418
  for rewards in all_results.values():
419
  flat_rewards.extend(rewards)
420
-
421
- log_text = "\n".join(log_lines)
422
- return flat_rewards or [0], full_log, graph_path, log_text
 
1
+ import copy
2
+ import gc
3
+ import json
4
+ import os
5
+ import re
6
+ import time
7
+ import warnings
8
+ from dataclasses import dataclass
9
 
10
+ import matplotlib
11
  import numpy as np
12
  import torch
13
  import torch.nn.functional as F
14
+
15
+ from cloud_arena.llm_environment import AWSCostEnv, SB3Adapter
16
+ from cloud_arena.visualization import generate_grpo_dashboard
17
+
18
  matplotlib.use("Agg")
 
 
19
  warnings.filterwarnings("ignore", category=UserWarning)
20
  warnings.filterwarnings("ignore", category=FutureWarning)
21
 
 
 
 
22
 
23
  MODELS_TO_TEST = [
24
  "unsloth/Qwen2.5-Math-7B-Instruct-bnb-4bit",
 
28
 
29
  ACTION_NAMES = {0: "NOOP", 1: "CHECK_DEPS", 2: "RESIZE", 3: "STOP", 4: "DELETE"}
30
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
31
  MAX_SEQ_LEN = 512
32
+ EMA_ALPHA = 0.3
33
+
34
+
35
+ @dataclass
36
+ class GRPOConfig:
37
+ num_iterations: int = 200
38
+ steps_per_episode: int = 15
39
+ group_size: int = 4
40
+ clip_epsilon: float = 0.2
41
+ kl_coef: float = 0.01
42
+ entropy_coef: float = 0.001
43
+ learning_rate: float = 2e-6
44
+ grad_accum_steps: int = 4
45
+ max_gen_tokens: int = 80
46
+ temperature: float = 0.7
47
+ top_p: float = 0.95
48
+ max_grad_norm: float = 1.0
49
+ seed: int = 42
50
+ target_kl: float = 0.12
51
 
 
52
 
53
  def format_prompt(state_dict):
54
  resources_text = ""
 
57
  tag = "PROD" if r["is_prod"] else "TEMP"
58
  resources_text += (
59
  f" - {r['name']} [{status}] ({tag}): "
60
+ f"${r['cost_per_hr']:.2f}/hr, CPU={r['cpu_pct']}%, Deps={len(r['dependencies'])}\n"
 
61
  )
62
  savings_pct = state_dict.get("savings_pct", 0.0)
63
  return (
64
+ "You are a Cloud FinOps AI.\n"
65
+ "Goal: Reduce cloud cost by >=20% while preserving safety and reliability.\n\n"
66
+ "Actions: 0=NOOP, 1=CHECK_DEPS, 2=RESIZE, 3=STOP, 4=DELETE\n\n"
67
  f"Resources:\n{resources_text}\n"
68
  f"Baseline: ${state_dict['baseline_cost']:.2f}/hr | "
69
  f"Current: ${state_dict['current_cost']:.2f}/hr | Savings: {savings_pct:.1f}%\n\n"
70
+ "Safety policy:\n"
71
+ "- Avoid deleting/stopping production-like or high dependency resources.\n"
72
+ "- Prefer low-risk actions that improve savings steadily.\n\n"
73
+ "Output format strictly:\n"
74
+ "Reason: <short>\n"
75
+ "ACTION: <number 0-4>\n\n"
76
+ "RESPONSE:"
77
  )
78
 
79
 
80
  def extract_action_and_reasoning(response_text):
 
81
  reasoning = response_text.strip()
82
+ action = 2
83
+ match = re.search(r"ACTION:\s*([0-4])", response_text, re.IGNORECASE)
 
84
  if match:
85
  return int(match.group(1)), reasoning
86
+ digits = re.findall(r"\b([0-4])\b", response_text[-30:])
 
 
 
 
 
 
 
 
 
 
 
87
  if digits:
88
  action = int(digits[-1])
89
  return action, reasoning
90
 
91
 
92
+ def seed_everything(seed):
93
+ np.random.seed(seed)
94
+ torch.manual_seed(seed)
95
+ if torch.cuda.is_available():
96
+ torch.cuda.manual_seed_all(seed)
97
+
98
+
99
+ def nuke_vram(model=None, optimizer=None, tokenizer=None, ref_model=None):
100
+ if model is not None:
101
+ del model
102
+ if optimizer is not None:
103
+ del optimizer
104
+ if tokenizer is not None:
105
+ del tokenizer
106
+ if ref_model is not None:
107
+ del ref_model
108
+ gc.collect()
109
+ if torch.cuda.is_available():
110
+ torch.cuda.empty_cache()
111
+ torch.cuda.synchronize()
112
+
113
 
114
+ def _completion_logprob(model, tokenizer, prompt, response_text):
115
  full_text = prompt + response_text
116
  enc = tokenizer(full_text, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN).to(DEVICE)
117
  prompt_len = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN)["input_ids"].shape[1]
118
+ outputs = model(**enc)
119
+ logits = outputs.logits[:, prompt_len - 1 : -1, :]
 
120
  targets = enc["input_ids"][:, prompt_len:]
 
121
  if targets.shape[1] == 0 or logits.shape[1] == 0:
122
+ z = torch.zeros(1, device=DEVICE)
123
+ return z, z, z
124
+ n_tokens = min(logits.shape[1], targets.shape[1])
125
+ log_probs = F.log_softmax(logits[:, :n_tokens, :], dim=-1)
126
+ probs = torch.softmax(logits[:, :n_tokens, :], dim=-1)
127
+ picked = log_probs.gather(2, targets[:, :n_tokens].unsqueeze(-1)).squeeze(-1)
128
+ token_logprob = picked.mean()
129
+ entropy = (-(probs * log_probs).sum(-1)).mean()
130
+ return token_logprob, entropy, torch.tensor(float(n_tokens), device=DEVICE)
131
+
132
+
133
+ def _sample_response(model, tokenizer, prompt, cfg):
134
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN)
135
+ input_ids = inputs["input_ids"].to(DEVICE)
136
+ attn_mask = inputs["attention_mask"].to(DEVICE)
137
+ with torch.no_grad():
138
+ out = model.generate(
139
+ input_ids=input_ids,
140
+ attention_mask=attn_mask,
141
+ max_new_tokens=cfg.max_gen_tokens,
142
+ do_sample=True,
143
+ temperature=cfg.temperature,
144
+ top_p=cfg.top_p,
145
+ pad_token_id=tokenizer.pad_token_id,
146
+ )
147
+ return tokenizer.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True)
148
 
 
 
 
149
 
150
+ def _evaluate_action_on_clone(core_env, action):
151
+ env_copy = copy.deepcopy(core_env)
152
+ result = env_copy.step(action)
153
+ return result
154
 
 
155
 
156
+ def run_grpo_episode(model, ref_model, tokenizer, env, cfg, optimizer=None, train_mode=False):
 
157
  obs, info = env.reset()
 
158
  done = False
 
159
  step_count = 0
160
+ episode_reward = 0.0
161
+ chosen_samples = []
162
+ stats = {
163
+ "wins": 0,
164
+ "veto_rate": 0.0,
165
+ "safety_violations": 0,
166
+ "avg_group_std": 0.0,
167
+ "avg_group_reward": 0.0,
168
+ "avg_token_len": 0.0,
169
+ "reward_components": {},
170
+ }
171
+ veto_count = 0
172
+ all_group_stds = []
173
+ all_group_rewards = []
174
+ token_lens = []
175
+ component_acc = {}
176
+
177
+ while not done and step_count < cfg.steps_per_episode:
178
+ state_dict = env.core._get_internal_state()
179
  prompt = format_prompt(state_dict)
180
+ group = []
181
+
182
+ for _ in range(cfg.group_size):
183
+ response_text = _sample_response(model, tokenizer, prompt, cfg)
184
+ action, reasoning = extract_action_and_reasoning(response_text)
185
+ sim_result = _evaluate_action_on_clone(env.core, action)
186
+ reward = float(sim_result["reward"])
187
+ info = sim_result["info"]
188
+ token_lp, _, token_len = _completion_logprob(model, tokenizer, prompt, response_text)
189
+ with torch.no_grad():
190
+ old_lp = token_lp.detach()
191
+ ref_lp, _, _ = _completion_logprob(ref_model, tokenizer, prompt, response_text)
192
+ group.append(
193
+ {
194
+ "prompt": prompt,
195
+ "response": response_text,
196
+ "action": action,
197
+ "reasoning": reasoning,
198
+ "reward": reward,
199
+ "old_logprob": old_lp,
200
+ "ref_logprob": ref_lp.detach(),
201
+ "token_len": float(token_len.item()),
202
+ "info": info,
203
+ }
204
  )
205
 
206
+ rewards = np.array([s["reward"] for s in group], dtype=np.float32)
207
+ baseline = float(rewards.mean())
208
+ std = float(rewards.std() + 1e-6)
209
+ for s in group:
210
+ s["advantage"] = float((s["reward"] - baseline) / std)
211
+
212
+ all_group_stds.append(std)
213
+ all_group_rewards.append(baseline)
214
+ best = max(group, key=lambda x: x["reward"])
215
+ chosen_samples.append(best)
216
+ token_lens.append(best["token_len"])
217
+
218
+ real_step = env.step(best["action"])
219
+ _, step_reward, terminated, truncated, step_info = real_step
220
+ done = bool(terminated or truncated)
221
+ episode_reward += float(step_reward)
222
+ veto_count += int(step_info.get("safety_violation", 0))
223
+ if step_info.get("win", False):
224
+ stats["wins"] += 1
225
+ if step_info.get("safety_violation", 0):
226
+ stats["safety_violations"] += 1
227
+ for k, v in step_info.get("reward_components", {}).items():
228
+ component_acc[k] = component_acc.get(k, 0.0) + float(v)
 
 
 
 
229
  step_count += 1
230
 
231
+ if train_mode and chosen_samples and optimizer is not None:
232
+ optimizer.zero_grad(set_to_none=True)
233
+ loss_sum = 0.0
234
+ kl_sum = 0.0
235
+ ent_sum = 0.0
236
+ clip_frac_count = 0
237
+
238
+ for i, sample in enumerate(chosen_samples, start=1):
239
+ new_lp, ent, _ = _completion_logprob(model, tokenizer, sample["prompt"], sample["response"])
240
+ ratio = torch.exp(new_lp - sample["old_logprob"])
241
+ clipped = torch.clamp(ratio, 1.0 - cfg.clip_epsilon, 1.0 + cfg.clip_epsilon)
242
+ adv = torch.tensor(sample["advantage"], device=DEVICE, dtype=torch.float32)
243
+ pg_loss = -torch.min(ratio * adv, clipped * adv)
244
+ kl = torch.clamp(new_lp - sample["ref_logprob"], min=-2.0, max=2.0)
245
+ total_loss = pg_loss + (cfg.kl_coef * kl) - (cfg.entropy_coef * ent)
246
+
247
+ if torch.isnan(total_loss) or torch.isinf(total_loss):
248
+ continue
249
+ (total_loss / cfg.grad_accum_steps).backward()
250
+ loss_sum += float(total_loss.detach().item())
251
+ kl_sum += float(kl.detach().item())
252
+ ent_sum += float(ent.detach().item())
253
+ clip_frac_count += int((ratio > (1.0 + cfg.clip_epsilon) or ratio < (1.0 - cfg.clip_epsilon)).item())
254
+
255
+ if i % cfg.grad_accum_steps == 0 or i == len(chosen_samples):
256
+ torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
257
+ optimizer.step()
258
+ optimizer.zero_grad(set_to_none=True)
259
+
260
+ n = max(len(chosen_samples), 1)
261
+ stats["loss"] = loss_sum / n
262
+ stats["kl"] = kl_sum / n
263
+ stats["entropy"] = ent_sum / n
264
+ stats["clip_frac"] = clip_frac_count / n
265
+ else:
266
+ stats["loss"] = 0.0
267
+ stats["kl"] = 0.0
268
+ stats["entropy"] = 0.0
269
+ stats["clip_frac"] = 0.0
270
+
271
+ stats["veto_rate"] = veto_count / max(step_count, 1)
272
+ stats["avg_group_std"] = float(np.mean(all_group_stds)) if all_group_stds else 0.0
273
+ stats["avg_group_reward"] = float(np.mean(all_group_rewards)) if all_group_rewards else 0.0
274
+ stats["avg_token_len"] = float(np.mean(token_lens)) if token_lens else 0.0
275
+ if step_count > 0:
276
+ stats["reward_components"] = {k: v / step_count for k, v in component_acc.items()}
277
+ else:
278
+ stats["reward_components"] = {}
279
+ return episode_reward, stats
280
 
 
281
 
282
+ def train_single_model_grpo(model_name, cfg):
283
+ from peft import LoraConfig, TaskType, get_peft_model
 
 
284
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
285
 
286
+ hf_token = os.environ.get("HF_TOKEN")
287
  short_name = model_name.split("/")[-1]
288
+ print(f"\n{'=' * 60}\n Loading: {short_name}\n{'=' * 60}")
 
 
289
 
290
+ seed_everything(cfg.seed)
291
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
292
+ if tokenizer.pad_token is None:
293
+ tokenizer.pad_token = tokenizer.eos_token
294
+
295
+ base_model = AutoModelForCausalLM.from_pretrained(
296
+ model_name,
297
+ torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
298
+ token=hf_token,
299
  attn_implementation="sdpa",
300
  ).to(DEVICE)
301
+ ref_model = AutoModelForCausalLM.from_pretrained(
302
+ model_name,
303
+ torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
304
+ token=hf_token,
305
+ attn_implementation="sdpa",
306
+ ).to(DEVICE)
307
+ ref_model.eval()
308
+ for p in ref_model.parameters():
309
+ p.requires_grad_(False)
310
 
311
  lora_cfg = LoraConfig(
312
+ r=16,
313
+ lora_alpha=16,
314
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
315
+ lora_dropout=0.0,
316
+ bias="none",
317
+ task_type=TaskType.CAUSAL_LM,
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  )
319
+ model = get_peft_model(base_model, lora_cfg)
320
+ optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.learning_rate)
321
  env = SB3Adapter()
 
322
 
323
+ all_rewards = []
324
+ iter_stats = []
325
+ print("\n PRE-TRAINING EVAL")
326
  model.eval()
327
+ pre_reward, pre_stats = run_grpo_episode(model, ref_model, tokenizer, env, cfg, train_mode=False)
328
+ all_rewards.append(pre_reward)
329
+ iter_stats.append(pre_stats)
330
+ print(f" Baseline reward: {pre_reward:+.3f}")
331
 
332
+ print(f"\n GRPO TRAINING ({cfg.num_iterations} iters, group={cfg.group_size})")
 
333
  model.train()
334
+ start = time.time()
335
+ for i in range(cfg.num_iterations):
336
+ reward, stats = run_grpo_episode(model, ref_model, tokenizer, env, cfg, optimizer=optimizer, train_mode=True)
 
 
 
 
 
337
  all_rewards.append(reward)
338
+ iter_stats.append(stats)
339
+
340
+ if stats["kl"] > cfg.target_kl * 1.5:
341
+ cfg.kl_coef = min(cfg.kl_coef * 1.15, 0.2)
342
+ elif stats["kl"] < cfg.target_kl * 0.5:
343
+ cfg.kl_coef = max(cfg.kl_coef * 0.95, 1e-4)
344
+
345
+ pct = ((i + 1) / cfg.num_iterations) * 100
346
+ elapsed = time.time() - start
347
+ eta = (elapsed / (i + 1)) * (cfg.num_iterations - i - 1)
348
+ ema = all_rewards[-1] if len(all_rewards) < 3 else (EMA_ALPHA * all_rewards[-1] + (1 - EMA_ALPHA) * all_rewards[-2])
349
+ print(
350
+ f" [{pct:5.1f}%] Iter {i+1:3d}/{cfg.num_iterations} | "
351
+ f"r={reward:+.3f} ema={ema:+.3f} loss={stats['loss']:+.4f} "
352
+ f"kl={stats['kl']:+.4f} ent={stats['entropy']:+.4f} eta={eta:.0f}s"
353
  )
 
 
 
 
 
 
 
354
 
355
+ print("\n POST-TRAINING EVAL")
 
356
  model.eval()
357
+ post_reward, post_stats = run_grpo_episode(model, ref_model, tokenizer, env, cfg, train_mode=False)
358
+ all_rewards.append(post_reward)
359
+ iter_stats.append(post_stats)
360
+ print(f" Final reward: {post_reward:+.3f} (Δ={post_reward - pre_reward:+.3f})")
361
+
362
+ nuke_vram(model, optimizer, tokenizer, ref_model=ref_model)
363
+ return all_rewards, iter_stats
364
+
365
+
366
+ def _build_grpo_config(
367
+ num_iterations=200,
368
+ steps_per_episode=15,
369
+ group_size=4,
370
+ clip_epsilon=0.2,
371
+ kl_coef=0.01,
372
+ entropy_coef=0.001,
373
+ learning_rate=2e-6,
374
+ max_gen_tokens=80,
375
+ temperature=0.7,
376
+ seed=42,
377
+ ):
378
+ return GRPOConfig(
379
+ num_iterations=int(num_iterations),
380
+ steps_per_episode=int(steps_per_episode),
381
+ group_size=max(2, int(group_size)),
382
+ clip_epsilon=float(clip_epsilon),
383
+ kl_coef=float(kl_coef),
384
+ entropy_coef=float(entropy_coef),
385
+ learning_rate=float(learning_rate),
386
+ max_gen_tokens=int(max_gen_tokens),
387
+ temperature=float(temperature),
388
+ seed=int(seed),
389
+ )
390
+
391
+
392
+ def train_llm(
393
+ model_name=None,
394
+ num_iterations=200,
395
+ steps_per_episode=15,
396
+ learning_rate=2e-6,
397
+ progress_callback=None,
398
+ group_size=4,
399
+ clip_epsilon=0.2,
400
+ kl_coef=0.01,
401
+ entropy_coef=0.001,
402
+ max_gen_tokens=80,
403
+ temperature=0.7,
404
+ seed=42,
405
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  log_lines = []
407
+
408
  def log(msg):
409
  print(msg)
410
  log_lines.append(msg)
411
  if progress_callback:
412
  progress_callback("\n".join(log_lines))
413
 
 
414
  if model_name and "," in model_name:
415
  models = [m.strip() for m in model_name.split(",")]
416
  elif model_name:
 
418
  else:
419
  models = MODELS_TO_TEST
420
 
421
+ log(f"Device: {DEVICE}")
422
+ log(f"Models: {len(models)}")
 
 
 
423
  all_results = {}
424
+ all_stats = {}
425
  full_log = []
426
 
427
+ for idx, mname in enumerate(models, start=1):
428
  short = mname.split("/")[-1]
429
+ log(f"\n{'-' * 58}\n[{idx}/{len(models)}] {short}\n{'-' * 58}")
430
+ cfg = _build_grpo_config(
431
+ num_iterations=num_iterations,
432
+ steps_per_episode=steps_per_episode,
433
+ group_size=group_size,
434
+ clip_epsilon=clip_epsilon,
435
+ kl_coef=kl_coef,
436
+ entropy_coef=entropy_coef,
437
+ learning_rate=learning_rate,
438
+ max_gen_tokens=max_gen_tokens,
439
+ temperature=temperature,
440
+ seed=seed + idx,
441
+ )
442
  try:
443
+ rewards, iter_stats = train_single_model_grpo(mname, cfg)
 
 
 
 
444
  all_results[short] = rewards
445
+ all_stats[short] = iter_stats
446
  delta = rewards[-1] - rewards[0]
447
+ log(f"GRPO complete: pre={rewards[0]:+.3f} post={rewards[-1]:+.3f} delta={delta:+.3f}")
448
+ full_log.append({"model": mname, "pre": rewards[0], "post": rewards[-1], "delta": delta, "rewards": rewards})
 
 
 
449
  except Exception as e:
450
+ log(f"FAILED: {short}: {e}")
451
  full_log.append({"model": mname, "error": str(e)})
452
+ nuke_vram()
453
 
454
+ os.makedirs("outputs", exist_ok=True)
455
  graph_path = None
456
  if all_results:
457
+ graph_path = generate_grpo_dashboard(all_results, all_stats, output_path="outputs/grpo_dashboard.png")
458
+ log(f"Dashboard saved: {graph_path}")
 
459
 
460
+ with open("outputs/multi_model_log.json", "w", encoding="utf-8") as f:
461
+ json.dump({"summary": full_log, "stats": all_stats}, f, indent=2, default=str)
 
462
 
 
463
  flat_rewards = []
464
  for rewards in all_results.values():
465
  flat_rewards.extend(rewards)
466
+ return flat_rewards or [0], full_log, graph_path, "\n".join(log_lines)
 
 
cloud_arena/visualization.py CHANGED
@@ -54,3 +54,54 @@ def generate_dashboard(callback, output_path="outputs/training_dashboard.png"):
54
  plt.savefig(output_path, dpi=200, bbox_inches='tight', facecolor=REF_BG)
55
  plt.close()
56
  return output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  plt.savefig(output_path, dpi=200, bbox_inches='tight', facecolor=REF_BG)
55
  plt.close()
56
  return output_path
57
+
58
+
59
+ def generate_grpo_dashboard(all_results, all_stats, output_path="outputs/grpo_dashboard.png"):
60
+ fig, axs = plt.subplots(2, 2, figsize=(16, 10), facecolor=REF_BG)
61
+ ax1, ax2, ax3, ax4 = axs.flatten()
62
+ for ax in [ax1, ax2, ax3, ax4]:
63
+ ax.set_facecolor(REF_BG)
64
+ ax.grid(True, alpha=0.08, color="white")
65
+ ax.spines["top"].set_visible(False)
66
+ ax.spines["right"].set_visible(False)
67
+ ax.spines["left"].set_color("#333333")
68
+ ax.spines["bottom"].set_color("#333333")
69
+ ax.tick_params(colors=TEXT_COLOR, labelsize=9)
70
+
71
+ palette = ["#00d4ff", "#ffa500", "#39ff14", "#ff6b6b", "#b47eff"]
72
+ model_names = list(all_results.keys())
73
+ for i, name in enumerate(model_names):
74
+ c = palette[i % len(palette)]
75
+ rewards = all_results[name]
76
+ ax1.plot(smooth(np.array(rewards), box_pts=min(20, max(3, len(rewards) // 5))), color=c, lw=2, label=name)
77
+
78
+ kl_curve = [s.get("kl", 0.0) for s in all_stats.get(name, [])]
79
+ ent_curve = [s.get("entropy", 0.0) for s in all_stats.get(name, [])]
80
+ veto_curve = [s.get("veto_rate", 0.0) for s in all_stats.get(name, [])]
81
+
82
+ ax2.plot(kl_curve, color=c, lw=1.8, label=name)
83
+ ax3.plot(ent_curve, color=c, lw=1.8, label=name)
84
+ ax4.plot(veto_curve, color=c, lw=1.8, label=name)
85
+
86
+ ax1.set_title("GRPO Reward (Smoothed)", color=TEXT_COLOR, fontsize=12, fontweight="bold")
87
+ ax1.set_xlabel("Episode", color=TEXT_COLOR)
88
+ ax1.set_ylabel("Reward", color=TEXT_COLOR)
89
+ ax1.legend(facecolor="#1a1a2e", edgecolor="#333", labelcolor=TEXT_COLOR, fontsize=8)
90
+
91
+ ax2.set_title("KL Trend", color=TEXT_COLOR, fontsize=12, fontweight="bold")
92
+ ax2.set_xlabel("Episode", color=TEXT_COLOR)
93
+ ax2.set_ylabel("KL", color=TEXT_COLOR)
94
+
95
+ ax3.set_title("Entropy Trend", color=TEXT_COLOR, fontsize=12, fontweight="bold")
96
+ ax3.set_xlabel("Episode", color=TEXT_COLOR)
97
+ ax3.set_ylabel("Entropy", color=TEXT_COLOR)
98
+
99
+ ax4.set_title("Safety Violation / Veto Rate", color=TEXT_COLOR, fontsize=12, fontweight="bold")
100
+ ax4.set_xlabel("Episode", color=TEXT_COLOR)
101
+ ax4.set_ylabel("Rate", color=TEXT_COLOR)
102
+ ax4.set_ylim(0, 1)
103
+
104
+ plt.tight_layout()
105
+ plt.savefig(output_path, dpi=200, bbox_inches="tight", facecolor=REF_BG)
106
+ plt.close()
107
+ return output_path