kavin57447 commited on
Commit
af6bbef
·
1 Parent(s): deef82c

Multi-model benchmark pipeline: VRAM cleanup + EMA graph + detailed output

Browse files
Files changed (3) hide show
  1. app.py +7 -4
  2. cloud_arena/llm_training.py +253 -140
  3. requirements.txt +1 -0
app.py CHANGED
@@ -98,10 +98,13 @@ with gr.Blocks(title="Cloud Arena RL") as demo:
98
  eval_btn.click(run_math_evaluation, outputs=eval_output)
99
 
100
  with gr.Tab("🧠 LLM RL"):
101
- gr.Markdown("### LLM Model Llama 3.1 8B + REINFORCE + LoRA")
102
- gr.Markdown("> ⚠️ Requires `HF_TOKEN` secret set in Space settings + accepted model license")
103
- llm_model = gr.Textbox(value="meta-llama/Llama-3.1-8B", label="Model Name")
104
- llm_iters = gr.Number(value=200, label="Training Iterations")
 
 
 
105
  llm_steps = gr.Number(value=15, label="Steps per Episode")
106
  llm_btn = gr.Button("🚀 Start LLM Training", variant="primary")
107
  llm_output = gr.Textbox(label="Training Log", lines=15)
 
98
  eval_btn.click(run_math_evaluation, outputs=eval_output)
99
 
100
  with gr.Tab("🧠 LLM RL"):
101
+ gr.Markdown("### Multi-Model RL Benchmark REINFORCE + LoRA")
102
+ gr.Markdown("> Comma-separate model names to benchmark multiple models sequentially")
103
+ llm_model = gr.Textbox(
104
+ value="unsloth/Qwen2.5-Math-7B-Instruct-bnb-4bit, unsloth/gemma-2b-it-bnb-4bit, unsloth/llama-3-8b-Instruct-bnb-4bit",
105
+ label="Model(s) — comma-separated for multi-model benchmark"
106
+ )
107
+ llm_iters = gr.Number(value=200, label="Training Iterations per Model")
108
  llm_steps = gr.Number(value=15, label="Steps per Episode")
109
  llm_btn = gr.Button("🚀 Start LLM Training", variant="primary")
110
  llm_output = gr.Textbox(label="Training Log", lines=15)
cloud_arena/llm_training.py CHANGED
@@ -1,59 +1,59 @@
1
  # ============================================================
2
- # LLM RL Training — Llama 3.1 8B + REINFORCE + LoRA
3
- # This is the LLM model, SEPARATE from the mathematical model.
4
- # Uses AWSCostEnv (llm_environment.py), NOT CloudArenaEnv.
5
  # ============================================================
6
 
7
- import os
8
- import re
9
- import json
10
- import time
11
- import warnings
12
  import numpy as np
13
  import torch
14
  import torch.nn.functional as F
15
  import matplotlib
16
  matplotlib.use("Agg")
17
  import matplotlib.pyplot as plt
18
-
19
  warnings.filterwarnings("ignore", category=UserWarning)
20
  warnings.filterwarnings("ignore", category=FutureWarning)
21
 
22
  from cloud_arena.llm_environment import SB3Adapter, Action, AWSCostEnv
23
 
24
- # ── Constants ────────────────────────────────────────────────────────────────
 
 
 
 
 
 
25
 
26
  ACTION_NAMES = {0: "NOOP", 1: "CHECK_DEPS", 2: "RESIZE", 3: "STOP", 4: "DELETE"}
27
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
28
 
29
- # ── GPU Optimization Constants ────────────────────────────────────────────────
30
- GRAD_ACCUM_STEPS = 4 # accumulate gradients over N episodes before stepping
31
- MAX_SEQ_LEN = 512 # shorter context = O(N²) attention is 4× faster than 1024
32
- MAX_GEN_TOKENS = 80 # enough room for reasoning + ACTION line, not enough to ramble
33
 
 
34
 
35
  def format_prompt(state_dict):
36
  resources_text = ""
37
  for r in state_dict["resources"]:
38
  status = "ACTIVE" if r["active"] else "STOPPED"
39
- tag = "PRODUCTION" if r["is_prod"] else "Temporary"
40
  resources_text += (
41
  f" - {r['name']} [{status}] ({tag}): "
42
- f"Cost=${r['cost_per_hr']:.2f}/hr, CPU={r['cpu_pct']}%, "
43
  f"Deps={len(r['dependencies'])}\n"
44
  )
45
  savings_pct = state_dict.get("savings_pct", 0.0)
