md896 commited on
Commit
af54ccd
·
1 Parent(s): e5262a1

Fix GRPO batch/generation mismatch: auto-adjust num_generations; set launcher default to 2.

Browse files
Files changed (2) hide show
  1. launch_job.py +2 -2
  2. ultimate_sota_training.py +17 -3
launch_job.py CHANGED
@@ -18,7 +18,7 @@ Environment (optional):
18
  TRAIN_REPO_GIT_URL, OPENENV_BASE_URL
19
  TRAIN_MAX_STEPS default: 80 (faster run; raise for stronger fit)
20
  ROWS_PER_TASK default: 32
21
- GRPO_NUM_GENERATIONS default: 6
22
  SKIP_HUB_PUSH default: 0
23
  """
24
  from __future__ import annotations
@@ -33,7 +33,7 @@ _REPO_URL = os.environ.get("TRAIN_REPO_GIT_URL", _DEFAULT_REPO)
33
  _OPENENV = os.environ.get("OPENENV_BASE_URL", "https://md896-sql-debug-env.hf.space")
34
  _MAX_STEPS = os.environ.get("TRAIN_MAX_STEPS", "80")
35
  _ROWS = os.environ.get("ROWS_PER_TASK", "32")
36
- _NUM_GEN = os.environ.get("GRPO_NUM_GENERATIONS", "6")
37
  _SKIP_PUSH = os.environ.get("SKIP_HUB_PUSH", "0")
38
  _TIMEOUT = os.environ.get("HF_JOB_TIMEOUT", "8h")
39
  # l4x1: newer GPU, good for Unsloth; use HF_JOB_FLAVOR=t4-small if queue or cost is better for you
 
18
  TRAIN_REPO_GIT_URL, OPENENV_BASE_URL
19
  TRAIN_MAX_STEPS default: 80 (faster run; raise for stronger fit)
20
  ROWS_PER_TASK default: 32
21
+ GRPO_NUM_GENERATIONS default: 2
22
  SKIP_HUB_PUSH default: 0
23
  """
24
  from __future__ import annotations
 
33
  _OPENENV = os.environ.get("OPENENV_BASE_URL", "https://md896-sql-debug-env.hf.space")
34
  _MAX_STEPS = os.environ.get("TRAIN_MAX_STEPS", "80")
35
  _ROWS = os.environ.get("ROWS_PER_TASK", "32")
36
+ _NUM_GEN = os.environ.get("GRPO_NUM_GENERATIONS", "2")
37
  _SKIP_PUSH = os.environ.get("SKIP_HUB_PUSH", "0")
38
  _TIMEOUT = os.environ.get("HF_JOB_TIMEOUT", "8h")
39
  # l4x1: newer GPU, good for Unsloth; use HF_JOB_FLAVOR=t4-small if queue or cost is better for you
ultimate_sota_training.py CHANGED
@@ -394,12 +394,26 @@ def run_sota_train():
394
  if report_to == "tensorboard":
395
  _ensure_dir(tb_dir)
396
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  _cfg: Dict[str, Any] = dict(
398
  output_dir=out_dir,
399
  learning_rate=float(os.environ.get("TRAIN_LR", "5e-6")),
400
- per_device_train_batch_size=int(os.environ.get("PER_DEVICE_TRAIN_BS", "1")),
401
- gradient_accumulation_steps=int(os.environ.get("GRAD_ACCUM", "2")),
402
- num_generations=int(os.environ.get("GRPO_NUM_GENERATIONS", "8")),
403
  max_completion_length=int(os.environ.get("GRPO_MAX_COMPLETION_LEN", "256")),
404
  temperature=float(os.environ.get("GRPO_TEMPERATURE", "0.9")),
405
  num_train_epochs=int(os.environ.get("TRAIN_NUM_EPOCHS", "1")),
 
394
  if report_to == "tensorboard":
395
  _ensure_dir(tb_dir)
396
 
397
+ per_device_bs = int(os.environ.get("PER_DEVICE_TRAIN_BS", "1"))
398
+ grad_accum = int(os.environ.get("GRAD_ACCUM", "2"))
399
+ requested_num_gen = int(os.environ.get("GRPO_NUM_GENERATIONS", "8"))
400
+ effective_bs = max(1, per_device_bs * grad_accum)
401
+ if effective_bs % requested_num_gen != 0:
402
+ valid = [d for d in range(2, effective_bs + 1) if effective_bs % d == 0]
403
+ num_gen = valid[-1] if valid else 2
404
+ print(
405
+ f"Adjusting GRPO_NUM_GENERATIONS from {requested_num_gen} to {num_gen} "
406
+ f"for effective batch size {effective_bs}."
407
+ )
408
+ else:
409
+ num_gen = requested_num_gen
410
+
411
  _cfg: Dict[str, Any] = dict(
412
  output_dir=out_dir,
413
  learning_rate=float(os.environ.get("TRAIN_LR", "5e-6")),
414
+ per_device_train_batch_size=per_device_bs,
415
+ gradient_accumulation_steps=grad_accum,
416
+ num_generations=num_gen,
417
  max_completion_length=int(os.environ.get("GRPO_MAX_COMPLETION_LEN", "256")),
418
  temperature=float(os.environ.get("GRPO_TEMPERATURE", "0.9")),
419
  num_train_epochs=int(os.environ.get("TRAIN_NUM_EPOCHS", "1")),