Keep training stack parity: CPU fallback path and mixed-precision safeguards.
Browse files- scripts/train_sota.py +19 -7
scripts/train_sota.py
CHANGED
|
@@ -432,8 +432,12 @@ def build_model_and_tokenizer(model_cfg: Dict[str, Any], training_defaults: Dict
|
|
| 432 |
if not base_model:
|
| 433 |
raise ValueError("model.base_model is required.")
|
| 434 |
|
| 435 |
-
|
| 436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
tokenizer = build_tokenizer(model_cfg)
|
| 439 |
|
|
@@ -445,10 +449,11 @@ def build_model_and_tokenizer(model_cfg: Dict[str, Any], training_defaults: Dict
|
|
| 445 |
if attn_impl:
|
| 446 |
model_kwargs["attn_implementation"] = attn_impl
|
| 447 |
|
| 448 |
-
|
|
|
|
|
|
|
|
|
|
| 449 |
if load_in_4bit:
|
| 450 |
-
if not torch.cuda.is_available():
|
| 451 |
-
raise RuntimeError("4-bit loading requested but CUDA is not available.")
|
| 452 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 453 |
load_in_4bit=True,
|
| 454 |
bnb_4bit_quant_type=as_text(model_cfg.get("bnb_4bit_quant_type")) or "nf4",
|
|
@@ -565,6 +570,9 @@ def build_training_args(
|
|
| 565 |
has_eval_split: bool,
|
| 566 |
) -> TrainingArguments:
|
| 567 |
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
| 568 |
return TrainingArguments(
|
| 569 |
output_dir=str(output_dir),
|
| 570 |
num_train_epochs=as_float(training_cfg.get("num_train_epochs"), 1.0),
|
|
@@ -582,8 +590,8 @@ def build_training_args(
|
|
| 582 |
save_total_limit=as_int(training_cfg.get("save_total_limit"), 3),
|
| 583 |
dataloader_num_workers=as_int(training_cfg.get("dataloader_num_workers"), 0),
|
| 584 |
seed=as_int(training_cfg.get("seed"), 17),
|
| 585 |
-
bf16=
|
| 586 |
-
fp16=
|
| 587 |
remove_unused_columns=False,
|
| 588 |
report_to="none",
|
| 589 |
evaluation_strategy="steps" if has_eval_split else "no",
|
|
@@ -860,6 +868,10 @@ def main() -> None:
|
|
| 860 |
model = None
|
| 861 |
else:
|
| 862 |
model, tokenizer = build_model_and_tokenizer(cfg["model"], cfg.get("training_defaults", {}))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 863 |
|
| 864 |
data_cfg = cfg["data"]
|
| 865 |
stage_reports: List[Dict[str, Any]] = []
|
|
|
|
| 432 |
if not base_model:
|
| 433 |
raise ValueError("model.base_model is required.")
|
| 434 |
|
| 435 |
+
use_cuda = torch.cuda.is_available()
|
| 436 |
+
requested_bf16 = bool(model_cfg.get("use_bf16", True))
|
| 437 |
+
if use_cuda:
|
| 438 |
+
dtype = torch.bfloat16 if requested_bf16 else torch.float16
|
| 439 |
+
else:
|
| 440 |
+
dtype = torch.float32
|
| 441 |
|
| 442 |
tokenizer = build_tokenizer(model_cfg)
|
| 443 |
|
|
|
|
| 449 |
if attn_impl:
|
| 450 |
model_kwargs["attn_implementation"] = attn_impl
|
| 451 |
|
| 452 |
+
requested_load_in_4bit = bool(model_cfg.get("load_in_4bit", True))
|
| 453 |
+
load_in_4bit = requested_load_in_4bit and use_cuda
|
| 454 |
+
if requested_load_in_4bit and not load_in_4bit:
|
| 455 |
+
print("CUDA unavailable. Disabling 4-bit loading and using full-precision CPU fallback.")
|
| 456 |
if load_in_4bit:
|
|
|
|
|
|
|
| 457 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 458 |
load_in_4bit=True,
|
| 459 |
bnb_4bit_quant_type=as_text(model_cfg.get("bnb_4bit_quant_type")) or "nf4",
|
|
|
|
| 570 |
has_eval_split: bool,
|
| 571 |
) -> TrainingArguments:
|
| 572 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 573 |
+
use_cuda = torch.cuda.is_available()
|
| 574 |
+
bf16_runtime = bool(use_cuda and use_bf16)
|
| 575 |
+
fp16_runtime = bool(use_cuda and not bf16_runtime)
|
| 576 |
return TrainingArguments(
|
| 577 |
output_dir=str(output_dir),
|
| 578 |
num_train_epochs=as_float(training_cfg.get("num_train_epochs"), 1.0),
|
|
|
|
| 590 |
save_total_limit=as_int(training_cfg.get("save_total_limit"), 3),
|
| 591 |
dataloader_num_workers=as_int(training_cfg.get("dataloader_num_workers"), 0),
|
| 592 |
seed=as_int(training_cfg.get("seed"), 17),
|
| 593 |
+
bf16=bf16_runtime,
|
| 594 |
+
fp16=fp16_runtime,
|
| 595 |
remove_unused_columns=False,
|
| 596 |
report_to="none",
|
| 597 |
evaluation_strategy="steps" if has_eval_split else "no",
|
|
|
|
| 868 |
model = None
|
| 869 |
else:
|
| 870 |
model, tokenizer = build_model_and_tokenizer(cfg["model"], cfg.get("training_defaults", {}))
|
| 871 |
+
if torch.cuda.is_available():
|
| 872 |
+
print("Compute mode: GPU")
|
| 873 |
+
else:
|
| 874 |
+
print("Compute mode: CPU fallback (no CUDA detected)")
|
| 875 |
|
| 876 |
data_cfg = cfg["data"]
|
| 877 |
stage_reports: List[Dict[str, Any]] = []
|