46
  return (
47
- f"You are a Cloud FinOps AI. Reduce cloud cost by >=20% without breaking production.\n\n"
48
  f"Actions: 0=NOOP, 1=CHECK_DEPS, 2=RESIZE, 3=STOP, 4=DELETE\n\n"
49
  f"Resources:\n{resources_text}\n"
50
  f"Baseline: ${state_dict['baseline_cost']:.2f}/hr | "
51
- f"Current: ${state_dict['current_cost']:.2f}/hr | "
52
- f"Savings: {savings_pct:.1f}%\n\n"
53
- f"Rules:\n"
54
- f"- Never delete/stop prod resources or those with >=5 deps\n"
55
- f"- Temp resources with 0-1 deps are safe to delete\n"
56
- f"- RESIZE is always safe\n\n"
57
  f"CRITICAL: Output ONLY a brief reason then ACTION: <number 0-4>. Nothing else.\n\n"
58
  f"REASONING:"
59
  )
@@ -62,18 +62,15 @@ def format_prompt(state_dict):
62
  def extract_action_and_reasoning(response_text):
63
  """Regex safety net: extracts action even from truncated/malformed output."""
64
  reasoning = response_text.strip()
65
- action = 2 # Default to RESIZE (safest action)
66
 
67
- # Try explicit ACTION: N format first
68
- action_match = re.search(r'ACTION:\s*([0-4])', response_text, re.IGNORECASE)
69
- if action_match:
70
- return int(action_match.group(1)), reasoning
71
 
72
- # Try JSON format: {"action": N} or {"action": "DELETE"}
73
  json_match = re.search(r'\{.*?\}', response_text, re.DOTALL)
74
  if json_match:
75
  try:
76
- import json
77
  parsed = json.loads(json_match.group(0))
78
  a = parsed.get("action", 2)
79
  if isinstance(a, int) and 0 <= a <= 4:
@@ -81,53 +78,46 @@ def extract_action_and_reasoning(response_text):
81
  except (json.JSONDecodeError, ValueError):
82
  pass
83
 
84
- # Last resort: any digit 0-4 near the end
85
- digit_matches = re.findall(r'\b([0-4])\b', response_text[-30:])
86
- if digit_matches:
87
- action = int(digit_matches[-1])
88
-
89
  return action, reasoning
90
 
91
 
 
 
92
  def compute_pg_loss(model, tokenizer, prompt, response_text, reward):
93
- """Compute REINFORCE loss WITHOUT stepping optimizer (for gradient accumulation)."""
94
  full_text = prompt + response_text
95
- encodings = tokenizer(full_text, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN).to(DEVICE)
96
- prompt_encodings = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN)
97
- prompt_len = prompt_encodings["input_ids"].shape[1]
98
 
99
- outputs = model(**encodings, labels=encodings["input_ids"])
100
  logits = outputs.logits[:, prompt_len-1:-1, :]
101
- targets = encodings["input_ids"][:, prompt_len:]
102
 
103
  if targets.shape[1] == 0 or logits.shape[1] == 0:
104
  return 0.0
105
 
106
- min_len = min(logits.shape[1], targets.shape[1])
107
- logits = logits[:, :min_len, :]
108
- targets = targets[:, :min_len]
109
 
110
- log_probs = F.log_softmax(logits, dim=-1)
111
- token_log_probs = log_probs.gather(2, targets.unsqueeze(-1)).squeeze(-1)
112
- avg_log_prob = token_log_probs.mean()
113
-
114
- scaled_reward = reward / 10.0
115
- loss = -scaled_reward * avg_log_prob
116
- loss.backward() # accumulate gradient, don't step yet
117
  return loss.item()
118
 
119
 
 
 
120
  def run_episode(model, tokenizer, env, is_training=False, optimizer=None,
121
- steps_per_episode=15):
122
  obs, info = env.reset()
123
  state_dict = env.core._get_internal_state()
124
  done = False
125
  episode_reward = 0.0
126
  step_count = 0
127
  reasoning_log = []
128
- losses = []
129
 
130
- # Accumulate gradients across all steps in the episode
131
  if is_training and optimizer is not None:
132
  optimizer.zero_grad()
133
 
@@ -135,175 +125,298 @@ def run_episode(model, tokenizer, env, is_training=False, optimizer=None,
135
  prompt = format_prompt(state_dict)
136
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN)
137
  input_ids = inputs["input_ids"].to(DEVICE)
138
- attention_mask = inputs["attention_mask"].to(DEVICE)
139
 
140
  with torch.no_grad():
141
- gen_outputs = model.generate(
142
- input_ids, attention_mask=attention_mask,
143
  max_new_tokens=MAX_GEN_TOKENS,
144
  do_sample=True, temperature=0.7, top_p=0.95,
145
  pad_token_id=tokenizer.pad_token_id,
146
  )
