parth-1 commited on
Commit
08bdaa0
·
verified ·
1 Parent(s): 7de0535

Update grpo_train.py

Browse files
Files changed (1) hide show
  1. grpo_train.py +2 -2
grpo_train.py CHANGED
@@ -351,8 +351,8 @@ else:
351
  _name = "CPU"
352
  _cc = (0, 0)
353
 
354
- USE_4BIT = _vram < 40 * 1024**3 # T4 (15 GB), L4 (24 GB) → 4-bit; A100 (80 GB) → bf16
355
- USE_BF16 = _cc >= (8, 0) # Ampere+ (A100, L4) support bf16; Turing (T4) does not
356
 
357
  # #region agent log
358
  _dlog("A", "grpo_train.py:gpu_detect", "GPU config resolved", {"name":_name,"vram_gb":round(_vram/1024**3,1),"cc":list(_cc),"USE_4BIT":USE_4BIT,"USE_BF16":USE_BF16})
 
351
  _name = "CPU"
352
  _cc = (0, 0)
353
 
354
+ USE_4BIT = _vram < 40 * 1024**3 # T4 (15 GB), L4 (24 GB) → 4-bit; A100 (80 GB) → full
355
+ USE_BF16 = _cc >= (8, 0) and not USE_4BIT # bf16 only when full-precision; 4-bit LoRA uses fp16 internally
356
 
357
  # #region agent log
358
  _dlog("A", "grpo_train.py:gpu_detect", "GPU config resolved", {"name":_name,"vram_gb":round(_vram/1024**3,1),"cc":list(_cc),"USE_4BIT":USE_4BIT,"USE_BF16":USE_BF16})