Guard auto-resume against adapter-only checkpoints
Browse files- train_worker.py +51 -5
train_worker.py
CHANGED
|
@@ -993,6 +993,35 @@ def latest_checkpoint_in_sibling_runs(output_dir: Path) -> Path | None:
|
|
| 993 |
return checkpoints[-1][2]
|
| 994 |
|
| 995 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 996 |
def resolve_resume_checkpoint(value: str | None, output_dir: Path) -> str | None:
|
| 997 |
requested = to_text(value).lower()
|
| 998 |
if requested in {"", "none", "false", "no"}:
|
|
@@ -1000,20 +1029,37 @@ def resolve_resume_checkpoint(value: str | None, output_dir: Path) -> str | None
|
|
| 1000 |
if requested in {"auto", "latest"}:
|
| 1001 |
latest = latest_checkpoint_dir(output_dir)
|
| 1002 |
if latest is not None:
|
| 1003 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1004 |
sibling_latest = latest_checkpoint_in_sibling_runs(output_dir=output_dir)
|
| 1005 |
if sibling_latest is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1006 |
print(
|
| 1007 |
-
"[train_worker][
|
| 1008 |
-
f"{sibling_latest}"
|
| 1009 |
)
|
| 1010 |
-
return str(sibling_latest)
|
| 1011 |
return None
|
| 1012 |
candidate = Path(to_text(value))
|
| 1013 |
if not candidate.is_absolute():
|
| 1014 |
candidate = output_dir / candidate
|
| 1015 |
if candidate.exists():
|
| 1016 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1017 |
raise RuntimeError(f"Requested resume checkpoint does not exist: {candidate}")
|
| 1018 |
|
| 1019 |
|
|
|
|
| 993 |
return checkpoints[-1][2]
|
| 994 |
|
| 995 |
|
| 996 |
+
def checkpoint_resume_compatible(checkpoint_dir: Path) -> tuple[bool, str]:
|
| 997 |
+
if not checkpoint_dir.exists():
|
| 998 |
+
return False, "path does not exist"
|
| 999 |
+
if not checkpoint_dir.is_dir():
|
| 1000 |
+
return False, "path is not a directory"
|
| 1001 |
+
|
| 1002 |
+
full_model_markers = (
|
| 1003 |
+
"model.safetensors",
|
| 1004 |
+
"pytorch_model.bin",
|
| 1005 |
+
"model.safetensors.index.json",
|
| 1006 |
+
"pytorch_model.bin.index.json",
|
| 1007 |
+
)
|
| 1008 |
+
if any((checkpoint_dir / marker).exists() for marker in full_model_markers):
|
| 1009 |
+
return True, ""
|
| 1010 |
+
|
| 1011 |
+
adapter_markers = (
|
| 1012 |
+
"adapter_model.safetensors",
|
| 1013 |
+
"adapter_model.bin",
|
| 1014 |
+
"adapter_config.json",
|
| 1015 |
+
)
|
| 1016 |
+
if any((checkpoint_dir / marker).exists() for marker in adapter_markers):
|
| 1017 |
+
return (
|
| 1018 |
+
False,
|
| 1019 |
+
"adapter-only checkpoint (missing full-model checkpoint files required by Trainer resume)",
|
| 1020 |
+
)
|
| 1021 |
+
|
| 1022 |
+
return False, "missing model checkpoint files"
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
def resolve_resume_checkpoint(value: str | None, output_dir: Path) -> str | None:
|
| 1026 |
requested = to_text(value).lower()
|
| 1027 |
if requested in {"", "none", "false", "no"}:
|
|
|
|
| 1029 |
if requested in {"auto", "latest"}:
|
| 1030 |
latest = latest_checkpoint_dir(output_dir)
|
| 1031 |
if latest is not None:
|
| 1032 |
+
compatible, reason = checkpoint_resume_compatible(latest)
|
| 1033 |
+
if compatible:
|
| 1034 |
+
return str(latest)
|
| 1035 |
+
print(
|
| 1036 |
+
"[train_worker][warn] auto-resume skipped latest checkpoint "
|
| 1037 |
+
f"'{latest}' ({reason})."
|
| 1038 |
+
)
|
| 1039 |
sibling_latest = latest_checkpoint_in_sibling_runs(output_dir=output_dir)
|
| 1040 |
if sibling_latest is not None:
|
| 1041 |
+
compatible, reason = checkpoint_resume_compatible(sibling_latest)
|
| 1042 |
+
if compatible:
|
| 1043 |
+
print(
|
| 1044 |
+
"[train_worker][info] auto-resume fallback selected sibling checkpoint: "
|
| 1045 |
+
f"{sibling_latest}"
|
| 1046 |
+
)
|
| 1047 |
+
return str(sibling_latest)
|
| 1048 |
print(
|
| 1049 |
+
"[train_worker][warn] auto-resume skipped sibling checkpoint "
|
| 1050 |
+
f"'{sibling_latest}' ({reason})."
|
| 1051 |
)
|
|
|
|
| 1052 |
return None
|
| 1053 |
candidate = Path(to_text(value))
|
| 1054 |
if not candidate.is_absolute():
|
| 1055 |
candidate = output_dir / candidate
|
| 1056 |
if candidate.exists():
|
| 1057 |
+
compatible, reason = checkpoint_resume_compatible(candidate)
|
| 1058 |
+
if compatible:
|
| 1059 |
+
return str(candidate)
|
| 1060 |
+
raise RuntimeError(
|
| 1061 |
+
f"Requested resume checkpoint is not trainer-resume compatible ({reason}): {candidate}"
|
| 1062 |
+
)
|
| 1063 |
raise RuntimeError(f"Requested resume checkpoint does not exist: {candidate}")
|
| 1064 |
|
| 1065 |
|