147
 
148
- response_ids = gen_outputs[0][input_ids.shape[1]:]
149
- response_text = tokenizer.decode(response_ids, skip_special_tokens=True)
150
  action, reasoning = extract_action_and_reasoning(response_text)
151
 
152
  next_obs, reward, terminated, truncated, next_info = env.step(action)
153
  done = terminated or truncated
154
  episode_reward += reward
155
 
 
 
 
 
 
 
 
156
  reasoning_log.append({
157
- "step": step_count + 1,
158
- "reasoning": reasoning[:300],
159
- "action": action,
160
- "action_name": ACTION_NAMES.get(action, "UNKNOWN"),
161
  "reward": round(reward, 4),
 
162
  "message": next_info.get("msg", ""),
163
  })
164
 
165
  if is_training and optimizer is not None:
166
- loss = compute_pg_loss(model, tokenizer, prompt, response_text, reward)
167
- losses.append(loss)
168
 
169
  obs = next_obs
170
  state_dict = env.core._get_internal_state()
171
  step_count += 1
172
 
173
- return episode_reward, reasoning_log, losses
174
 
175
 
176
- def train_llm(model_name="meta-llama/Llama-3.1-8B",
177
- num_iterations=200, steps_per_episode=15, learning_rate=2e-6,
178
- progress_callback=None):
179
- """
180
- Full LLM RL training pipeline. Returns (all_rewards, full_log, graph_path).
181
- """
182
- hf_token = os.environ.get("HF_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
183
 
 
 
 
 
184
  from transformers import AutoModelForCausalLM, AutoTokenizer
185
  from peft import get_peft_model, LoraConfig, TaskType
186
 
187
- log_lines = []
188
- def log(msg):
189
- print(msg)
190
- log_lines.append(msg)
191
- if progress_callback:
192
- progress_callback("\n".join(log_lines))
193
-
194
- log(f"🖥️ Device: {DEVICE}")
195
- log(f"🧠 Model: {model_name}")
196
- log(f"🔁 Iterations: {num_iterations}")
197
- log("📦 Loading model and tokenizer...")
198
 
199
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
200
  model = AutoModelForCausalLM.from_pretrained(
201
  model_name, torch_dtype=torch.bfloat16, token=hf_token,
202
- attn_implementation="sdpa", # PyTorch built-in, no flash-attn package needed
203
  ).to(DEVICE)
204
 
205
- lora_config = LoraConfig(
206
  r=16, lora_alpha=16,
207
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
208
- lora_dropout=0.0, bias="none",
209
- task_type=TaskType.CAUSAL_LM,
210
  )
211
- model = get_peft_model(model, lora_config)
212
-
213
  if tokenizer.pad_token is None:
214
  tokenizer.pad_token = tokenizer.eos_token
215
 
216
  trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
217
  total = sum(p.numel() for p in model.parameters())
218
- log(f"✅ Model loaded. Trainable: {trainable:,} / {total:,} params")
 
 
 
 
219
 
220
  optimizer = torch.optim.AdamW(
221
  filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate
222
  )
223
  env = SB3Adapter()
224
-
225
  all_rewards = []
226
- full_log = []
227
 
228
  # Pre-training eval
229
- log("\n▶ PRE-TRAINING EVAL")
230
  model.eval()
231
- pre_reward, pre_log_data, _ = run_episode(model, tokenizer, env, steps_per_episode=steps_per_episode)
232
- all_rewards.append(pre_reward)
233
- full_log.append({"phase": "pre-training", "reward": pre_reward, "reasoning": pre_log_data})
234
- log(f" Reward: {pre_reward:+.3f}")
235
 
236
- # Training
237
- log(f"\n▶ TRAINING ({num_iterations} iterations, grad_accum={GRAD_ACCUM_STEPS})")
238
  model.train()
 
 
239
  for i in range(num_iterations):
240
- reward, train_log_data, ep_losses = run_episode(
241
  model, tokenizer, env, is_training=True, optimizer=optimizer,
242
  steps_per_episode=steps_per_episode,
 
243
  )
244
  all_rewards.append(reward)
245
- full_log.append({"phase": f"training-{i+1}", "reward": reward, "reasoning": train_log_data})
246
 
247
- # Step optimizer every GRAD_ACCUM_STEPS episodes (batched update)
248
  if (i + 1) % GRAD_ACCUM_STEPS == 0:
249
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
250
  optimizer.step()
251
  optimizer.zero_grad()
252
 
253
- log(f" Iter {i+1}/{num_iterations}: reward={reward:+.3f}")
 
 
 
 
 
 
 
 
 
254
 
255
- # Final optimizer step for remaining accumulated gradients
256
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
257
  optimizer.step()
258
 
259
  # Post-training eval
260
- log("\n▶ POST-TRAINING EVAL")
261
  model.eval()
262
- post_reward, post_log_data, _ = run_episode(model, tokenizer, env, steps_per_episode=steps_per_episode)
263
- all_rewards.append(post_reward)
264
- full_log.append({"phase": "post-training", "reward": post_reward, "reasoning": post_log_data})
265
- log(f" Reward: {post_reward:+.3f}")
 
266
 
267
- delta = all_rewards[-1] - all_rewards[0]
268
- log(f"\n✅ DONE | Pre: {all_rewards[0]:+.3f} → Post: {all_rewards[-1]:+.3f} | Δ={delta:+.3f}")
269
 
270
- # Save log
271
- with open("outputs/llm_training_log.json", "w") as f:
272
- json.dump(full_log, f, indent=2, default=str)
273
 
274
- # Generate graph
275
- graph_path = _generate_graph(all_rewards, num_iterations, model_name)
276
 
277
- return all_rewards, full_log, graph_path, "\n".join(log_lines)
278
 
 
 
 
 
 
279
 
280
- def _generate_graph(all_rewards, num_iterations, model_name):
281
- labels = ["Before"] + [f"Iter {i+1}" for i in range(num_iterations)] + ["After"]
282
- colors = ["#ef4444"] + ["#3b82f6"] * num_iterations + ["#22c55e"]
283
 
284
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6), facecolor="#0e1117")
 
 
 
 
285
  for ax in [ax1, ax2]:
