Spaces:
Paused
Paused
Commit ·
8d95050
1
Parent(s): 27c9425
Switch to Llama 3.1 8B + fix low-timestep crash (min 5000)
Browse files- app.py +6 -3
- cloud_arena/llm_training.py +2 -2
app.py
CHANGED
|
@@ -18,7 +18,10 @@ os.makedirs("./outputs", exist_ok=True)
|
|
| 18 |
def run_math_training(timesteps):
|
| 19 |
from cloud_arena.training import train_model
|
| 20 |
try:
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
| 22 |
from cloud_arena.visualization import generate_dashboard
|
| 23 |
img_path = generate_dashboard(callback, "outputs/dashboard.png")
|
| 24 |
summary = (
|
|
@@ -94,9 +97,9 @@ with gr.Blocks(title="Cloud Arena RL") as demo:
|
|
| 94 |
eval_btn.click(run_math_evaluation, outputs=eval_output)
|
| 95 |
|
| 96 |
with gr.Tab("🧠 LLM RL"):
|
| 97 |
-
gr.Markdown("### LLM Model —
|
| 98 |
gr.Markdown("> ⚠️ Requires `HF_TOKEN` secret set in Space settings + accepted model license")
|
| 99 |
-
llm_model = gr.Textbox(value="
|
| 100 |
llm_iters = gr.Number(value=10, label="Training Iterations")
|
| 101 |
llm_steps = gr.Number(value=5, label="Steps per Episode")
|
| 102 |
llm_btn = gr.Button("🚀 Start LLM Training", variant="primary")
|
|
|
|
| 18 |
def run_math_training(timesteps):
|
| 19 |
from cloud_arena.training import train_model
|
| 20 |
try:
|
| 21 |
+
ts = max(int(timesteps), 5000) # minimum 5000 to avoid sampling errors
|
| 22 |
+
if int(timesteps) < 5000:
|
| 23 |
+
print(f"⚠️ Timesteps too low ({int(timesteps)}), using minimum 5000")
|
| 24 |
+
model, callback, _ = train_model(total_timesteps=ts)
|
| 25 |
from cloud_arena.visualization import generate_dashboard
|
| 26 |
img_path = generate_dashboard(callback, "outputs/dashboard.png")
|
| 27 |
summary = (
|
|
|
|
| 97 |
eval_btn.click(run_math_evaluation, outputs=eval_output)
|
| 98 |
|
| 99 |
with gr.Tab("🧠 LLM RL"):
|
| 100 |
+
gr.Markdown("### LLM Model — Llama 3.1 8B + REINFORCE + LoRA")
|
| 101 |
gr.Markdown("> ⚠️ Requires `HF_TOKEN` secret set in Space settings + accepted model license")
|
| 102 |
+
llm_model = gr.Textbox(value="meta-llama/Llama-3.1-8B", label="Model Name")
|
| 103 |
llm_iters = gr.Number(value=10, label="Training Iterations")
|
| 104 |
llm_steps = gr.Number(value=5, label="Steps per Episode")
|
| 105 |
llm_btn = gr.Button("🚀 Start LLM Training", variant="primary")
|
cloud_arena/llm_training.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# ============================================================
|
| 2 |
-
# LLM RL Training —
|
| 3 |
# This is the LLM model, SEPARATE from the mathematical model.
|
| 4 |
# Uses AWSCostEnv (llm_environment.py), NOT CloudArenaEnv.
|
| 5 |
# ============================================================
|
|
@@ -147,7 +147,7 @@ def run_episode(model, tokenizer, env, is_training=False, optimizer=None,
|
|
| 147 |
return episode_reward, reasoning_log
|
| 148 |
|
| 149 |
|
| 150 |
-
def train_llm(model_name="
|
| 151 |
num_iterations=10, steps_per_episode=5, learning_rate=5e-5,
|
| 152 |
progress_callback=None):
|
| 153 |
"""
|
|
|
|
| 1 |
# ============================================================
|
| 2 |
+
# LLM RL Training — Llama 3.1 8B + REINFORCE + LoRA
|
| 3 |
# This is the LLM model, SEPARATE from the mathematical model.
|
| 4 |
# Uses AWSCostEnv (llm_environment.py), NOT CloudArenaEnv.
|
| 5 |
# ============================================================
|
|
|
|
| 147 |
return episode_reward, reasoning_log
|
| 148 |
|
| 149 |
|
| 150 |
+
def train_llm(model_name="meta-llama/Llama-3.1-8B",
|
| 151 |
num_iterations=10, steps_per_episode=5, learning_rate=5e-5,
|
| 152 |
progress_callback=None):
|
| 153 |
"""
|