walidsobhie-code commited on
Commit
b098bb5
·
1 Parent(s): e78785b

fix: load model in FP32 to avoid AMP gradient scaling conflict

Browse files

- Non-quantized models now load with torch_dtype=torch.float32
- Trainer fp16=True handles casting during training (AMP)
- Prevents 'GradScaler expects FP32 master weights' errors
- Keeps gradient checkpointing and LoRA intact

Files changed (1) hide show
  1. train_simple_nobnb.py +3 -4
train_simple_nobnb.py CHANGED
@@ -68,12 +68,11 @@ def load_model_and_tokenizer(
68
  torch_dtype=torch.bfloat16,
69
  )
70
  else:
71
- # No quantization - load in fp16 for Kaggle T4/P100 (bf16 not supported)
72
- # Model dtype MUST match training dtype to avoid GradScaler conflicts
73
- load_dtype = torch.float16 if use_fp16 else torch.bfloat16
74
  model = AutoModelForCausalLM.from_pretrained(
75
  model_name,
76
- torch_dtype=load_dtype,
77
  trust_remote_code=trust_remote_code,
78
  device_map="auto",
79
  use_cache=False,
 
68
  torch_dtype=torch.bfloat16,
69
  )
70
  else:
71
+ # No quantization - load in FP32 for AMP compatibility
72
+ # Trainer with fp16=True will handle casting during training
 
73
  model = AutoModelForCausalLM.from_pretrained(
74
  model_name,
75
+ torch_dtype=torch.float32,
76
  trust_remote_code=trust_remote_code,
77
  device_map="auto",
78
  use_cache=False,