286
- ax.set_facecolor("#0e1117")
287
- ax.tick_params(colors="#e6e6e6")
288
- ax.grid(axis="y", alpha=0.1, color="white")
289
- for s in ['top','right']:
290
  ax.spines[s].set_visible(False)
291
- for s in ['left','bottom']:
292
  ax.spines[s].set_color('#333')
293
 
294
- ax1.bar(range(len(all_rewards)), all_rewards, color=colors, edgecolor="white", lw=1.5, width=0.6)
295
- ax1.set_xticks(range(len(labels)))
296
- ax1.set_xticklabels(labels, fontsize=8, color="#e6e6e6", rotation=45)
297
- ax1.set_title(f"LLM RL: {model_name.split('/')[-1]}", color="#e6e6e6", fontsize=13, fontweight="bold")
298
- ax1.set_ylabel("Reward", color="#e6e6e6")
299
-
300
- comp = [all_rewards[0], all_rewards[-1]]
301
- ax2.bar(["Before", "After"], comp, color=["#ef4444", "#22c55e"], edgecolor="white", lw=2, width=0.5)
302
- ax2.set_title("Before vs After", color="#e6e6e6", fontsize=13, fontweight="bold")
303
- ax2.set_ylabel("Reward", color="#e6e6e6")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
  plt.tight_layout()
306
- path = "outputs/llm_training_results.png"
307
- plt.savefig(path, dpi=200, bbox_inches="tight", facecolor="#0e1117")
308
  plt.close()
309
- return path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # ============================================================
2
+ # Multi-Model RL Benchmarking Pipeline
3
+ # Sequential training of multiple LLMs with VRAM cleanup
4
+ # REINFORCE + LoRA on Cloud FinOps Environment
5
  # ============================================================
6
 
7
+ import os, re, json, time, gc
 
 
 
 
8
  import numpy as np
9
  import torch
10
  import torch.nn.functional as F
11
  import matplotlib
12
  matplotlib.use("Agg")
13
  import matplotlib.pyplot as plt
14
+ import warnings
15
  warnings.filterwarnings("ignore", category=UserWarning)
16
  warnings.filterwarnings("ignore", category=FutureWarning)
17
 
18
  from cloud_arena.llm_environment import SB3Adapter, Action, AWSCostEnv
19
 
20
+ # ── Configuration ─────────────────────────────────────────────────────────────
21
+
22
+ MODELS_TO_TEST = [
23
+ "unsloth/Qwen2.5-Math-7B-Instruct-bnb-4bit",
24
+ "unsloth/gemma-2b-it-bnb-4bit",
25
+ "unsloth/llama-3-8b-Instruct-bnb-4bit",
26
+ ]
27
 
28
  ACTION_NAMES = {0: "NOOP", 1: "CHECK_DEPS", 2: "RESIZE", 3: "STOP", 4: "DELETE"}
29
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
+ GRAD_ACCUM_STEPS = 4
31
+ MAX_SEQ_LEN = 512
32
+ MAX_GEN_TOKENS = 80
33
+ EMA_ALPHA = 0.3 # EMA smoothing factor for reward graph
34
 
 
 
 
 
