md896 commited on
Commit
6083a40
·
1 Parent(s): 948530a

Fix TRL 0.18 compatibility: remove unsupported generation_kwargs; set safety flags on model.generation_config.

Browse files
Files changed (1) hide show
  1. ultimate_sota_training.py +5 -5
ultimate_sota_training.py CHANGED
@@ -364,6 +364,11 @@ def run_sota_train():
364
  device_map="auto",
365
  attn_implementation=os.environ.get("ATTN_IMPLEMENTATION", "eager"),
366
  )
 
 
 
 
 
367
 
368
  train_dataset = make_real_dataset()
369
 
@@ -423,11 +428,6 @@ def run_sota_train():
423
  max_completion_length=int(os.environ.get("GRPO_MAX_COMPLETION_LEN", "256")),
424
  temperature=float(os.environ.get("GRPO_TEMPERATURE", "0.7")),
425
  top_p=float(os.environ.get("GRPO_TOP_P", "0.9")),
426
- # Keep generation numerically safe in long sampling loops.
427
- generation_kwargs={
428
- "remove_invalid_values": True,
429
- "renormalize_logits": True,
430
- },
431
  bf16=bool(use_cuda),
432
  fp16=False,
433
  num_train_epochs=int(os.environ.get("TRAIN_NUM_EPOCHS", "1")),
 
364
  device_map="auto",
365
  attn_implementation=os.environ.get("ATTN_IMPLEMENTATION", "eager"),
366
  )
367
+ # Runtime generation safety defaults (used by both eval and GRPO generate path).
368
+ model.generation_config.remove_invalid_values = True
369
+ model.generation_config.renormalize_logits = True
370
+ model.generation_config.top_p = float(os.environ.get("GRPO_TOP_P", "0.9"))
371
+ model.generation_config.temperature = float(os.environ.get("GRPO_TEMPERATURE", "0.7"))
372
 
373
  train_dataset = make_real_dataset()
374
 
 
428
  max_completion_length=int(os.environ.get("GRPO_MAX_COMPLETION_LEN", "256")),
429
  temperature=float(os.environ.get("GRPO_TEMPERATURE", "0.7")),
430
  top_p=float(os.environ.get("GRPO_TOP_P", "0.9")),
 
 
 
 
 
431
  bf16=bool(use_cuda),
432
  fp16=False,
433
  num_train_epochs=int(os.environ.get("TRAIN_NUM_EPOCHS", "1")),