Spaces:
Running
Running
Harden GRPO generation stability on CUDA: bf16 + eager attention + invalid-logit guards.
Browse files- ultimate_sota_training.py +17 -3
ultimate_sota_training.py
CHANGED
|
@@ -355,11 +355,14 @@ def run_sota_train():
|
|
| 355 |
|
| 356 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 357 |
tokenizer.pad_token = tokenizer.eos_token
|
| 358 |
-
|
|
|
|
|
|
|
| 359 |
model = AutoModelForCausalLM.from_pretrained(
|
| 360 |
MODEL_NAME,
|
| 361 |
torch_dtype=torch_dtype,
|
| 362 |
device_map="auto",
|
|
|
|
| 363 |
)
|
| 364 |
|
| 365 |
train_dataset = make_real_dataset()
|
|
@@ -378,7 +381,10 @@ def run_sota_train():
|
|
| 378 |
**inputs,
|
| 379 |
max_new_tokens=256,
|
| 380 |
do_sample=True,
|
| 381 |
-
temperature=0.7,
|
|
|
|
|
|
|
|
|
|
| 382 |
pad_token_id=tokenizer.eos_token_id,
|
| 383 |
)
|
| 384 |
completions.append(tokenizer.decode(out[0], skip_special_tokens=True))
|
|
@@ -415,7 +421,15 @@ def run_sota_train():
|
|
| 415 |
gradient_accumulation_steps=grad_accum,
|
| 416 |
num_generations=num_gen,
|
| 417 |
max_completion_length=int(os.environ.get("GRPO_MAX_COMPLETION_LEN", "256")),
|
| 418 |
-
temperature=float(os.environ.get("GRPO_TEMPERATURE", "0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
num_train_epochs=int(os.environ.get("TRAIN_NUM_EPOCHS", "1")),
|
| 420 |
max_steps=max_steps,
|
| 421 |
logging_steps=int(os.environ.get("LOGGING_STEPS", "1")),
|
|
|
|
| 355 |
|
| 356 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 357 |
tokenizer.pad_token = tokenizer.eos_token
|
| 358 |
+
use_cuda = torch.cuda.is_available()
|
| 359 |
+
# L4/A10/A100 are typically more numerically stable with bf16 than fp16 for RL-style sampling.
|
| 360 |
+
torch_dtype = torch.bfloat16 if use_cuda else torch.float32
|
| 361 |
model = AutoModelForCausalLM.from_pretrained(
|
| 362 |
MODEL_NAME,
|
| 363 |
torch_dtype=torch_dtype,
|
| 364 |
device_map="auto",
|
| 365 |
+
attn_implementation=os.environ.get("ATTN_IMPLEMENTATION", "eager"),
|
| 366 |
)
|
| 367 |
|
| 368 |
train_dataset = make_real_dataset()
|
|
|
|
| 381 |
**inputs,
|
| 382 |
max_new_tokens=256,
|
| 383 |
do_sample=True,
|
| 384 |
+
temperature=float(os.environ.get("EVAL_TEMPERATURE", "0.7")),
|
| 385 |
+
top_p=float(os.environ.get("EVAL_TOP_P", "0.9")),
|
| 386 |
+
renormalize_logits=True,
|
| 387 |
+
remove_invalid_values=True,
|
| 388 |
pad_token_id=tokenizer.eos_token_id,
|
| 389 |
)
|
| 390 |
completions.append(tokenizer.decode(out[0], skip_special_tokens=True))
|
|
|
|
| 421 |
gradient_accumulation_steps=grad_accum,
|
| 422 |
num_generations=num_gen,
|
| 423 |
max_completion_length=int(os.environ.get("GRPO_MAX_COMPLETION_LEN", "256")),
|
| 424 |
+
temperature=float(os.environ.get("GRPO_TEMPERATURE", "0.7")),
|
| 425 |
+
top_p=float(os.environ.get("GRPO_TOP_P", "0.9")),
|
| 426 |
+
# Keep generation numerically safe in long sampling loops.
|
| 427 |
+
generation_kwargs={
|
| 428 |
+
"remove_invalid_values": True,
|
| 429 |
+
"renormalize_logits": True,
|
| 430 |
+
},
|
| 431 |
+
bf16=bool(use_cuda),
|
| 432 |
+
fp16=False,
|
| 433 |
num_train_epochs=int(os.environ.get("TRAIN_NUM_EPOCHS", "1")),
|
| 434 |
max_steps=max_steps,
|
| 435 |
logging_steps=int(os.environ.get("LOGGING_STEPS", "1")),
|