35
 
36
+ # ── Prompt & Parser ───────────────────────────────────────────────────────────
37
 
38
  def format_prompt(state_dict):
39
  resources_text = ""
40
  for r in state_dict["resources"]:
41
  status = "ACTIVE" if r["active"] else "STOPPED"
42
+ tag = "PROD" if r["is_prod"] else "TEMP"
43
  resources_text += (
44
  f" - {r['name']} [{status}] ({tag}): "
45
+ f"${r['cost_per_hr']:.2f}/hr, CPU={r['cpu_pct']}%, "
46
  f"Deps={len(r['dependencies'])}\n"
47
  )
48
  savings_pct = state_dict.get("savings_pct", 0.0)
49
  return (
50
+ f"You are a Cloud FinOps AI. Reduce cost by >=20% without breaking production.\n\n"
51
  f"Actions: 0=NOOP, 1=CHECK_DEPS, 2=RESIZE, 3=STOP, 4=DELETE\n\n"
52
  f"Resources:\n{resources_text}\n"
53
  f"Baseline: ${state_dict['baseline_cost']:.2f}/hr | "
54
+ f"Current: ${state_dict['current_cost']:.2f}/hr | Savings: {savings_pct:.1f}%\n\n"
55
+ f"Rules:\n- Never delete/stop prod resources or those with >=5 deps\n"
56
+ f"- Temp resources with 0-1 deps are safe to delete\n- RESIZE is always safe\n\n"
 
 
 
57
  f"CRITICAL: Output ONLY a brief reason then ACTION: <number 0-4>. Nothing else.\n\n"
58
  f"REASONING:"
59
  )
 
62
  def extract_action_and_reasoning(response_text):
63
  """Regex safety net: extracts action even from truncated/malformed output."""
64
  reasoning = response_text.strip()
65
+ action = 2 # Default RESIZE
66
 
67
+ match = re.search(r'ACTION:\s*([0-4])', response_text, re.IGNORECASE)
68
+ if match:
69
+ return int(match.group(1)), reasoning
 
70
 
 
71
  json_match = re.search(r'\{.*?\}', response_text, re.DOTALL)
72
  if json_match:
73
  try:
 
74
  parsed = json.loads(json_match.group(0))
75
  a = parsed.get("action", 2)
76
  if isinstance(a, int) and 0 <= a <= 4:
 
78
  except (json.JSONDecodeError, ValueError):
79
  pass
80
 
81
+ digits = re.findall(r'\b([0-4])\b', response_text[-30:])
82
+ if digits:
83
+ action = int(digits[-1])
 
 
84
  return action, reasoning
85
 
86
 
87
+ # ── REINFORCE Loss ────────────────────────────────────────────────────────────
88
+
89
  def compute_pg_loss(model, tokenizer, prompt, response_text, reward):
 
90
  full_text = prompt + response_text
91
+ enc = tokenizer(full_text, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN).to(DEVICE)
92
+ prompt_len = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN)["input_ids"].shape[1]
 
93
 
94
+ outputs = model(**enc, labels=enc["input_ids"])
95
  logits = outputs.logits[:, prompt_len-1:-1, :]
96
+ targets = enc["input_ids"][:, prompt_len:]
97
 
98
  if targets.shape[1] == 0 or logits.shape[1] == 0:
99
  return 0.0
100
 
101
+ ml = min(logits.shape[1], targets.shape[1])
102
+ log_probs = F.log_softmax(logits[:, :ml, :], dim=-1)
103
+ token_lp = log_probs.gather(2, targets[:, :ml].unsqueeze(-1)).squeeze(-1)
104
 
105
+ loss = -(reward / 10.0) * token_lp.mean()
106
+ loss.backward()
 
 
 
 
 
107
  return loss.item()
108
 
109
 
110
+ # ── Episode Runner ────────────────────────────────────────────────────────────
111
+
112
  def run_episode(model, tokenizer, env, is_training=False, optimizer=None,
113
+ steps_per_episode=15, iteration_num=0, total_iters=0):
114
  obs, info = env.reset()
115
  state_dict = env.core._get_internal_state()
116
  done = False
117
  episode_reward = 0.0
118
  step_count = 0
119
  reasoning_log = []
 
120
 
 
121
  if is_training and optimizer is not None:
122
  optimizer.zero_grad()
123
 
 
125
  prompt = format_prompt(state_dict)
126
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN)
127
  input_ids = inputs["input_ids"].to(DEVICE)
128
+ attn_mask = inputs["attention_mask"].to(DEVICE)
129
 
130
  with torch.no_grad():
