umsr_reasoner_trainer / train_worker.py
NorthernTribe-Research's picture
Guard auto-resume against adapter-only checkpoints
e590ff1 verified
#!/usr/bin/env python3
"""Autonomous lightweight trainer for UMSR Reasoner Space."""
from __future__ import annotations
import argparse
import inspect
import json
import math
import os
import platform
import re
import shutil
import subprocess
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
try:
import torch
import torch.nn.functional as F
except Exception:
torch = None
F = None
try:
from datasets import load_dataset
except Exception:
load_dataset = None
try:
from huggingface_hub import HfApi
except Exception:
HfApi = None
try:
import transformers as transformers_pkg
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
TrainerCallback,
Trainer,
TrainingArguments,
set_seed,
)
except Exception:
transformers_pkg = None
AutoConfig = None
AutoModelForCausalLM = None
AutoTokenizer = None
BitsAndBytesConfig = None
DataCollatorForLanguageModeling = None
TrainerCallback = None
Trainer = None
TrainingArguments = None
set_seed = None
try:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
except Exception:
LoraConfig = None
get_peft_model = None
prepare_model_for_kbit_training = None
try:
import accelerate
except Exception:
accelerate = None
TrainerBase = Trainer if Trainer is not None else object
SYSTEM_PROMPT = (
"You are a rigorous reasoning assistant. "
"Solve the task step by step. "
"For programming tasks, provide a correct and runnable code block. "
"Then provide only the final answer inside "
"<final_answer>...</final_answer>."
)
INHOUSE_OWNER_PREFIX = "NorthernTribe-Research/"
BANNED_MODEL_TOKENS = ("gpt2",)
CODE_TASK_HINT_RE = re.compile(
r"\b("
r"code|python|program|programming|function|class|method|bug|debug|algorithm|"
r"runtime|complexity|compile|leetcode|unit test|sql|regex|script"
r")\b",
re.IGNORECASE,
)
APP_DIR = Path(__file__).resolve().parent
REPO_ROOT = APP_DIR.parent.parent
def require_dependency(name: str, available: bool) -> None:
if not available:
raise RuntimeError(f"Missing dependency '{name}'. Install requirements.txt in the Space.")
def to_text(value: Any) -> str:
if value is None:
return ""
return str(value).strip()
def parse_options(value: Any) -> list[str]:
if isinstance(value, list):
return [to_text(item) for item in value if to_text(item)]
text = to_text(value)
if not text:
return []
if "||" in text:
return [chunk.strip() for chunk in text.split("||") if chunk.strip()]
return []
def parse_float(value: Any, default: float) -> float:
try:
candidate = float(value)
except Exception:
return default
if math.isnan(candidate) or math.isinf(candidate):
return default
return candidate
def env_bool(name: str, default: bool) -> bool:
value = os.environ.get(name)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "on"}
def is_local_model_path(model_ref: str) -> bool:
if not model_ref:
return False
try:
return Path(model_ref).exists()
except Exception:
return False
def is_inhouse_model_ref(model_ref: str) -> bool:
return to_text(model_ref).startswith(INHOUSE_OWNER_PREFIX)
def validate_model_reference(model_ref: str, role: str, enforce_inhouse: bool) -> None:
resolved = to_text(model_ref)
if not resolved:
raise ValueError(f"{role} model reference is empty.")
lowered = resolved.lower()
if any(token in lowered for token in BANNED_MODEL_TOKENS):
raise ValueError(
f"{role} model '{resolved}' is blocked. "
"Use an in-house NorthernTribe-Research model or a local checkpoint path."
)
if enforce_inhouse:
if is_inhouse_model_ref(resolved) or is_local_model_path(resolved):
return
raise ValueError(
f"{role} model '{resolved}' is external. "
"In-house enforcement is enabled; use NorthernTribe-Research/* or a local path."
)
def quality_ok(row: dict[str, Any], min_quality: float) -> bool:
if not to_text(row.get("problem")) or not to_text(row.get("answer")):
return False
quality = parse_float(row.get("quality_score", 1.0), 1.0)
return quality >= min_quality
def looks_like_code_task(problem: str, domain: str, options: list[str] | None = None) -> bool:
domain_key = to_text(domain).lower()
if domain_key in {"code", "coding", "programming", "software", "computer_science"}:
return True
chunks = [to_text(problem), to_text(domain)]
if options:
chunks.extend(to_text(item) for item in options)
haystack = " ".join(chunks).strip()
if not haystack:
return False
return bool(CODE_TASK_HINT_RE.search(haystack))
def build_user_prompt(problem: str, options: list[str], domain: str) -> str:
blocks = [f"Problem:\n{problem}"]
if options:
blocks.append("Options:\n" + "\n".join(f"- {item}" for item in options))
if domain:
blocks.append(f"Domain: {domain}")
code_task = looks_like_code_task(problem=problem, domain=domain, options=options)
instruction_lines = [
"1) Think step by step.",
]
if code_task:
instruction_lines.append("2) If code is required, output one runnable ```python``` block.")
instruction_lines.append("3) End with <final_answer>...</final_answer>.")
else:
instruction_lines.append("2) End with <final_answer>...</final_answer>.")
blocks.append("Instructions:\n" + "\n".join(instruction_lines))
return "\n\n".join(blocks)
def format_text(row: dict[str, Any]) -> str:
problem = to_text(row.get("problem"))
answer = to_text(row.get("answer"))
reasoning = to_text(row.get("reasoning_text"))
domain = to_text(row.get("domain"))
options = parse_options(row.get("options"))
prompt = build_user_prompt(problem=problem, options=options, domain=domain)
completion_chunks: list[str] = []
if reasoning:
completion_chunks.append(f"<reasoning>\n{reasoning}\n</reasoning>")
completion_chunks.append(f"<final_answer>{answer}</final_answer>")
completion = "\n\n".join(completion_chunks)
return (
f"SYSTEM:\n{SYSTEM_PROMPT}\n\n"
f"USER:\n{prompt}\n\n"
f"ASSISTANT:\n{completion}"
)
def dataset_cache_dir() -> str:
override = os.environ.get("HF_DATASETS_CACHE")
if override:
path = Path(override)
else:
path = Path(".hf_cache/datasets")
path.mkdir(parents=True, exist_ok=True)
return str(path)
def model_card_text(
dataset_id: str,
model_id: str,
base_model: str,
teacher_models: list[str],
distill_enabled: bool,
) -> str:
mode_text = "teacher-student distillation" if distill_enabled else "supervised fine-tuning"
teacher_line = ", ".join(teacher_models) if teacher_models else "n/a"
return f"""---
language:
- en
library_name: transformers
pipeline_tag: text-generation
datasets:
- {dataset_id}
tags:
- reasoning
- structured-output
- instruction-following
- math
- logic
- science
---
# UMSR-Reasoner-7B
## Purpose
UMSR-Reasoner-7B is a general reasoning model designed for structured problem solving and consistent answer formatting in production and research workflows.
Model repository: `https://huggingface.co/{model_id}`
Primary dataset: `https://huggingface.co/datasets/{dataset_id}`
## Intended Use
Use this model for tasks that require:
- multi-step quantitative reasoning
- logic and strategy-style question answering
- science and technical problem decomposition
- deterministic final-answer formatting for downstream parsers
## Core Capabilities
- Produces step-aware reasoning outputs for complex prompts
- Handles open-form and exam-style tasks across math, logic, and science domains
- Supports structured response contracts for automation pipelines
- Works well in teacher-student continuous improvement loops
## Recommended Prompting
For highest reliability, use explicit instructions about reasoning depth and enforce a final-answer tag in every response.
Suggested system instruction:
`Solve step by step and end with <final_answer>...</final_answer>.`
## Output Contract
Required final output tag:
`<final_answer>...</final_answer>`
Optional reasoning tag:
`<reasoning>...</reasoning>`
## Training Profile
- Student model: `{base_model}`
- Training mode: `{mode_text}`
- Teacher model(s): `{teacher_line}`
## Operational Guidance
- Prefer lower sampling temperature for deterministic workflows
- Validate final answers for high-stakes usage
- Run domain-specific evaluation before production rollout
## Limitations
- May produce plausible but incorrect reasoning traces
- Performance varies with prompt quality and task domain
- Not a substitute for expert review in legal, medical, financial, or safety-critical decisions
"""
def write_json(path: Path, payload: dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
def append_jsonl(path: Path, payload: dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "a", encoding="utf-8") as handle:
handle.write(json.dumps(payload, sort_keys=True) + "\n")
def read_json(path: Path) -> dict[str, Any]:
if not path.exists():
return {}
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except Exception:
return {}
return payload if isinstance(payload, dict) else {}
def finalize_live_progress(output_dir: Path, message: str) -> None:
progress_path = output_dir / "live_progress.json"
payload = read_json(progress_path)
metrics = payload.get("metrics") if isinstance(payload.get("metrics"), dict) else {}
payload.update(
{
"updated_at": datetime.now(timezone.utc).isoformat(),
"status": "completed",
"message": to_text(message) or "training finished",
"metrics": metrics,
}
)
write_json(progress_path, payload)
def safe_float(value: Any) -> float | None:
if value is None:
return None
try:
if torch is not None and isinstance(value, torch.Tensor):
value = value.detach().float().item()
return float(value)
except Exception:
return None
class LiveProgressCallback(TrainerCallback if TrainerCallback is not None else object):
def __init__(
self,
output_dir: Path,
distill_enabled: bool,
runtime_hardware: dict[str, Any] | None = None,
runtime_system: dict[str, Any] | None = None,
):
self.output_dir = output_dir
self.distill_enabled = bool(distill_enabled)
self.runtime_hardware = dict(runtime_hardware or {})
self.runtime_system = dict(runtime_system or {})
self.progress_path = output_dir / "live_progress.json"
self.events_path = output_dir / "live_events.jsonl"
self.latest_metrics: dict[str, float] = {}
def _sync_progress(self, state: Any, status: str, message: str) -> None:
payload = {
"updated_at": datetime.now(timezone.utc).isoformat(),
"status": status,
"message": to_text(message),
"distill_enabled": self.distill_enabled,
"runtime_system": self.runtime_system,
"runtime_hardware": self.runtime_hardware,
"global_step": int(getattr(state, "global_step", 0) or 0),
"max_steps": int(getattr(state, "max_steps", 0) or 0),
"epoch": safe_float(getattr(state, "epoch", None)),
"metrics": self.latest_metrics,
}
write_json(self.progress_path, payload)
def _append_event(self, state: Any, event_type: str, payload: dict[str, Any]) -> None:
event = {
"ts": datetime.now(timezone.utc).isoformat(),
"event": event_type,
"global_step": int(getattr(state, "global_step", 0) or 0),
"epoch": safe_float(getattr(state, "epoch", None)),
"payload": payload,
}
append_jsonl(self.events_path, event)
@staticmethod
def _extract_metrics(logs: dict[str, Any]) -> dict[str, float]:
keys = [
"loss",
"eval_loss",
"learning_rate",
"grad_norm",
"epoch",
"distill_ce_loss",
"distill_kd_loss",
"distill_temperature",
"distill_ce_weight",
"distill_kd_weight",
]
metrics: dict[str, float] = {}
for key in keys:
value = safe_float(logs.get(key))
if value is not None:
metrics[key] = value
return metrics
def on_train_begin(self, args: Any, state: Any, control: Any, **kwargs: Any) -> None:
del args, control, kwargs
self._sync_progress(state=state, status="running", message="training started")
self._append_event(state=state, event_type="train_begin", payload={})
def on_log(self, args: Any, state: Any, control: Any, logs: dict[str, Any] | None = None, **kwargs: Any) -> None:
del args, control, kwargs
payload = logs or {}
metrics = self._extract_metrics(payload)
if metrics:
self.latest_metrics.update(metrics)
step = int(getattr(state, "global_step", 0) or 0)
max_steps = int(getattr(state, "max_steps", 0) or 0)
message = f"step {step}/{max_steps}" if max_steps > 0 else f"step {step}"
self._sync_progress(state=state, status="running", message=message)
self._append_event(state=state, event_type="log", payload=metrics)
def on_save(self, args: Any, state: Any, control: Any, **kwargs: Any) -> None:
del args, control, kwargs
self._sync_progress(state=state, status="running", message="checkpoint saved")
self._append_event(state=state, event_type="save", payload={})
def on_evaluate(
self,
args: Any,
state: Any,
control: Any,
metrics: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
del args, control, kwargs
values = self._extract_metrics(metrics or {})
if values:
self.latest_metrics.update(values)
self._sync_progress(state=state, status="running", message="evaluation completed")
self._append_event(state=state, event_type="evaluate", payload=values)
def on_train_end(self, args: Any, state: Any, control: Any, **kwargs: Any) -> None:
del args, control, kwargs
self._sync_progress(state=state, status="completed", message="training finished")
self._append_event(state=state, event_type="train_end", payload={})
def parse_version_tuple(version_text: str) -> tuple[int, int, int]:
numbers: list[int] = []
for segment in str(version_text).split("."):
digits = ""
for ch in segment:
if ch.isdigit():
digits += ch
else:
break
if not digits:
numbers.append(0)
else:
numbers.append(int(digits))
if len(numbers) == 3:
break
while len(numbers) < 3:
numbers.append(0)
return tuple(numbers[:3])
def probe_runtime_hardware() -> dict[str, Any]:
info: dict[str, Any] = {
"torch_available": bool(torch is not None),
"torch_version": to_text(getattr(torch, "__version__", "unknown")) if torch is not None else "missing",
"cuda_available": False,
"cuda_device_count": 0,
"cuda_device_0": "",
"cuda_compute_capability_0": "",
"cuda_total_memory_gb_0": None,
"mps_available": False,
}
if torch is None:
return info
try:
cuda_available = bool(torch.cuda.is_available())
except Exception:
cuda_available = False
info["cuda_available"] = cuda_available
try:
device_count = int(torch.cuda.device_count())
except Exception:
device_count = 0
info["cuda_device_count"] = max(0, device_count)
if cuda_available and device_count > 0:
try:
info["cuda_device_0"] = to_text(torch.cuda.get_device_name(0))
except Exception:
info["cuda_device_0"] = ""
try:
props = torch.cuda.get_device_properties(0)
info["cuda_compute_capability_0"] = f"{int(props.major)}.{int(props.minor)}"
info["cuda_total_memory_gb_0"] = round(float(props.total_memory) / float(1024 ** 3), 2)
except Exception:
info["cuda_compute_capability_0"] = ""
info["cuda_total_memory_gb_0"] = None
try:
info["mps_available"] = bool(hasattr(torch.backends, "mps") and torch.backends.mps.is_available())
except Exception:
info["mps_available"] = False
return info
def log_runtime_hardware(info: dict[str, Any]) -> None:
cuda_available = bool(info.get("cuda_available", False))
device_count = int(info.get("cuda_device_count", 0) or 0)
gpu_name = to_text(info.get("cuda_device_0"))
print(f"CUDA available: {cuda_available}")
print(f"CUDA device count: {device_count}")
if cuda_available and device_count > 0:
print(f"GPU: {gpu_name or 'unknown'}")
print(
f"[train_worker][runtime] torch={to_text(info.get('torch_version')) or 'unknown'} "
f"cuda={cuda_available} devices={device_count} "
f"mps={bool(info.get('mps_available', False))}"
)
if gpu_name:
details = [f"name={gpu_name}"]
capability = to_text(info.get("cuda_compute_capability_0"))
if capability:
details.append(f"sm={capability}")
memory_gb = info.get("cuda_total_memory_gb_0")
if memory_gb is not None:
details.append(f"vram_gb={memory_gb}")
print("[train_worker][runtime] gpu0 " + " ".join(details))
def preferred_loader_dtype_key() -> str:
version = parse_version_tuple(getattr(transformers_pkg, "__version__", "0.0.0"))
return "dtype" if version >= (5, 0, 0) else "torch_dtype"
def dtype_from_name(name: str) -> Any:
require_dependency("torch", torch is not None)
mapping = {
"float16": torch.float16,
"fp16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
"float32": torch.float32,
"fp32": torch.float32,
}
key = to_text(name).lower()
if key not in mapping:
raise ValueError(f"Unsupported model dtype '{name}'.")
return mapping[key]
def parse_target_modules(text: str) -> list[str]:
raw = to_text(text)
if not raw:
return []
if raw.lower() in {"auto", "default"}:
return []
return [item.strip() for item in raw.split(",") if item.strip()]
def parse_model_list(value: str) -> list[str]:
raw = to_text(value)
if not raw:
return []
return [item.strip() for item in raw.split(",") if item.strip()]
def parse_csv_values(value: str) -> list[str]:
raw = to_text(value)
if not raw:
return []
return [chunk.strip() for chunk in raw.split(",") if chunk.strip()]
def read_os_release() -> dict[str, str]:
path = Path("/etc/os-release")
if not path.exists():
return {}
payload: dict[str, str] = {}
try:
lines = path.read_text(encoding="utf-8", errors="replace").splitlines()
except Exception:
return {}
for line in lines:
text = line.strip()
if not text or text.startswith("#") or "=" not in text:
continue
key, value = text.split("=", 1)
key = key.strip().upper()
value = value.strip().strip('"').strip("'")
if key:
payload[key] = value
return payload
def mem_total_gb() -> float | None:
path = Path("/proc/meminfo")
if not path.exists():
return None
try:
for line in path.read_text(encoding="utf-8", errors="replace").splitlines():
if not line.startswith("MemTotal:"):
continue
parts = line.split()
if len(parts) < 2:
return None
kb = float(parts[1])
return round(kb / (1024.0 * 1024.0), 2)
except Exception:
return None
return None
def collect_runtime_system_snapshot(required_bins: list[str], native_mode: bool) -> dict[str, Any]:
os_release = read_os_release()
os_id = to_text(os_release.get("ID")).lower()
os_name = to_text(os_release.get("PRETTY_NAME")) or platform.platform()
kernel = to_text(platform.release())
arch = to_text(platform.machine())
python_exe = to_text(sys.executable)
python_version = to_text(platform.python_version())
in_venv = bool(getattr(sys, "base_prefix", "") != getattr(sys, "prefix", ""))
cpu_count = int(os.cpu_count() or 0)
memory_gb = mem_total_gb()
try:
disk_usage = shutil.disk_usage(str(Path.cwd()))
disk_total_gb = round(float(disk_usage.total) / float(1024 ** 3), 2)
disk_free_gb = round(float(disk_usage.free) / float(1024 ** 3), 2)
except Exception:
disk_total_gb = None
disk_free_gb = None
binaries: dict[str, str] = {}
missing_required_bins: list[str] = []
for binary in required_bins:
resolved = to_text(shutil.which(binary))
binaries[binary] = resolved
if not resolved:
missing_required_bins.append(binary)
nvidia_smi_present = bool(to_text(shutil.which("nvidia-smi")))
nvidia_smi_output = ""
if nvidia_smi_present:
try:
result = subprocess.run(
["nvidia-smi", "--query-gpu=name,driver_version,memory.total", "--format=csv,noheader"],
capture_output=True,
text=True,
timeout=5,
check=False,
)
nvidia_smi_output = (result.stdout or result.stderr or "").strip()
except Exception as exc:
nvidia_smi_output = f"unavailable: {exc}"
is_ubuntu = os_id == "ubuntu" or "ubuntu" in os_name.lower()
native_ready = len(missing_required_bins) == 0
return {
"collected_at": datetime.now(timezone.utc).isoformat(),
"native_mode": bool(native_mode),
"native_ready": bool(native_ready),
"is_ubuntu": bool(is_ubuntu),
"os_id": os_id or "unknown",
"os_name": os_name or "unknown",
"kernel": kernel or "unknown",
"arch": arch or "unknown",
"python_executable": python_exe or "unknown",
"python_version": python_version or "unknown",
"in_venv": bool(in_venv),
"cpu_count": cpu_count,
"memory_gb": memory_gb,
"disk_total_gb": disk_total_gb,
"disk_free_gb": disk_free_gb,
"required_bins": required_bins,
"binaries": binaries,
"missing_required_bins": missing_required_bins,
"nvidia_smi_present": nvidia_smi_present,
"nvidia_smi_output": nvidia_smi_output,
}
def log_runtime_system_snapshot(snapshot: dict[str, Any]) -> None:
os_name = to_text(snapshot.get("os_name")) or "unknown"
kernel = to_text(snapshot.get("kernel")) or "unknown"
arch = to_text(snapshot.get("arch")) or "unknown"
python_version = to_text(snapshot.get("python_version")) or "unknown"
python_executable = to_text(snapshot.get("python_executable")) or "unknown"
native_mode = bool(snapshot.get("native_mode", False))
native_ready = bool(snapshot.get("native_ready", False))
mode_text = "on" if native_mode else "off"
ready_text = "ready" if native_ready else "degraded"
print(
f"[train_worker][system] native_mode={mode_text} state={ready_text} "
f"os='{os_name}' kernel={kernel} arch={arch}"
)
print(
f"[train_worker][system] python={python_version} executable={python_executable} "
f"venv={bool(snapshot.get('in_venv', False))}"
)
print(
f"[train_worker][system] cpu_count={int(snapshot.get('cpu_count', 0) or 0)} "
f"memory_gb={snapshot.get('memory_gb')} disk_free_gb={snapshot.get('disk_free_gb')}"
)
missing = snapshot.get("missing_required_bins")
if isinstance(missing, list) and missing:
print("[train_worker][warn] missing required native binaries: " + ",".join(str(item) for item in missing))
else:
required = snapshot.get("required_bins")
if isinstance(required, list) and required:
print("[train_worker][system] required native binaries detected: " + ",".join(str(item) for item in required))
nvidia_text = to_text(snapshot.get("nvidia_smi_output"))
if nvidia_text:
print("[train_worker][system] nvidia-smi: " + nvidia_text)
def build_quant_config(use_4bit: bool, dtype: Any) -> Any:
if not use_4bit:
return None
require_dependency("BitsAndBytesConfig", BitsAndBytesConfig is not None)
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=dtype,
)
def _available_module_suffixes(model: Any) -> set[str]:
suffixes: set[str] = set()
for name, module in model.named_modules():
if not name:
continue
parts = name.split(".")
lowered_parts = [part.lower() for part in parts]
# When modules are already LoRA-wrapped, recover the original layer suffix.
if lowered_parts and lowered_parts[-1] == "base_layer" and len(parts) >= 2:
suffixes.add(parts[-2])
for marker in ("lora_a", "lora_b", "lora_embedding_a", "lora_embedding_b", "lora_magnitude_vector"):
if marker in lowered_parts:
marker_index = lowered_parts.index(marker)
if marker_index > 0:
suffixes.add(parts[marker_index - 1])
if len(list(module.children())) > 0:
continue
suffixes.add(name.split(".")[-1])
return suffixes
def model_has_existing_lora(model: Any) -> bool:
if bool(getattr(model, "peft_config", None)):
return True
class_name = model.__class__.__name__.lower()
if "peft" in class_name:
return True
for name, module in model.named_modules():
lowered = name.lower()
if ".lora_" in lowered or lowered.endswith("lora_a") or lowered.endswith("lora_b"):
return True
if hasattr(module, "lora_A") or hasattr(module, "lora_B"):
return True
return False
def set_lora_only_trainable(model: Any) -> tuple[int, int]:
trainable_names = (
"lora_A",
"lora_B",
"lora_embedding_A",
"lora_embedding_B",
"lora_magnitude_vector",
"modules_to_save",
)
trainable_params = 0
total_params = 0
for name, param in model.named_parameters():
total_params += int(param.numel())
is_trainable = any(token in name for token in trainable_names)
param.requires_grad = is_trainable
if is_trainable:
trainable_params += int(param.numel())
return trainable_params, total_params
def resolve_lora_target_modules(model: Any, requested: list[str]) -> list[str]:
available_suffixes = _available_module_suffixes(model)
# Preferred order: decoder-style proj names, then GPT/BLOOM/Falcon-style names.
fallback_priority: list[str] = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
"c_attn",
"c_proj",
"c_fc",
]
if requested:
selected = [name for name in requested if name in available_suffixes]
if selected:
missing = sorted(set(requested) - set(selected))
if missing:
print(
"[train_worker][warn] dropping unavailable LoRA target modules: "
+ ",".join(missing)
)
return selected
print(
"[train_worker][warn] requested LoRA target modules not found for this base model; "
"falling back to auto-detected targets."
)
auto_selected = [name for name in fallback_priority if name in available_suffixes]
if auto_selected:
print(
"[train_worker][info] using auto-resolved LoRA target modules: "
+ ",".join(auto_selected)
)
return auto_selected
raise RuntimeError(
"LoRA is enabled, but no compatible target modules were found in the base model. "
f"Available leaf module suffixes: {sorted(available_suffixes)}"
)
def load_split_dataset(dataset_id: str, split: str, cache_dir: str) -> Any:
require_dependency("datasets", load_dataset is not None)
def is_retryable_dataset_error(exc: Exception) -> bool:
message = to_text(exc).lower()
if not message:
return False
retry_markers = (
"client has been closed",
"connection aborted",
"connection reset",
"connection refused",
"network is unreachable",
"temporary failure",
"timed out",
"timeout",
"503",
"429",
)
return any(marker in message for marker in retry_markers)
def resolve_local_fallback(split_name: str) -> Path | None:
name = to_text(split_name)
if not name:
return None
candidates = [
Path("data_processed_v2/parquet") / f"{name}.parquet",
REPO_ROOT / "data_processed_v2" / "parquet" / f"{name}.parquet",
REPO_ROOT / "data_processed" / "parquet" / f"{name}.parquet",
]
for candidate in candidates:
if candidate.exists():
return candidate
return None
attempts = 3
last_exc: Exception | None = None
for attempt in range(1, attempts + 1):
try:
return load_dataset(dataset_id, split=split, cache_dir=cache_dir)
except Exception as exc:
last_exc = exc
if attempt < attempts and is_retryable_dataset_error(exc):
wait_seconds = attempt * 2
print(
f"[train_worker][warn] failed to load '{dataset_id}:{split}' "
f"(attempt {attempt}/{attempts}): {exc}; retrying in {wait_seconds}s"
)
time.sleep(wait_seconds)
continue
break
fallback = resolve_local_fallback(split)
if fallback is not None:
print(
f"[train_worker][warn] failed to load '{dataset_id}:{split}' ({last_exc}); "
f"using local fallback {fallback}"
)
return load_dataset(
"parquet",
data_files={split: str(fallback)},
split=split,
cache_dir=cache_dir,
)
raise RuntimeError(
f"Unable to load dataset split '{dataset_id}:{split}' and no local fallback exists."
) from last_exc
def latest_checkpoint_dir(output_dir: Path) -> Path | None:
checkpoints: list[tuple[int, Path]] = []
for candidate in output_dir.glob("checkpoint-*"):
if not candidate.is_dir():
continue
suffix = candidate.name.replace("checkpoint-", "", 1)
try:
step = int(suffix)
except Exception:
continue
checkpoints.append((step, candidate))
if not checkpoints:
return None
checkpoints.sort(key=lambda item: item[0])
return checkpoints[-1][1]
def latest_checkpoint_in_sibling_runs(output_dir: Path) -> Path | None:
runs_root = output_dir.parent
if not runs_root.exists():
return None
checkpoints: list[tuple[float, int, Path]] = []
for run_dir in runs_root.iterdir():
if not run_dir.is_dir() or run_dir == output_dir:
continue
for candidate in run_dir.glob("checkpoint-*"):
if not candidate.is_dir():
continue
suffix = candidate.name.replace("checkpoint-", "", 1)
try:
step = int(suffix)
except Exception:
continue
try:
mtime = candidate.stat().st_mtime
except Exception:
mtime = 0.0
checkpoints.append((mtime, step, candidate))
if not checkpoints:
return None
checkpoints.sort(key=lambda item: (item[0], item[1]))
return checkpoints[-1][2]
def checkpoint_resume_compatible(checkpoint_dir: Path) -> tuple[bool, str]:
if not checkpoint_dir.exists():
return False, "path does not exist"
if not checkpoint_dir.is_dir():
return False, "path is not a directory"
full_model_markers = (
"model.safetensors",
"pytorch_model.bin",
"model.safetensors.index.json",
"pytorch_model.bin.index.json",
)
if any((checkpoint_dir / marker).exists() for marker in full_model_markers):
return True, ""
adapter_markers = (
"adapter_model.safetensors",
"adapter_model.bin",
"adapter_config.json",
)
if any((checkpoint_dir / marker).exists() for marker in adapter_markers):
return (
False,
"adapter-only checkpoint (missing full-model checkpoint files required by Trainer resume)",
)
return False, "missing model checkpoint files"
def resolve_resume_checkpoint(value: str | None, output_dir: Path) -> str | None:
requested = to_text(value).lower()
if requested in {"", "none", "false", "no"}:
return None
if requested in {"auto", "latest"}:
latest = latest_checkpoint_dir(output_dir)
if latest is not None:
compatible, reason = checkpoint_resume_compatible(latest)
if compatible:
return str(latest)
print(
"[train_worker][warn] auto-resume skipped latest checkpoint "
f"'{latest}' ({reason})."
)
sibling_latest = latest_checkpoint_in_sibling_runs(output_dir=output_dir)
if sibling_latest is not None:
compatible, reason = checkpoint_resume_compatible(sibling_latest)
if compatible:
print(
"[train_worker][info] auto-resume fallback selected sibling checkpoint: "
f"{sibling_latest}"
)
return str(sibling_latest)
print(
"[train_worker][warn] auto-resume skipped sibling checkpoint "
f"'{sibling_latest}' ({reason})."
)
return None
candidate = Path(to_text(value))
if not candidate.is_absolute():
candidate = output_dir / candidate
if candidate.exists():
compatible, reason = checkpoint_resume_compatible(candidate)
if compatible:
return str(candidate)
raise RuntimeError(
f"Requested resume checkpoint is not trainer-resume compatible ({reason}): {candidate}"
)
raise RuntimeError(f"Requested resume checkpoint does not exist: {candidate}")
def resolve_schedule_weights(
ce_weight_start: float,
ce_weight_end: float,
kd_weight_start: float,
kd_weight_end: float,
) -> tuple[float, float, float, float]:
start_total = float(ce_weight_start) + float(kd_weight_start)
end_total = float(ce_weight_end) + float(kd_weight_end)
if start_total <= 0 or end_total <= 0:
raise ValueError("Distillation CE/KD weights must sum to positive values.")
return (
float(ce_weight_start) / start_total,
float(ce_weight_end) / end_total,
float(kd_weight_start) / start_total,
float(kd_weight_end) / end_total,
)
class DistillationTrainer(TrainerBase):
def __init__(
self,
*args: Any,
teacher_models: list[Any],
temperature_start: float,
temperature_end: float,
ce_weight_start: float,
ce_weight_end: float,
kd_weight_start: float,
kd_weight_end: float,
**kwargs: Any,
):
require_dependency("Trainer", Trainer is not None)
super().__init__(*args, **kwargs)
if not teacher_models:
raise ValueError("teacher_models must contain at least one teacher.")
self.teacher_models = teacher_models
self.temperature_start = float(temperature_start)
self.temperature_end = float(temperature_end)
self.ce_weight_start = float(ce_weight_start)
self.ce_weight_end = float(ce_weight_end)
self.kd_weight_start = float(kd_weight_start)
self.kd_weight_end = float(kd_weight_end)
self._latest_distill_metrics: dict[str, float] = {}
for teacher in self.teacher_models:
teacher.eval()
for param in teacher.parameters():
param.requires_grad = False
def _progress(self) -> float:
max_steps = int(getattr(self.state, "max_steps", 0) or 0)
if max_steps <= 1:
return 0.0
step = float(getattr(self.state, "global_step", 0) or 0)
return max(0.0, min(1.0, step / float(max_steps)))
@staticmethod
def _interp(start: float, end: float, progress: float) -> float:
return start + (end - start) * progress
def _teacher_forward(self, teacher: Any, input_ids: Any, attention_mask: Any) -> Any:
try:
return teacher(input_ids=input_ids, attention_mask=attention_mask).logits
except Exception:
teacher_device = next(teacher.parameters()).device
out = teacher(
input_ids=input_ids.to(teacher_device),
attention_mask=attention_mask.to(teacher_device),
).logits
return out.to(input_ids.device)
def _teacher_logits(self, input_ids: Any, attention_mask: Any) -> Any:
accum = None
for teacher in self.teacher_models:
logits = self._teacher_forward(teacher, input_ids, attention_mask)
accum = logits if accum is None else (accum + logits)
return accum / float(len(self.teacher_models))
def compute_loss(
self,
model: Any,
inputs: dict[str, Any],
return_outputs: bool = False,
num_items_in_batch: int | None = None,
) -> Any:
del num_items_in_batch
require_dependency("torch.nn.functional", F is not None)
labels = inputs["labels"]
progress = self._progress()
temperature = max(1e-6, self._interp(self.temperature_start, self.temperature_end, progress))
ce_weight = self._interp(self.ce_weight_start, self.ce_weight_end, progress)
kd_weight = self._interp(self.kd_weight_start, self.kd_weight_end, progress)
student_out = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
labels=labels,
)
ce_loss = student_out.loss
with torch.no_grad():
teacher_logits = self._teacher_logits(inputs["input_ids"], inputs["attention_mask"])
student_logits = student_out.logits
student_shift = student_logits[:, :-1, :]
teacher_shift = teacher_logits[:, :-1, :]
labels_shift = labels[:, 1:]
if student_shift.shape[-1] != teacher_shift.shape[-1]:
vocab = min(int(student_shift.shape[-1]), int(teacher_shift.shape[-1]))
student_shift = student_shift[:, :, :vocab]
teacher_shift = teacher_shift[:, :, :vocab]
active = labels_shift.ne(-100)
if active.any():
s = student_shift[active]
t = teacher_shift[active]
kd_loss = F.kl_div(
F.log_softmax(s / temperature, dim=-1),
F.softmax(t / temperature, dim=-1),
reduction="batchmean",
) * (temperature * temperature)
else:
kd_loss = torch.tensor(0.0, device=ce_loss.device)
self._latest_distill_metrics = {
"distill_ce_loss": float(ce_loss.detach().float().item()),
"distill_kd_loss": float(kd_loss.detach().float().item()),
"distill_temperature": float(temperature),
"distill_ce_weight": float(ce_weight),
"distill_kd_weight": float(kd_weight),
}
loss = ce_weight * ce_loss + kd_weight * kd_loss
if return_outputs:
student_out.loss = loss
return loss, student_out
return loss
def log(self, logs: dict[str, float], *args: Any, **kwargs: Any) -> None:
merged = dict(logs)
if self._latest_distill_metrics:
merged.update(self._latest_distill_metrics)
super().log(merged, *args, **kwargs)
def estimate_total_train_steps(
train_rows: int,
batch_size: int,
grad_accum: int,
epochs: float,
) -> int:
effective_batch = max(1, int(batch_size) * int(grad_accum))
steps_per_epoch = max(1, math.ceil(int(train_rows) / effective_batch))
return max(1, math.ceil(steps_per_epoch * float(epochs)))
def build_teacher_models(
teacher_names: list[str],
teacher_dtype: Any,
trust_remote_code: bool,
attn_implementation: str,
use_4bit: bool,
using_cuda: bool,
) -> list[Any]:
require_dependency("AutoModelForCausalLM", AutoModelForCausalLM is not None)
if not teacher_names:
raise RuntimeError("Distillation is enabled but no teacher model was configured.")
dtype_key = preferred_loader_dtype_key()
teachers: list[Any] = []
for teacher_name in teacher_names:
teacher_kwargs: dict[str, Any] = {
"trust_remote_code": bool(trust_remote_code),
}
teacher_kwargs[dtype_key] = teacher_dtype
quant_cfg = build_quant_config(use_4bit=bool(use_4bit), dtype=teacher_dtype)
if quant_cfg is not None:
teacher_kwargs["quantization_config"] = quant_cfg
teacher_kwargs["device_map"] = "auto"
elif using_cuda:
teacher_kwargs["device_map"] = "auto"
if attn_implementation:
teacher_kwargs["attn_implementation"] = attn_implementation
try:
teacher = AutoModelForCausalLM.from_pretrained(teacher_name, **teacher_kwargs)
except Exception as exc:
if attn_implementation and "attn_implementation" in teacher_kwargs:
print(
f"[train_worker][warn] teacher '{teacher_name}' failed with "
f"attn_implementation='{attn_implementation}' ({exc}); "
"retrying with default attention backend"
)
teacher_kwargs.pop("attn_implementation", None)
teacher = AutoModelForCausalLM.from_pretrained(teacher_name, **teacher_kwargs)
else:
raise
teacher.eval()
for param in teacher.parameters():
param.requires_grad = False
teachers.append(teacher)
return teachers
def parse_args() -> argparse.Namespace:
default_use_4bit = env_bool("UMSR_USE_4BIT", True)
default_use_4bit_teacher = env_bool("UMSR_USE_4BIT_TEACHER", True)
default_lora_enabled = env_bool("UMSR_LORA_ENABLED", True)
default_grad_ckpt = env_bool("UMSR_GRADIENT_CHECKPOINTING", True)
default_distill_enabled = env_bool("UMSR_DISTILL_ENABLED", True)
default_enforce_inhouse_models = env_bool("UMSR_ENFORCE_INHOUSE_MODELS", True)
default_native_trainer_mode = env_bool("UMSR_NATIVE_TRAINER_MODE", True)
default_native_strict_mode = env_bool("UMSR_NATIVE_STRICT_MODE", False)
parser = argparse.ArgumentParser(description="Train and optionally push an autonomous UMSR run")
parser.add_argument("--dataset-id", default=os.environ.get("UMSR_DATASET_ID", "NorthernTribe-Research/UMSR-v1"))
parser.add_argument("--train-split", default=os.environ.get("UMSR_TRAIN_SPLIT", "train"))
parser.add_argument("--eval-split", default=os.environ.get("UMSR_EVAL_SPLIT", "validation"))
parser.add_argument("--min-quality", type=float, default=float(os.environ.get("UMSR_MIN_QUALITY", "0.72")))
parser.add_argument("--model-name", default=os.environ.get("UMSR_BASE_MODEL", "NorthernTribe-Research/UMSR-Reasoner-7B"))
parser.add_argument(
"--teacher-model",
default=os.environ.get("UMSR_TEACHER_MODEL", "NorthernTribe-Research/UMSR-Reasoner-7B"),
)
parser.add_argument("--model-dtype", default=os.environ.get("UMSR_MODEL_DTYPE", "bfloat16"))
parser.add_argument("--teacher-dtype", default=os.environ.get("UMSR_TEACHER_DTYPE", "bfloat16"))
parser.add_argument(
"--attn-implementation",
default=os.environ.get("UMSR_ATTN_IMPLEMENTATION", ""),
)
parser.add_argument("--distill-enabled", dest="distill_enabled", action="store_true")
parser.add_argument("--no-distill-enabled", dest="distill_enabled", action="store_false")
parser.set_defaults(distill_enabled=default_distill_enabled)
parser.add_argument("--enforce-inhouse-models", dest="enforce_inhouse_models", action="store_true")
parser.add_argument("--allow-external-models", dest="enforce_inhouse_models", action="store_false")
parser.set_defaults(enforce_inhouse_models=default_enforce_inhouse_models)
parser.add_argument("--use-4bit", dest="use_4bit", action="store_true")
parser.add_argument("--no-use-4bit", dest="use_4bit", action="store_false")
parser.set_defaults(use_4bit=default_use_4bit)
parser.add_argument("--use-4bit-teacher", dest="use_4bit_teacher", action="store_true")
parser.add_argument("--no-use-4bit-teacher", dest="use_4bit_teacher", action="store_false")
parser.set_defaults(use_4bit_teacher=default_use_4bit_teacher)
parser.add_argument("--lora-enabled", dest="lora_enabled", action="store_true")
parser.add_argument("--no-lora-enabled", dest="lora_enabled", action="store_false")
parser.set_defaults(lora_enabled=default_lora_enabled)
parser.add_argument("--lora-r", type=int, default=int(os.environ.get("UMSR_LORA_R", "32")))
parser.add_argument("--lora-alpha", type=int, default=int(os.environ.get("UMSR_LORA_ALPHA", "64")))
parser.add_argument("--lora-dropout", type=float, default=float(os.environ.get("UMSR_LORA_DROPOUT", "0.05")))
parser.add_argument(
"--lora-target-modules",
default=os.environ.get(
"UMSR_LORA_TARGET_MODULES",
"q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj",
),
)
parser.add_argument("--repo-id", default=os.environ.get("UMSR_MODEL_REPO_ID", "NorthernTribe-Research/UMSR-Reasoner-7B"))
parser.add_argument("--output-dir", default="runs/latest")
parser.add_argument("--max-train-samples", type=int, default=int(os.environ.get("UMSR_MAX_TRAIN_SAMPLES", "256")))
parser.add_argument("--max-eval-samples", type=int, default=int(os.environ.get("UMSR_MAX_EVAL_SAMPLES", "64")))
parser.add_argument("--num-train-epochs", type=float, default=float(os.environ.get("UMSR_NUM_TRAIN_EPOCHS", "1")))
parser.add_argument("--learning-rate", type=float, default=float(os.environ.get("UMSR_LEARNING_RATE", "1e-4")))
parser.add_argument("--weight-decay", type=float, default=float(os.environ.get("UMSR_WEIGHT_DECAY", "0.0")))
parser.add_argument("--warmup-ratio", type=float, default=float(os.environ.get("UMSR_WARMUP_RATIO", "0.03")))
parser.add_argument("--warmup-steps", type=int, default=int(os.environ.get("UMSR_WARMUP_STEPS", "0")))
parser.add_argument("--per-device-train-batch-size", type=int, default=int(os.environ.get("UMSR_BATCH_SIZE", "1")))
parser.add_argument("--per-device-eval-batch-size", type=int, default=int(os.environ.get("UMSR_EVAL_BATCH_SIZE", "1")))
parser.add_argument("--gradient-accumulation-steps", type=int, default=int(os.environ.get("UMSR_GRAD_ACCUM", "1")))
parser.add_argument("--max-length", type=int, default=int(os.environ.get("UMSR_MAX_LENGTH", "512")))
parser.add_argument("--logging-steps", type=int, default=int(os.environ.get("UMSR_LOGGING_STEPS", "1")))
parser.add_argument("--eval-steps", type=int, default=int(os.environ.get("UMSR_EVAL_STEPS", "25")))
parser.add_argument("--save-steps", type=int, default=int(os.environ.get("UMSR_SAVE_STEPS", "25")))
parser.add_argument("--save-total-limit", type=int, default=int(os.environ.get("UMSR_SAVE_TOTAL_LIMIT", "4")))
parser.add_argument("--seed", type=int, default=int(os.environ.get("UMSR_SEED", "42")))
parser.add_argument("--temperature-start", type=float, default=float(os.environ.get("UMSR_TEMPERATURE_START", "2.5")))
parser.add_argument("--temperature-end", type=float, default=float(os.environ.get("UMSR_TEMPERATURE_END", "1.2")))
parser.add_argument("--ce-weight-start", type=float, default=float(os.environ.get("UMSR_CE_WEIGHT_START", "0.35")))
parser.add_argument("--ce-weight-end", type=float, default=float(os.environ.get("UMSR_CE_WEIGHT_END", "0.5")))
parser.add_argument("--kd-weight-start", type=float, default=float(os.environ.get("UMSR_KD_WEIGHT_START", "0.65")))
parser.add_argument("--kd-weight-end", type=float, default=float(os.environ.get("UMSR_KD_WEIGHT_END", "0.5")))
parser.add_argument(
"--resume-from-checkpoint",
default=os.environ.get("UMSR_RESUME_FROM_CHECKPOINT", "auto"),
help="Use 'auto' to continue from the latest checkpoint in output-dir if available.",
)
parser.add_argument("--gradient-checkpointing", dest="gradient_checkpointing", action="store_true")
parser.add_argument("--no-gradient-checkpointing", dest="gradient_checkpointing", action="store_false")
parser.set_defaults(gradient_checkpointing=default_grad_ckpt)
parser.add_argument(
"--tie-word-embeddings",
action="store_true",
help="Keep embedding and lm_head weights tied in the model config.",
)
parser.add_argument("--native-trainer-mode", dest="native_trainer_mode", action="store_true")
parser.add_argument("--no-native-trainer-mode", dest="native_trainer_mode", action="store_false")
parser.set_defaults(native_trainer_mode=default_native_trainer_mode)
parser.add_argument("--native-strict-mode", dest="native_strict_mode", action="store_true")
parser.add_argument("--no-native-strict-mode", dest="native_strict_mode", action="store_false")
parser.set_defaults(native_strict_mode=default_native_strict_mode)
parser.add_argument(
"--required-bins",
default=os.environ.get("UMSR_REQUIRED_BINS", "bash,python3,git,curl"),
help="Comma-separated native binaries required for the runtime preflight.",
)
parser.add_argument("--token-env", default=os.environ.get("UMSR_TOKEN_ENV", "HF_TOKEN"))
parser.add_argument("--push-to-hub", action="store_true")
return parser.parse_args()
def main() -> None:
args = parse_args()
require_dependency("torch", torch is not None)
require_dependency("datasets", load_dataset is not None)
require_dependency("transformers", all(
dep is not None
for dep in [AutoConfig, AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments]
))
require_dependency("accelerate>=1.1.0", accelerate is not None)
if accelerate is not None:
current_accelerate = parse_version_tuple(getattr(accelerate, "__version__", "0.0.0"))
if current_accelerate < (1, 1, 0):
raise RuntimeError(
"accelerate>=1.1.0 is required by Trainer. "
f"Found accelerate=={getattr(accelerate, '__version__', 'unknown')}."
)
if args.push_to_hub:
require_dependency("huggingface_hub", HfApi is not None)
if set_seed is not None:
set_seed(int(args.seed))
required_bins = parse_csv_values(args.required_bins)
runtime_system = collect_runtime_system_snapshot(
required_bins=required_bins,
native_mode=bool(args.native_trainer_mode),
)
log_runtime_system_snapshot(runtime_system)
if bool(args.native_trainer_mode) and bool(args.native_strict_mode):
missing = runtime_system.get("missing_required_bins")
if isinstance(missing, list) and missing:
raise RuntimeError(
"Native strict mode failed: missing required binaries: " + ",".join(str(item) for item in missing)
)
teacher_names = parse_model_list(args.teacher_model)
validate_model_reference(
args.model_name,
role="base",
enforce_inhouse=bool(args.enforce_inhouse_models),
)
for teacher_name in teacher_names:
validate_model_reference(
teacher_name,
role="teacher",
enforce_inhouse=bool(args.enforce_inhouse_models),
)
runtime_hardware = probe_runtime_hardware()
log_runtime_hardware(runtime_hardware)
using_cuda = bool(runtime_hardware.get("cuda_available", False))
using_mps = bool(runtime_hardware.get("mps_available", False))
requested_use_4bit = bool(args.use_4bit)
effective_use_4bit = bool(requested_use_4bit and using_cuda)
bf16_supported = bool(
using_cuda
and hasattr(torch.cuda, "is_bf16_supported")
and torch.cuda.is_bf16_supported()
)
fp16_enabled = bool(using_cuda and not bf16_supported)
bf16_enabled = bool(using_cuda and bf16_supported)
device_label = "cuda" if using_cuda else ("mps" if using_mps else "cpu")
if using_cuda:
# Improve CUDA throughput where supported, while staying numerically stable.
try:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
except Exception:
pass
elif requested_use_4bit:
print("[train_worker][warn] 4-bit quantization requested without CUDA; using non-quantized model load.")
cache_dir = dataset_cache_dir()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
write_json(output_dir / "system_snapshot.json", runtime_system)
write_json(
output_dir / "live_progress.json",
{
"updated_at": datetime.now(timezone.utc).isoformat(),
"status": "initializing",
"message": "preparing datasets and models",
"distill_enabled": bool(args.distill_enabled),
"runtime_system": runtime_system,
"runtime_hardware": runtime_hardware,
"global_step": 0,
"max_steps": 0,
"epoch": 0.0,
"metrics": {},
},
)
train_ds = load_split_dataset(args.dataset_id, split=args.train_split, cache_dir=cache_dir)
train_ds = train_ds.filter(lambda row: quality_ok(row, float(args.min_quality)))
if args.max_train_samples > 0 and len(train_ds) > args.max_train_samples:
train_ds = train_ds.shuffle(seed=int(args.seed)).select(range(args.max_train_samples))
eval_ds = None
if args.eval_split:
eval_ds = load_split_dataset(args.dataset_id, split=args.eval_split, cache_dir=cache_dir)
eval_ds = eval_ds.filter(lambda row: quality_ok(row, float(args.min_quality)))
if args.max_eval_samples > 0 and len(eval_ds) > args.max_eval_samples:
eval_ds = eval_ds.shuffle(seed=int(args.seed)).select(range(args.max_eval_samples))
train_ds = train_ds.map(
lambda row: {"text": format_text(row)},
remove_columns=train_ds.column_names,
)
if eval_ds is not None:
eval_ds = eval_ds.map(
lambda row: {"text": format_text(row)},
remove_columns=eval_ds.column_names,
)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def tokenize_batch(batch: dict[str, Any]) -> dict[str, Any]:
return tokenizer(
batch["text"],
truncation=True,
max_length=int(args.max_length),
padding=False,
)
train_tokenized = train_ds.map(tokenize_batch, batched=True, remove_columns=["text"])
eval_tokenized = None
if eval_ds is not None:
eval_tokenized = eval_ds.map(tokenize_batch, batched=True, remove_columns=["text"])
model_dtype = dtype_from_name(args.model_dtype)
teacher_dtype = dtype_from_name(args.teacher_dtype)
lora_target_modules = parse_target_modules(args.lora_target_modules)
requested_teacher_use_4bit = bool(args.use_4bit_teacher)
effective_teacher_use_4bit = bool(requested_teacher_use_4bit and using_cuda)
if requested_teacher_use_4bit and not using_cuda:
print("[train_worker][warn] teacher 4-bit quantization requested without CUDA; using non-quantized teacher load.")
ce_weight_start, ce_weight_end, kd_weight_start, kd_weight_end = resolve_schedule_weights(
ce_weight_start=float(args.ce_weight_start),
ce_weight_end=float(args.ce_weight_end),
kd_weight_start=float(args.kd_weight_start),
kd_weight_end=float(args.kd_weight_end),
)
model_config = AutoConfig.from_pretrained(args.model_name)
if hasattr(model_config, "tie_word_embeddings"):
model_config.tie_word_embeddings = bool(args.tie_word_embeddings)
model_kwargs: dict[str, Any] = {"config": model_config}
dtype_key = preferred_loader_dtype_key()
attn_impl = to_text(args.attn_implementation)
if attn_impl:
model_kwargs["attn_implementation"] = attn_impl
student_quant_config = build_quant_config(use_4bit=effective_use_4bit, dtype=model_dtype)
if student_quant_config is not None:
model_kwargs["quantization_config"] = student_quant_config
model_kwargs["device_map"] = "auto"
else:
model_kwargs[dtype_key] = model_dtype
if using_cuda:
model_kwargs["device_map"] = "auto"
try:
model = AutoModelForCausalLM.from_pretrained(args.model_name, **model_kwargs)
except Exception as exc:
if attn_impl and "attn_implementation" in model_kwargs:
print(
f"[train_worker][warn] failed with attn_implementation='{attn_impl}' ({exc}); "
"retrying with default attention backend"
)
model_kwargs.pop("attn_implementation", None)
model = AutoModelForCausalLM.from_pretrained(args.model_name, **model_kwargs)
else:
raise
model.config.use_cache = False
model.config.pad_token_id = tokenizer.pad_token_id
if bool(args.gradient_checkpointing):
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
if effective_use_4bit:
require_dependency("peft.prepare_model_for_kbit_training", prepare_model_for_kbit_training is not None)
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=bool(args.gradient_checkpointing),
)
if bool(args.lora_enabled):
require_dependency("peft.LoraConfig", LoraConfig is not None)
require_dependency("peft.get_peft_model", get_peft_model is not None)
if model_has_existing_lora(model):
print(
"[train_worker][warn] base model already has LoRA adapters attached; "
"skipping adapter reinjection and training existing adapters."
)
trainable, total = set_lora_only_trainable(model)
pct = (100.0 * float(trainable) / float(total)) if total > 0 else 0.0
print(
f"[train_worker][info] trainable params set to existing LoRA adapters: "
f"{trainable}/{total} ({pct:.4f}%)"
)
if not lora_target_modules:
lora_target_modules = ["preloaded-adapter"]
else:
lora_target_modules = resolve_lora_target_modules(model=model, requested=lora_target_modules)
lora_cfg = LoraConfig(
r=int(args.lora_r),
lora_alpha=int(args.lora_alpha),
lora_dropout=float(args.lora_dropout),
bias="none",
task_type="CAUSAL_LM",
target_modules=lora_target_modules,
)
model = get_peft_model(model, lora_cfg)
if hasattr(model, "print_trainable_parameters"):
model.print_trainable_parameters()
teacher_models: list[Any] = []
if bool(args.distill_enabled):
require_dependency("torch.nn.functional", F is not None)
teacher_models = build_teacher_models(
teacher_names=teacher_names,
teacher_dtype=teacher_dtype,
trust_remote_code=False,
attn_implementation=attn_impl,
use_4bit=effective_teacher_use_4bit,
using_cuda=using_cuda,
)
elif teacher_names:
print("[train_worker][info] teacher model configured but distillation is disabled; running CE-only SFT mode.")
total_steps = estimate_total_train_steps(
train_rows=len(train_tokenized),
batch_size=int(args.per_device_train_batch_size),
grad_accum=int(args.gradient_accumulation_steps),
epochs=float(args.num_train_epochs),
)
requested_warmup_steps = max(0, int(args.warmup_steps))
warmup_ratio = max(0.0, float(args.warmup_ratio))
derived_warmup_steps = max(0, int(round(total_steps * warmup_ratio))) if warmup_ratio > 0 else 0
effective_warmup_steps = requested_warmup_steps if requested_warmup_steps > 0 else derived_warmup_steps
training_kwargs: dict[str, Any] = {
"output_dir": str(output_dir),
"run_name": "umsr-autonomous-space",
"num_train_epochs": float(args.num_train_epochs),
"learning_rate": float(args.learning_rate),
"weight_decay": float(args.weight_decay),
"warmup_steps": int(effective_warmup_steps),
"per_device_train_batch_size": int(args.per_device_train_batch_size),
"per_device_eval_batch_size": int(args.per_device_eval_batch_size),
"gradient_accumulation_steps": int(args.gradient_accumulation_steps),
"logging_steps": int(args.logging_steps),
"disable_tqdm": True,
"save_steps": int(args.save_steps),
"save_total_limit": max(1, int(args.save_total_limit)),
"report_to": ["none"],
"remove_unused_columns": False,
"seed": int(args.seed),
"fp16": fp16_enabled,
"bf16": bf16_enabled,
"gradient_checkpointing": bool(args.gradient_checkpointing),
"dataloader_pin_memory": using_cuda,
"optim": "paged_adamw_8bit" if effective_use_4bit else "adamw_torch",
}
training_arg_params = set(inspect.signature(TrainingArguments.__init__).parameters.keys())
eval_key = "eval_strategy" if "eval_strategy" in training_arg_params else "evaluation_strategy"
if "logging_first_step" in training_arg_params:
training_kwargs["logging_first_step"] = True
if eval_tokenized is not None:
training_kwargs[eval_key] = "steps"
training_kwargs["eval_steps"] = int(args.eval_steps)
else:
training_kwargs[eval_key] = "no"
train_args = TrainingArguments(**training_kwargs)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
callbacks: list[Any] | None = None
if TrainerCallback is not None:
callbacks = [
LiveProgressCallback(
output_dir=output_dir,
distill_enabled=bool(args.distill_enabled),
runtime_hardware=runtime_hardware,
runtime_system=runtime_system,
)
]
else:
print("[train_worker][warn] TrainerCallback unavailable; live progress telemetry disabled.")
resume_checkpoint = resolve_resume_checkpoint(args.resume_from_checkpoint, output_dir=output_dir)
if resume_checkpoint:
print(f"[train_worker][info] resuming from checkpoint: {resume_checkpoint}")
run_config = {
"dataset_id": args.dataset_id,
"train_split": args.train_split,
"eval_split": args.eval_split,
"min_quality": float(args.min_quality),
"student_model": args.model_name,
"teacher_models": teacher_names,
"distill_enabled": bool(args.distill_enabled),
"enforce_inhouse_models": bool(args.enforce_inhouse_models),
"model_dtype": to_text(args.model_dtype).lower(),
"teacher_dtype": to_text(args.teacher_dtype).lower(),
"use_4bit_student_requested": requested_use_4bit,
"use_4bit_student_effective": effective_use_4bit,
"use_4bit_teacher_requested": requested_teacher_use_4bit,
"use_4bit_teacher_effective": effective_teacher_use_4bit,
"temperature_start": float(args.temperature_start),
"temperature_end": float(args.temperature_end),
"ce_weight_start": float(ce_weight_start),
"ce_weight_end": float(ce_weight_end),
"kd_weight_start": float(kd_weight_start),
"kd_weight_end": float(kd_weight_end),
"lora_enabled": bool(args.lora_enabled),
"lora_r": int(args.lora_r),
"lora_alpha": int(args.lora_alpha),
"lora_dropout": float(args.lora_dropout),
"lora_target_modules": lora_target_modules,
"save_total_limit": max(1, int(args.save_total_limit)),
"resume_from_checkpoint": resume_checkpoint or "",
"output_dir": str(output_dir),
"system_snapshot_path": str(output_dir / "system_snapshot.json"),
"target_repo_id": args.repo_id,
"native_trainer_mode": bool(args.native_trainer_mode),
"native_strict_mode": bool(args.native_strict_mode),
"required_bins": required_bins,
"runtime_system": runtime_system,
"runtime_hardware": runtime_hardware,
"created_at": datetime.now(timezone.utc).isoformat(),
}
write_json(output_dir / "effective_run_config.json", run_config)
if bool(args.distill_enabled):
trainer = DistillationTrainer(
model=model,
teacher_models=teacher_models,
args=train_args,
train_dataset=train_tokenized,
eval_dataset=eval_tokenized,
data_collator=data_collator,
callbacks=callbacks,
temperature_start=max(1e-6, float(args.temperature_start)),
temperature_end=max(1e-6, float(args.temperature_end)),
ce_weight_start=float(ce_weight_start),
ce_weight_end=float(ce_weight_end),
kd_weight_start=float(kd_weight_start),
kd_weight_end=float(kd_weight_end),
)
else:
trainer = Trainer(
model=model,
args=train_args,
train_dataset=train_tokenized,
eval_dataset=eval_tokenized,
data_collator=data_collator,
callbacks=callbacks,
)
if resume_checkpoint:
train_result = trainer.train(resume_from_checkpoint=resume_checkpoint)
else:
train_result = trainer.train()
trainer.save_model()
tokenizer.save_pretrained(str(output_dir))
trainer.save_state()
train_metrics = dict(train_result.metrics)
train_metrics["train_samples"] = len(train_tokenized)
train_metrics["distill_enabled"] = bool(args.distill_enabled)
train_metrics["teacher_count"] = len(teacher_models)
train_metrics["ce_weight_start"] = float(ce_weight_start)
train_metrics["ce_weight_end"] = float(ce_weight_end)
train_metrics["kd_weight_start"] = float(kd_weight_start)
train_metrics["kd_weight_end"] = float(kd_weight_end)
train_metrics["temperature_start"] = float(args.temperature_start)
train_metrics["temperature_end"] = float(args.temperature_end)
write_json(output_dir / "metrics" / "train_metrics.json", train_metrics)
eval_metrics: dict[str, Any] = {}
if eval_tokenized is not None:
eval_metrics = dict(trainer.evaluate())
eval_metrics["eval_samples"] = len(eval_tokenized)
write_json(output_dir / "metrics" / "eval_metrics.json", eval_metrics)
summary = {
"dataset_id": args.dataset_id,
"train_rows": len(train_tokenized),
"eval_rows": len(eval_tokenized) if eval_tokenized is not None else 0,
"output_dir": str(output_dir),
"system_snapshot_path": str(output_dir / "system_snapshot.json"),
"live_progress_path": str(output_dir / "live_progress.json"),
"live_events_path": str(output_dir / "live_events.jsonl"),
"base_model": args.model_name,
"target_repo_id": args.repo_id,
"native_trainer_mode": bool(args.native_trainer_mode),
"native_strict_mode": bool(args.native_strict_mode),
"required_bins": required_bins,
"runtime_system": runtime_system,
"runtime_hardware": runtime_hardware,
"device": device_label,
"cuda_available": using_cuda,
"mps_available": using_mps,
"fp16": fp16_enabled,
"bf16": bf16_enabled,
"model_dtype": to_text(args.model_dtype).lower(),
"teacher_dtype": to_text(args.teacher_dtype).lower(),
"attn_implementation": attn_impl,
"distill_enabled": bool(args.distill_enabled),
"enforce_inhouse_models": bool(args.enforce_inhouse_models),
"teacher_models": teacher_names,
"teacher_count": len(teacher_models),
"temperature_start": float(args.temperature_start),
"temperature_end": float(args.temperature_end),
"ce_weight_start": float(ce_weight_start),
"ce_weight_end": float(ce_weight_end),
"kd_weight_start": float(kd_weight_start),
"kd_weight_end": float(kd_weight_end),
"use_4bit_requested": requested_use_4bit,
"use_4bit_effective": effective_use_4bit,
"use_4bit_teacher_requested": requested_teacher_use_4bit,
"use_4bit_teacher_effective": effective_teacher_use_4bit,
"lora_enabled": bool(args.lora_enabled),
"lora_r": int(args.lora_r),
"lora_alpha": int(args.lora_alpha),
"lora_dropout": float(args.lora_dropout),
"lora_target_modules": lora_target_modules,
"save_total_limit": max(1, int(args.save_total_limit)),
"gradient_checkpointing": bool(args.gradient_checkpointing),
"warmup_ratio": float(warmup_ratio),
"requested_warmup_steps": int(requested_warmup_steps),
"warmup_steps": int(effective_warmup_steps),
"total_train_steps_estimate": int(total_steps),
"tie_word_embeddings": bool(getattr(model.config, "tie_word_embeddings", False)),
"resume_from_checkpoint": resume_checkpoint or "",
"finished_at": datetime.now(timezone.utc).isoformat(),
}
write_json(output_dir / "run_summary.json", summary)
finalize_live_progress(output_dir=output_dir, message="training finished")
(output_dir / "README.md").write_text(
model_card_text(
dataset_id=args.dataset_id,
model_id=args.repo_id,
base_model=args.model_name,
teacher_models=teacher_names,
distill_enabled=bool(args.distill_enabled),
),
encoding="utf-8",
)
if args.push_to_hub:
token = os.environ.get(args.token_env, "")
if not token:
raise RuntimeError(f"Missing token in environment variable ${args.token_env}")
api = HfApi(token=token)
api.create_repo(repo_id=args.repo_id, repo_type="model", private=False, exist_ok=True)
api.upload_folder(
repo_id=args.repo_id,
repo_type="model",
folder_path=str(output_dir),
commit_message="Autonomous Space trainer update",
ignore_patterns=["checkpoint-*", "optimizer.pt", "scheduler.pt", "rng_state.pth"],
)
print(json.dumps(summary, indent=2, sort_keys=True))
if __name__ == "__main__":
main()