kavin57447 commited on
Commit
8d95050
·
1 Parent(s): 27c9425

Switch to Llama 3.1 8B + fix low-timestep crash (min 5000)

Browse files
Files changed (2) hide show
  1. app.py +6 -3
  2. cloud_arena/llm_training.py +2 -2
app.py CHANGED
@@ -18,7 +18,10 @@ os.makedirs("./outputs", exist_ok=True)
18
  def run_math_training(timesteps):
19
  from cloud_arena.training import train_model
20
  try:
21
- model, callback, _ = train_model(total_timesteps=int(timesteps))
 
 
 
22
  from cloud_arena.visualization import generate_dashboard
23
  img_path = generate_dashboard(callback, "outputs/dashboard.png")
24
  summary = (
@@ -94,9 +97,9 @@ with gr.Blocks(title="Cloud Arena RL") as demo:
94
  eval_btn.click(run_math_evaluation, outputs=eval_output)
95
 
96
  with gr.Tab("🧠 LLM RL"):
97
- gr.Markdown("### LLM Model — Gemma 7B + REINFORCE + LoRA")
98
  gr.Markdown("> ⚠️ Requires `HF_TOKEN` secret set in Space settings + accepted model license")
99
- llm_model = gr.Textbox(value="google/gemma-7b-it", label="Model Name")
100
  llm_iters = gr.Number(value=10, label="Training Iterations")
101
  llm_steps = gr.Number(value=5, label="Steps per Episode")
102
  llm_btn = gr.Button("🚀 Start LLM Training", variant="primary")
 
18
  def run_math_training(timesteps):
19
  from cloud_arena.training import train_model
20
  try:
21
+ ts = max(int(timesteps), 5000) # minimum 5000 to avoid sampling errors
22
+ if int(timesteps) < 5000:
23
+ print(f"⚠️ Timesteps too low ({int(timesteps)}), using minimum 5000")
24
+ model, callback, _ = train_model(total_timesteps=ts)
25
  from cloud_arena.visualization import generate_dashboard
26
  img_path = generate_dashboard(callback, "outputs/dashboard.png")
27
  summary = (
 
97
  eval_btn.click(run_math_evaluation, outputs=eval_output)
98
 
99
  with gr.Tab("🧠 LLM RL"):
100
+ gr.Markdown("### LLM Model — Llama 3.1 8B + REINFORCE + LoRA")
101
  gr.Markdown("> ⚠️ Requires `HF_TOKEN` secret set in Space settings + accepted model license")
102
+ llm_model = gr.Textbox(value="meta-llama/Llama-3.1-8B", label="Model Name")
103
  llm_iters = gr.Number(value=10, label="Training Iterations")
104
  llm_steps = gr.Number(value=5, label="Steps per Episode")
105
  llm_btn = gr.Button("🚀 Start LLM Training", variant="primary")
cloud_arena/llm_training.py CHANGED
@@ -1,5 +1,5 @@
1
  # ============================================================
2
- # LLM RL Training — LLaMA 3.1 8B + REINFORCE + LoRA
3
  # This is the LLM model, SEPARATE from the mathematical model.
4
  # Uses AWSCostEnv (llm_environment.py), NOT CloudArenaEnv.
5
  # ============================================================
@@ -147,7 +147,7 @@ def run_episode(model, tokenizer, env, is_training=False, optimizer=None,
147
  return episode_reward, reasoning_log
148
 
149
 
150
- def train_llm(model_name="google/gemma-7b-it",
151
  num_iterations=10, steps_per_episode=5, learning_rate=5e-5,
152
  progress_callback=None):
153
  """
 
1
  # ============================================================
2
+ # LLM RL Training — Llama 3.1 8B + REINFORCE + LoRA
3
  # This is the LLM model, SEPARATE from the mathematical model.
4
  # Uses AWSCostEnv (llm_environment.py), NOT CloudArenaEnv.
5
  # ============================================================
 
147
  return episode_reward, reasoning_log
148
 
149
 
150
+ def train_llm(model_name="meta-llama/Llama-3.1-8B",
151
  num_iterations=10, steps_per_episode=5, learning_rate=5e-5,
152
  progress_callback=None):
153
  """