131
+ gen = model.generate(
132
+ input_ids, attention_mask=attn_mask,
133
  max_new_tokens=MAX_GEN_TOKENS,
134
  do_sample=True, temperature=0.7, top_p=0.95,
135
  pad_token_id=tokenizer.pad_token_id,
136
  )
137
 
138
+ response_text = tokenizer.decode(gen[0][input_ids.shape[1]:], skip_special_tokens=True)
 
139
  action, reasoning = extract_action_and_reasoning(response_text)
140
 
141
  next_obs, reward, terminated, truncated, next_info = env.step(action)
142
  done = terminated or truncated
143
  episode_reward += reward
144
 
145
+ # ── Detailed per-step terminal output ──
146
+ if is_training and total_iters > 0:
147
+ pct = (iteration_num / total_iters) * 100
148
+ print(f" [{pct:5.1f}%] Ep {iteration_num} Step {step_count+1}: "
149
+ f"{ACTION_NAMES.get(action,'?')} → r={reward:+.3f} | "
150
+ f"💬 {reasoning[:80]}")
151
+
152
  reasoning_log.append({
153
+ "step": step_count + 1, "action": action,
154
+ "action_name": ACTION_NAMES.get(action, "?"),
 
 
155
  "reward": round(reward, 4),
156
+ "reasoning": reasoning[:200],
157
  "message": next_info.get("msg", ""),
158
  })
159
 
160
  if is_training and optimizer is not None:
161
+ compute_pg_loss(model, tokenizer, prompt, response_text, reward)
 
162
 
163
  obs = next_obs
164
  state_dict = env.core._get_internal_state()
165
  step_count += 1
166
 
167
+ return episode_reward, reasoning_log
168
 
169
 
170
+ # ── VRAM Cleanup ──────────────────────────────────────────────────────────────
171
+
172
+ def nuke_vram(model=None, optimizer=None, tokenizer=None):
173
+ """Aggressively free VRAM between model runs."""
174
+ if model is not None:
175
+ del model
176
+ if optimizer is not None:
177
+ del optimizer
178
+ if tokenizer is not None:
179
+ del tokenizer
180
+ gc.collect()
181
+ if torch.cuda.is_available():
182
+ torch.cuda.empty_cache()
183
+ torch.cuda.synchronize()
184
+ vram = torch.cuda.memory_allocated() / 1e9
185
+ print(f" 🧹 VRAM after cleanup: {vram:.2f} GB")
186
+
187
+
188
+ # ── Single Model Training ────────────────────────────────────────────────────
189
 
190
+ def train_single_model(model_name, num_iterations=200, steps_per_episode=15,
191
+ learning_rate=2e-6):
192
+ """Train one model, return rewards list."""
193
+ hf_token = os.environ.get("HF_TOKEN")
194
  from transformers import AutoModelForCausalLM, AutoTokenizer
195
  from peft import get_peft_model, LoraConfig, TaskType
196
 
197
+ short_name = model_name.split("/")[-1]
198
+ print(f"\n{'='*60}")
199
+ print(f" 🧠 Loading: {short_name}")
200
+ print(f"{'='*60}")
 
 
 
 
 
 
 
201
 
202
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
203
  model = AutoModelForCausalLM.from_pretrained(
204
  model_name, torch_dtype=torch.bfloat16, token=hf_token,
205
+ attn_implementation="sdpa",
206
  ).to(DEVICE)
207
 
208
+ lora_cfg = LoraConfig(
209
  r=16, lora_alpha=16,
210
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
211
+ lora_dropout=0.0, bias="none", task_type=TaskType.CAUSAL_LM,
 
212
  )
213
+ model = get_peft_model(model, lora_cfg)
 
214
  if tokenizer.pad_token is None:
215
  tokenizer.pad_token = tokenizer.eos_token
216
 
217
  trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
218
  total = sum(p.numel() for p in model.parameters())
219
+ print(f" Loaded | Trainable: {trainable:,} / {total:,}")
220
+
221
+ if torch.cuda.is_available():
222
+ vram = torch.cuda.memory_allocated() / 1e9
223
+ print(f" 📊 VRAM used: {vram:.2f} GB")
224
 
225
  optimizer = torch.optim.AdamW(
226
  filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate
227
  )
228
  env = SB3Adapter()
 
229
  all_rewards = []
 
230
 
231
  # Pre-training eval
232
+ print(f"\n ▶ PRE-TRAINING EVAL")
233
  model.eval()
234
+ pre_r, _ = run_episode(model, tokenizer, env, steps_per_episode=steps_per_episode)
235
+ all_rewards.append(pre_r)
236
+ print(f" Baseline reward: {pre_r:+.3f}")
 
237
 
