Spaces:
Runtime error
Runtime error
Fix hang: remove use_cpu parameter, reduce generations to 2, batch to 2, steps to 20
Browse files- train_arithmetic.py +3 -4
train_arithmetic.py
CHANGED
|
@@ -19,7 +19,7 @@ from trl import GRPOConfig, GRPOTrainer
|
|
| 19 |
|
| 20 |
BASE_MODEL = "Qwen/Qwen3-0.6B-Base"
|
| 21 |
OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic"
|
| 22 |
-
MAX_STEPS =
|
| 23 |
NUM_SAMPLES = 500 # Training samples
|
| 24 |
EVAL_SAMPLES = 20 # For baseline test
|
| 25 |
|
|
@@ -184,8 +184,8 @@ def main():
|
|
| 184 |
training_args = GRPOConfig(
|
| 185 |
output_dir="./outputs",
|
| 186 |
max_steps=MAX_STEPS,
|
| 187 |
-
per_device_train_batch_size=
|
| 188 |
-
num_generations=
|
| 189 |
learning_rate=2e-4,
|
| 190 |
beta=0.0, # No KL penalty for this task
|
| 191 |
bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
|
|
@@ -196,7 +196,6 @@ def main():
|
|
| 196 |
save_steps=MAX_STEPS, # Save at end
|
| 197 |
push_to_hub=False, # We'll push manually
|
| 198 |
report_to="none",
|
| 199 |
-
use_cpu=is_cpu, # Explicitly tell trainer to use CPU
|
| 200 |
)
|
| 201 |
|
| 202 |
print("π Starting GRPO Training...")
|
|
|
|
| 19 |
|
| 20 |
BASE_MODEL = "Qwen/Qwen3-0.6B-Base"
|
| 21 |
OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic"
|
| 22 |
+
MAX_STEPS = 20 # Reduced for CPU testing
|
| 23 |
NUM_SAMPLES = 500 # Training samples
|
| 24 |
EVAL_SAMPLES = 20 # For baseline test
|
| 25 |
|
|
|
|
| 184 |
training_args = GRPOConfig(
|
| 185 |
output_dir="./outputs",
|
| 186 |
max_steps=MAX_STEPS,
|
| 187 |
+
per_device_train_batch_size=2, # Reduced for CPU
|
| 188 |
+
num_generations=2, # Reduced for CPU (faster)
|
| 189 |
learning_rate=2e-4,
|
| 190 |
beta=0.0, # No KL penalty for this task
|
| 191 |
bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
|
|
|
|
| 196 |
save_steps=MAX_STEPS, # Save at end
|
| 197 |
push_to_hub=False, # We'll push manually
|
| 198 |
report_to="none",
|
|
|
|
| 199 |
)
|
| 200 |
|
| 201 |
print("π Starting GRPO Training...")
|