NorthernTribe-Research commited on
Commit
aef97c9
·
verified ·
1 Parent(s): 41a7495

Keep training stack parity: CPU fallback path and mixed-precision safeguards.

Browse files
Files changed (1) hide show
  1. scripts/train_sota.py +19 -7
scripts/train_sota.py CHANGED
@@ -432,8 +432,12 @@ def build_model_and_tokenizer(model_cfg: Dict[str, Any], training_defaults: Dict
432
  if not base_model:
433
  raise ValueError("model.base_model is required.")
434
 
435
- use_bf16 = bool(model_cfg.get("use_bf16", True))
436
- dtype = torch.bfloat16 if use_bf16 else torch.float16
 
 
 
 
437
 
438
  tokenizer = build_tokenizer(model_cfg)
439
 
@@ -445,10 +449,11 @@ def build_model_and_tokenizer(model_cfg: Dict[str, Any], training_defaults: Dict
445
  if attn_impl:
446
  model_kwargs["attn_implementation"] = attn_impl
447
 
448
- load_in_4bit = bool(model_cfg.get("load_in_4bit", True))
 
 
 
449
  if load_in_4bit:
450
- if not torch.cuda.is_available():
451
- raise RuntimeError("4-bit loading requested but CUDA is not available.")
452
  model_kwargs["quantization_config"] = BitsAndBytesConfig(
453
  load_in_4bit=True,
454
  bnb_4bit_quant_type=as_text(model_cfg.get("bnb_4bit_quant_type")) or "nf4",
@@ -565,6 +570,9 @@ def build_training_args(
565
  has_eval_split: bool,
566
  ) -> TrainingArguments:
567
  output_dir.mkdir(parents=True, exist_ok=True)
 
 
 
568
  return TrainingArguments(
569
  output_dir=str(output_dir),
570
  num_train_epochs=as_float(training_cfg.get("num_train_epochs"), 1.0),
@@ -582,8 +590,8 @@ def build_training_args(
582
  save_total_limit=as_int(training_cfg.get("save_total_limit"), 3),
583
  dataloader_num_workers=as_int(training_cfg.get("dataloader_num_workers"), 0),
584
  seed=as_int(training_cfg.get("seed"), 17),
585
- bf16=use_bf16,
586
- fp16=not use_bf16,
587
  remove_unused_columns=False,
588
  report_to="none",
589
  evaluation_strategy="steps" if has_eval_split else "no",
@@ -860,6 +868,10 @@ def main() -> None:
860
  model = None
861
  else:
862
  model, tokenizer = build_model_and_tokenizer(cfg["model"], cfg.get("training_defaults", {}))
 
 
 
 
863
 
864
  data_cfg = cfg["data"]
865
  stage_reports: List[Dict[str, Any]] = []
 
432
  if not base_model:
433
  raise ValueError("model.base_model is required.")
434
 
435
+ use_cuda = torch.cuda.is_available()
436
+ requested_bf16 = bool(model_cfg.get("use_bf16", True))
437
+ if use_cuda:
438
+ dtype = torch.bfloat16 if requested_bf16 else torch.float16
439
+ else:
440
+ dtype = torch.float32
441
 
442
  tokenizer = build_tokenizer(model_cfg)
443
 
 
449
  if attn_impl:
450
  model_kwargs["attn_implementation"] = attn_impl
451
 
452
+ requested_load_in_4bit = bool(model_cfg.get("load_in_4bit", True))
453
+ load_in_4bit = requested_load_in_4bit and use_cuda
454
+ if requested_load_in_4bit and not load_in_4bit:
455
+ print("CUDA unavailable. Disabling 4-bit loading and using full-precision CPU fallback.")
456
  if load_in_4bit:
 
 
457
  model_kwargs["quantization_config"] = BitsAndBytesConfig(
458
  load_in_4bit=True,
459
  bnb_4bit_quant_type=as_text(model_cfg.get("bnb_4bit_quant_type")) or "nf4",
 
570
  has_eval_split: bool,
571
  ) -> TrainingArguments:
572
  output_dir.mkdir(parents=True, exist_ok=True)
573
+ use_cuda = torch.cuda.is_available()
574
+ bf16_runtime = bool(use_cuda and use_bf16)
575
+ fp16_runtime = bool(use_cuda and not bf16_runtime)
576
  return TrainingArguments(
577
  output_dir=str(output_dir),
578
  num_train_epochs=as_float(training_cfg.get("num_train_epochs"), 1.0),
 
590
  save_total_limit=as_int(training_cfg.get("save_total_limit"), 3),
591
  dataloader_num_workers=as_int(training_cfg.get("dataloader_num_workers"), 0),
592
  seed=as_int(training_cfg.get("seed"), 17),
593
+ bf16=bf16_runtime,
594
+ fp16=fp16_runtime,
595
  remove_unused_columns=False,
596
  report_to="none",
597
  evaluation_strategy="steps" if has_eval_split else "no",
 
868
  model = None
869
  else:
870
  model, tokenizer = build_model_and_tokenizer(cfg["model"], cfg.get("training_defaults", {}))
871
+ if torch.cuda.is_available():
872
+ print("Compute mode: GPU")
873
+ else:
874
+ print("Compute mode: CPU fallback (no CUDA detected)")
875
 
876
  data_cfg = cfg["data"]
877
  stage_reports: List[Dict[str, Any]] = []