238
+ # Training loop
239
+ print(f"\n ▶ TRAINING ({num_iterations} iters, accum={GRAD_ACCUM_STEPS})")
240
  model.train()
241
+ t0 = time.time()
242
+
243
  for i in range(num_iterations):
244
+ reward, log_data = run_episode(
245
  model, tokenizer, env, is_training=True, optimizer=optimizer,
246
  steps_per_episode=steps_per_episode,
247
+ iteration_num=i+1, total_iters=num_iterations,
248
  )
249
  all_rewards.append(reward)
 
250
 
 
251
  if (i + 1) % GRAD_ACCUM_STEPS == 0:
252
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
253
  optimizer.step()
254
  optimizer.zero_grad()
255
 
256
+ # Per-iteration summary
257
+ pct = ((i+1) / num_iterations) * 100
258
+ elapsed = time.time() - t0
259
+ eta = (elapsed / (i+1)) * (num_iterations - i - 1)
260
+ ema = all_rewards[-1] if len(all_rewards) < 3 else (
261
+ EMA_ALPHA * all_rewards[-1] + (1 - EMA_ALPHA) * all_rewards[-2]
262
+ )
263
+ print(f" ┃ [{pct:5.1f}%] Iter {i+1:3d}/{num_iterations} │ "
264
+ f"r={reward:+.3f} │ EMA={ema:+.3f} │ "
265
+ f"ETA={eta:.0f}s")
266
 
267
+ # Final grad step
268
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
269
  optimizer.step()
270
 
271
  # Post-training eval
272
+ print(f"\n ▶ POST-TRAINING EVAL")
273
  model.eval()
274
+ post_r, _ = run_episode(model, tokenizer, env, steps_per_episode=steps_per_episode)
275
+ all_rewards.append(post_r)
276
+ delta = post_r - pre_r
277
+ print(f" Final reward: {post_r:+.3f} (Δ={delta:+.3f})")
278
+ print(f" Time: {time.time()-t0:.0f}s")
279
 
280
+ # Cleanup VRAM
281
+ nuke_vram(model, optimizer, tokenizer)
282
 
283
+ return all_rewards
 
 
284
 
 
 
285
 
286
+ # ── EMA Graph ─────────────────────────────────────────────────────────────────
287
 
288
+ def compute_ema(rewards, alpha=EMA_ALPHA):
289
+ ema = [rewards[0]]
290
+ for r in rewards[1:]:
291
+ ema.append(alpha * r + (1 - alpha) * ema[-1])
292
+ return ema
293
 
 
 
 
294
 
295
+ def generate_comparison_graph(all_results, output_path="outputs/multi_model_comparison.png"):
296
+ BG = '#0e1117'
297
+ COLORS = ['#00d4ff', '#ffa500', '#39ff14', '#ff6b6b', '#b47eff']
298
+
299
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 7), facecolor=BG)
300
  for ax in [ax1, ax2]:
301
+ ax.set_facecolor(BG)
302
+ ax.tick_params(colors='#e6e6e6', labelsize=10)
303
+ ax.grid(True, alpha=0.08, color='white')
304
+ for s in ['top', 'right']:
305
  ax.spines[s].set_visible(False)
306
+ for s in ['left', 'bottom']:
307
  ax.spines[s].set_color('#333')
308
 
309
+ # Left: EMA reward curves
310
+ for idx, (name, rewards) in enumerate(all_results.items()):
311
+ color = COLORS[idx % len(COLORS)]
312
+ ema = compute_ema(rewards)
313
+ ax1.plot(ema, color=color, lw=2.5, label=name, alpha=0.9)
314
+ ax1.plot(rewards, color=color, lw=0.5, alpha=0.2)
315
+
316
+ ax1.set_title("Training Reward (EMA Smoothed)", color='#e6e6e6', fontsize=14, fontweight='bold')
317
+ ax1.set_xlabel("Episode", color='#e6e6e6', fontsize=11)
318
+ ax1.set_ylabel("Reward", color='#e6e6e6', fontsize=11)
319
+ ax1.legend(facecolor='#1a1a2e', edgecolor='#333', labelcolor='#e6e6e6', fontsize=9)
320
+ ax1.axhline(y=0, color='gray', linestyle='--', alpha=0.3)
321
+
322
+ # Right: Before vs After comparison bars
323
+ names = list(all_results.keys())
324
+ pre_scores = [all_results[n][0] for n in names]
325
+ post_scores = [all_results[n][-1] for n in names]
326
+
327
+ x = np.arange(len(names))
328
+ w = 0.35
329
+ bars1 = ax2.bar(x - w/2, pre_scores, w, label='Before', color='#ef4444', edgecolor='white', lw=1)
330
+ bars2 = ax2.bar(x + w/2, post_scores, w, label='After', color='#22c55e', edgecolor='white', lw=1)
331
+
332
+ ax2.set_xticks(x)
333
+ ax2.set_xticklabels(names, fontsize=8, color='#e6e6e6', rotation=15)
334
+ ax2.set_title("Pre vs Post Training", color='#e6e6e6', fontsize=14, fontweight='bold')
335
+ ax2.set_ylabel("Reward", color='#e6e6e6', fontsize=11)
336
+ ax2.legend(facecolor='#1a1a2e', edgecolor='#333', labelcolor='#e6e6e6')
337
+ ax2.axhline(y=0, color='gray', linestyle='--', alpha=0.3)
338
+
339
+ for bar, val in zip(list(bars1) + list(bars2), pre_scores + post_scores):
340
+ ax2.text(bar.get_x() + bar.get_width()/2, val + 0.1,
341
+ f"{val:+.1f}", ha='center', va='bottom', fontsize=9,
342
+ color='#e6e6e6', fontweight='bold')
343
 
