md896 commited on
Commit
948530a
·
1 Parent(s): af54ccd

Harden GRPO generation stability on CUDA: bf16 + eager attention + invalid-logit guards.

Browse files
Files changed (1) hide show
  1. ultimate_sota_training.py +17 -3
ultimate_sota_training.py CHANGED
@@ -355,11 +355,14 @@ def run_sota_train():
355
 
356
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
357
  tokenizer.pad_token = tokenizer.eos_token
358
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
 
359
  model = AutoModelForCausalLM.from_pretrained(
360
  MODEL_NAME,
361
  torch_dtype=torch_dtype,
362
  device_map="auto",
 
363
  )
364
 
365
  train_dataset = make_real_dataset()
@@ -378,7 +381,10 @@ def run_sota_train():
378
  **inputs,
379
  max_new_tokens=256,
380
  do_sample=True,
381
- temperature=0.7,
 
 
 
382
  pad_token_id=tokenizer.eos_token_id,
383
  )
384
  completions.append(tokenizer.decode(out[0], skip_special_tokens=True))
@@ -415,7 +421,15 @@ def run_sota_train():
415
  gradient_accumulation_steps=grad_accum,
416
  num_generations=num_gen,
417
  max_completion_length=int(os.environ.get("GRPO_MAX_COMPLETION_LEN", "256")),
418
- temperature=float(os.environ.get("GRPO_TEMPERATURE", "0.9")),
 
 
 
 
 
 
 
 
419
  num_train_epochs=int(os.environ.get("TRAIN_NUM_EPOCHS", "1")),
420
  max_steps=max_steps,
421
  logging_steps=int(os.environ.get("LOGGING_STEPS", "1")),
 
355
 
356
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
357
  tokenizer.pad_token = tokenizer.eos_token
358
+ use_cuda = torch.cuda.is_available()
359
+ # L4/A10/A100 are typically more numerically stable with bf16 than fp16 for RL-style sampling.
360
+ torch_dtype = torch.bfloat16 if use_cuda else torch.float32
361
  model = AutoModelForCausalLM.from_pretrained(
362
  MODEL_NAME,
363
  torch_dtype=torch_dtype,
364
  device_map="auto",
365
+ attn_implementation=os.environ.get("ATTN_IMPLEMENTATION", "eager"),
366
  )
367
 
368
  train_dataset = make_real_dataset()
 
381
  **inputs,
382
  max_new_tokens=256,
383
  do_sample=True,
384
+ temperature=float(os.environ.get("EVAL_TEMPERATURE", "0.7")),
385
+ top_p=float(os.environ.get("EVAL_TOP_P", "0.9")),
386
+ renormalize_logits=True,
387
+ remove_invalid_values=True,
388
  pad_token_id=tokenizer.eos_token_id,
389
  )
390
  completions.append(tokenizer.decode(out[0], skip_special_tokens=True))
 
421
  gradient_accumulation_steps=grad_accum,
422
  num_generations=num_gen,
423
  max_completion_length=int(os.environ.get("GRPO_MAX_COMPLETION_LEN", "256")),
424
+ temperature=float(os.environ.get("GRPO_TEMPERATURE", "0.7")),
425
+ top_p=float(os.environ.get("GRPO_TOP_P", "0.9")),
426
+ # Keep generation numerically safe in long sampling loops.
427
+ generation_kwargs={
428
+ "remove_invalid_values": True,
429
+ "renormalize_logits": True,
430
+ },
431
+ bf16=bool(use_cuda),
432
+ fp16=False,
433
  num_train_epochs=int(os.environ.get("TRAIN_NUM_EPOCHS", "1")),
434
  max_steps=max_steps,
435
  logging_steps=int(os.environ.get("LOGGING_STEPS", "1")),