Spaces:
Running
Running
Fix TRL 0.18 compatibility: remove unsupported generation_kwargs; set safety flags on model.generation_config.
Browse files
ultimate_sota_training.py
CHANGED
|
@@ -364,6 +364,11 @@ def run_sota_train():
|
|
| 364 |
device_map="auto",
|
| 365 |
attn_implementation=os.environ.get("ATTN_IMPLEMENTATION", "eager"),
|
| 366 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
|
| 368 |
train_dataset = make_real_dataset()
|
| 369 |
|
|
@@ -423,11 +428,6 @@ def run_sota_train():
|
|
| 423 |
max_completion_length=int(os.environ.get("GRPO_MAX_COMPLETION_LEN", "256")),
|
| 424 |
temperature=float(os.environ.get("GRPO_TEMPERATURE", "0.7")),
|
| 425 |
top_p=float(os.environ.get("GRPO_TOP_P", "0.9")),
|
| 426 |
-
# Keep generation numerically safe in long sampling loops.
|
| 427 |
-
generation_kwargs={
|
| 428 |
-
"remove_invalid_values": True,
|
| 429 |
-
"renormalize_logits": True,
|
| 430 |
-
},
|
| 431 |
bf16=bool(use_cuda),
|
| 432 |
fp16=False,
|
| 433 |
num_train_epochs=int(os.environ.get("TRAIN_NUM_EPOCHS", "1")),
|
|
|
|
| 364 |
device_map="auto",
|
| 365 |
attn_implementation=os.environ.get("ATTN_IMPLEMENTATION", "eager"),
|
| 366 |
)
|
| 367 |
+
# Runtime generation safety defaults (used by both eval and GRPO generate path).
|
| 368 |
+
model.generation_config.remove_invalid_values = True
|
| 369 |
+
model.generation_config.renormalize_logits = True
|
| 370 |
+
model.generation_config.top_p = float(os.environ.get("GRPO_TOP_P", "0.9"))
|
| 371 |
+
model.generation_config.temperature = float(os.environ.get("GRPO_TEMPERATURE", "0.7"))
|
| 372 |
|
| 373 |
train_dataset = make_real_dataset()
|
| 374 |
|
|
|
|
| 428 |
max_completion_length=int(os.environ.get("GRPO_MAX_COMPLETION_LEN", "256")),
|
| 429 |
temperature=float(os.environ.get("GRPO_TEMPERATURE", "0.7")),
|
| 430 |
top_p=float(os.environ.get("GRPO_TOP_P", "0.9")),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
bf16=bool(use_cuda),
|
| 432 |
fp16=False,
|
| 433 |
num_train_epochs=int(os.environ.get("TRAIN_NUM_EPOCHS", "1")),
|