344
  plt.tight_layout()
345
+ plt.savefig(output_path, dpi=200, bbox_inches='tight', facecolor=BG)
 
346
  plt.close()
347
+ return output_path
348
+
349
+
350
+ # ── Main Pipeline ─────────────────────────────────────────────────────────────
351
+
352
+ def train_llm(model_name=None, num_iterations=200, steps_per_episode=15,
353
+ learning_rate=2e-6, progress_callback=None):
354
+ """
355
+ Multi-model or single-model training pipeline.
356
+ If model_name contains commas, runs multi-model benchmark.
357
+ """
358
+ log_lines = []
359
+ def log(msg):
360
+ print(msg)
361
+ log_lines.append(msg)
362
+ if progress_callback:
363
+ progress_callback("\n".join(log_lines))
364
+
365
+ # Determine model list
366
+ if model_name and "," in model_name:
367
+ models = [m.strip() for m in model_name.split(",")]
368
+ elif model_name:
369
+ models = [model_name.strip()]
370
+ else:
371
+ models = MODELS_TO_TEST
372
+
373
+ log(f"🖥️ Device: {DEVICE}")
374
+ log(f"🔁 Models to test: {len(models)}")
375
+ for m in models:
376
+ log(f" • {m}")
377
+
378
+ all_results = {}
379
+ full_log = []
380
+
381
+ for model_idx, mname in enumerate(models):
382
+ short = mname.split("/")[-1]
383
+ log(f"\n{'━'*60}")
384
+ log(f" [{model_idx+1}/{len(models)}] {short}")
385
+ log(f"{'━'*60}")
386
+
387
+ try:
388
+ rewards = train_single_model(
389
+ mname, num_iterations=num_iterations,
390
+ steps_per_episode=steps_per_episode,
391
+ learning_rate=learning_rate,
392
+ )
393
+ all_results[short] = rewards
394
+ delta = rewards[-1] - rewards[0]
395
+ log(f" ✅ {short}: Pre={rewards[0]:+.3f} → Post={rewards[-1]:+.3f} (Δ={delta:+.3f})")
396
+ full_log.append({
397
+ "model": mname, "pre": rewards[0], "post": rewards[-1],
398
+ "delta": delta, "all_rewards": rewards,
399
+ })
400
+ except Exception as e:
401
+ log(f" ❌ {short} FAILED: {e}")
402
+ full_log.append({"model": mname, "error": str(e)})
403
+ nuke_vram() # cleanup even on failure
404
+
405
+ # Generate comparison graph
406
+ graph_path = None
407
+ if all_results:
408
+ os.makedirs("outputs", exist_ok=True)
409
+ graph_path = generate_comparison_graph(all_results)
410
+ log(f"\n📊 Comparison graph saved to {graph_path}")
411
+
412
+ # Save log
413
+ with open("outputs/multi_model_log.json", "w") as f:
414
+ json.dump(full_log, f, indent=2, default=str)
415
+
416
+ # Build flat reward list for backward compat
417
+ flat_rewards = []
418
+ for rewards in all_results.values():
419
+ flat_rewards.extend(rewards)
420
+
421
+ log_text = "\n".join(log_lines)
422
+ return flat_rewards or [0], full_log, graph_path, log_text
requirements.txt CHANGED
@@ -14,3 +14,4 @@ peft==0.12.0
14
  accelerate==0.33.0
15
  bitsandbytes>=0.43.0
16
  sentencepiece
 
 
14
  accelerate==0.33.0
15
  bitsandbytes>=0.43.0
16
  sentencepiece
17
+ unsloth