| |
| """Autonomous lightweight trainer for UMSR Reasoner Space.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import inspect |
| import json |
| import math |
| import os |
| import platform |
| import re |
| import shutil |
| import subprocess |
| import sys |
| import time |
| from datetime import datetime, timezone |
| from pathlib import Path |
| from typing import Any |
|
|
| try: |
| import torch |
| import torch.nn.functional as F |
| except Exception: |
| torch = None |
| F = None |
|
|
| try: |
| from datasets import load_dataset |
| except Exception: |
| load_dataset = None |
|
|
| try: |
| from huggingface_hub import HfApi |
| except Exception: |
| HfApi = None |
|
|
| try: |
| import transformers as transformers_pkg |
| from transformers import ( |
| AutoConfig, |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| BitsAndBytesConfig, |
| DataCollatorForLanguageModeling, |
| TrainerCallback, |
| Trainer, |
| TrainingArguments, |
| set_seed, |
| ) |
| except Exception: |
| transformers_pkg = None |
| AutoConfig = None |
| AutoModelForCausalLM = None |
| AutoTokenizer = None |
| BitsAndBytesConfig = None |
| DataCollatorForLanguageModeling = None |
| TrainerCallback = None |
| Trainer = None |
| TrainingArguments = None |
| set_seed = None |
|
|
| try: |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
| except Exception: |
| LoraConfig = None |
| get_peft_model = None |
| prepare_model_for_kbit_training = None |
|
|
| try: |
| import accelerate |
| except Exception: |
| accelerate = None |
|
|
| TrainerBase = Trainer if Trainer is not None else object |
|
|
| SYSTEM_PROMPT = ( |
| "You are a rigorous reasoning assistant. " |
| "Solve the task step by step. " |
| "For programming tasks, provide a correct and runnable code block. " |
| "Then provide only the final answer inside " |
| "<final_answer>...</final_answer>." |
| ) |
|
|
| INHOUSE_OWNER_PREFIX = "NorthernTribe-Research/" |
| BANNED_MODEL_TOKENS = ("gpt2",) |
| CODE_TASK_HINT_RE = re.compile( |
| r"\b(" |
| r"code|python|program|programming|function|class|method|bug|debug|algorithm|" |
| r"runtime|complexity|compile|leetcode|unit test|sql|regex|script" |
| r")\b", |
| re.IGNORECASE, |
| ) |
| APP_DIR = Path(__file__).resolve().parent |
| REPO_ROOT = APP_DIR.parent.parent |
|
|
|
|
| def require_dependency(name: str, available: bool) -> None: |
| if not available: |
| raise RuntimeError(f"Missing dependency '{name}'. Install requirements.txt in the Space.") |
|
|
|
|
| def to_text(value: Any) -> str: |
| if value is None: |
| return "" |
| return str(value).strip() |
|
|
|
|
| def parse_options(value: Any) -> list[str]: |
| if isinstance(value, list): |
| return [to_text(item) for item in value if to_text(item)] |
| text = to_text(value) |
| if not text: |
| return [] |
| if "||" in text: |
| return [chunk.strip() for chunk in text.split("||") if chunk.strip()] |
| return [] |
|
|
|
|
| def parse_float(value: Any, default: float) -> float: |
| try: |
| candidate = float(value) |
| except Exception: |
| return default |
| if math.isnan(candidate) or math.isinf(candidate): |
| return default |
| return candidate |
|
|
|
|
| def env_bool(name: str, default: bool) -> bool: |
| value = os.environ.get(name) |
| if value is None: |
| return default |
| return value.strip().lower() in {"1", "true", "yes", "on"} |
|
|
|
|
| def is_local_model_path(model_ref: str) -> bool: |
| if not model_ref: |
| return False |
| try: |
| return Path(model_ref).exists() |
| except Exception: |
| return False |
|
|
|
|
| def is_inhouse_model_ref(model_ref: str) -> bool: |
| return to_text(model_ref).startswith(INHOUSE_OWNER_PREFIX) |
|
|
|
|
| def validate_model_reference(model_ref: str, role: str, enforce_inhouse: bool) -> None: |
| resolved = to_text(model_ref) |
| if not resolved: |
| raise ValueError(f"{role} model reference is empty.") |
|
|
| lowered = resolved.lower() |
| if any(token in lowered for token in BANNED_MODEL_TOKENS): |
| raise ValueError( |
| f"{role} model '{resolved}' is blocked. " |
| "Use an in-house NorthernTribe-Research model or a local checkpoint path." |
| ) |
|
|
| if enforce_inhouse: |
| if is_inhouse_model_ref(resolved) or is_local_model_path(resolved): |
| return |
| raise ValueError( |
| f"{role} model '{resolved}' is external. " |
| "In-house enforcement is enabled; use NorthernTribe-Research/* or a local path." |
| ) |
|
|
|
|
| def quality_ok(row: dict[str, Any], min_quality: float) -> bool: |
| if not to_text(row.get("problem")) or not to_text(row.get("answer")): |
| return False |
| quality = parse_float(row.get("quality_score", 1.0), 1.0) |
| return quality >= min_quality |
|
|
|
|
| def looks_like_code_task(problem: str, domain: str, options: list[str] | None = None) -> bool: |
| domain_key = to_text(domain).lower() |
| if domain_key in {"code", "coding", "programming", "software", "computer_science"}: |
| return True |
|
|
| chunks = [to_text(problem), to_text(domain)] |
| if options: |
| chunks.extend(to_text(item) for item in options) |
| haystack = " ".join(chunks).strip() |
| if not haystack: |
| return False |
| return bool(CODE_TASK_HINT_RE.search(haystack)) |
|
|
|
|
| def build_user_prompt(problem: str, options: list[str], domain: str) -> str: |
| blocks = [f"Problem:\n{problem}"] |
| if options: |
| blocks.append("Options:\n" + "\n".join(f"- {item}" for item in options)) |
| if domain: |
| blocks.append(f"Domain: {domain}") |
| code_task = looks_like_code_task(problem=problem, domain=domain, options=options) |
| instruction_lines = [ |
| "1) Think step by step.", |
| ] |
| if code_task: |
| instruction_lines.append("2) If code is required, output one runnable ```python``` block.") |
| instruction_lines.append("3) End with <final_answer>...</final_answer>.") |
| else: |
| instruction_lines.append("2) End with <final_answer>...</final_answer>.") |
| blocks.append("Instructions:\n" + "\n".join(instruction_lines)) |
| return "\n\n".join(blocks) |
|
|
|
|
| def format_text(row: dict[str, Any]) -> str: |
| problem = to_text(row.get("problem")) |
| answer = to_text(row.get("answer")) |
| reasoning = to_text(row.get("reasoning_text")) |
| domain = to_text(row.get("domain")) |
| options = parse_options(row.get("options")) |
|
|
| prompt = build_user_prompt(problem=problem, options=options, domain=domain) |
| completion_chunks: list[str] = [] |
| if reasoning: |
| completion_chunks.append(f"<reasoning>\n{reasoning}\n</reasoning>") |
| completion_chunks.append(f"<final_answer>{answer}</final_answer>") |
| completion = "\n\n".join(completion_chunks) |
| return ( |
| f"SYSTEM:\n{SYSTEM_PROMPT}\n\n" |
| f"USER:\n{prompt}\n\n" |
| f"ASSISTANT:\n{completion}" |
| ) |
|
|
|
|
| def dataset_cache_dir() -> str: |
| override = os.environ.get("HF_DATASETS_CACHE") |
| if override: |
| path = Path(override) |
| else: |
| path = Path(".hf_cache/datasets") |
| path.mkdir(parents=True, exist_ok=True) |
| return str(path) |
|
|
|
|
| def model_card_text( |
| dataset_id: str, |
| model_id: str, |
| base_model: str, |
| teacher_models: list[str], |
| distill_enabled: bool, |
| ) -> str: |
| mode_text = "teacher-student distillation" if distill_enabled else "supervised fine-tuning" |
| teacher_line = ", ".join(teacher_models) if teacher_models else "n/a" |
| return f"""--- |
| language: |
| - en |
| library_name: transformers |
| pipeline_tag: text-generation |
| datasets: |
| - {dataset_id} |
| tags: |
| - reasoning |
| - structured-output |
| - instruction-following |
| - math |
| - logic |
| - science |
| --- |
| |
| # UMSR-Reasoner-7B |
| |
| ## Purpose |
| |
| UMSR-Reasoner-7B is a general reasoning model designed for structured problem solving and consistent answer formatting in production and research workflows. |
| |
| Model repository: `https://huggingface.co/{model_id}` |
| Primary dataset: `https://huggingface.co/datasets/{dataset_id}` |
| |
| ## Intended Use |
| |
| Use this model for tasks that require: |
| |
| - multi-step quantitative reasoning |
| - logic and strategy-style question answering |
| - science and technical problem decomposition |
| - deterministic final-answer formatting for downstream parsers |
| |
| ## Core Capabilities |
| |
| - Produces step-aware reasoning outputs for complex prompts |
| - Handles open-form and exam-style tasks across math, logic, and science domains |
| - Supports structured response contracts for automation pipelines |
| - Works well in teacher-student continuous improvement loops |
| |
| ## Recommended Prompting |
| |
| For highest reliability, use explicit instructions about reasoning depth and enforce a final-answer tag in every response. |
| |
| Suggested system instruction: |
| |
| `Solve step by step and end with <final_answer>...</final_answer>.` |
| |
| ## Output Contract |
| |
| Required final output tag: |
| |
| `<final_answer>...</final_answer>` |
| |
| Optional reasoning tag: |
| |
| `<reasoning>...</reasoning>` |
| |
| ## Training Profile |
| |
| - Student model: `{base_model}` |
| - Training mode: `{mode_text}` |
| - Teacher model(s): `{teacher_line}` |
| |
| ## Operational Guidance |
| |
| - Prefer lower sampling temperature for deterministic workflows |
| - Validate final answers for high-stakes usage |
| - Run domain-specific evaluation before production rollout |
| |
| ## Limitations |
| |
| - May produce plausible but incorrect reasoning traces |
| - Performance varies with prompt quality and task domain |
| - Not a substitute for expert review in legal, medical, financial, or safety-critical decisions |
| """ |
|
|
|
|
| def write_json(path: Path, payload: dict[str, Any]) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") |
|
|
|
|
| def append_jsonl(path: Path, payload: dict[str, Any]) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| with open(path, "a", encoding="utf-8") as handle: |
| handle.write(json.dumps(payload, sort_keys=True) + "\n") |
|
|
|
|
| def read_json(path: Path) -> dict[str, Any]: |
| if not path.exists(): |
| return {} |
| try: |
| payload = json.loads(path.read_text(encoding="utf-8")) |
| except Exception: |
| return {} |
| return payload if isinstance(payload, dict) else {} |
|
|
|
|
| def finalize_live_progress(output_dir: Path, message: str) -> None: |
| progress_path = output_dir / "live_progress.json" |
| payload = read_json(progress_path) |
| metrics = payload.get("metrics") if isinstance(payload.get("metrics"), dict) else {} |
| payload.update( |
| { |
| "updated_at": datetime.now(timezone.utc).isoformat(), |
| "status": "completed", |
| "message": to_text(message) or "training finished", |
| "metrics": metrics, |
| } |
| ) |
| write_json(progress_path, payload) |
|
|
|
|
| def safe_float(value: Any) -> float | None: |
| if value is None: |
| return None |
| try: |
| if torch is not None and isinstance(value, torch.Tensor): |
| value = value.detach().float().item() |
| return float(value) |
| except Exception: |
| return None |
|
|
|
|
| class LiveProgressCallback(TrainerCallback if TrainerCallback is not None else object): |
| def __init__( |
| self, |
| output_dir: Path, |
| distill_enabled: bool, |
| runtime_hardware: dict[str, Any] | None = None, |
| runtime_system: dict[str, Any] | None = None, |
| ): |
| self.output_dir = output_dir |
| self.distill_enabled = bool(distill_enabled) |
| self.runtime_hardware = dict(runtime_hardware or {}) |
| self.runtime_system = dict(runtime_system or {}) |
| self.progress_path = output_dir / "live_progress.json" |
| self.events_path = output_dir / "live_events.jsonl" |
| self.latest_metrics: dict[str, float] = {} |
|
|
| def _sync_progress(self, state: Any, status: str, message: str) -> None: |
| payload = { |
| "updated_at": datetime.now(timezone.utc).isoformat(), |
| "status": status, |
| "message": to_text(message), |
| "distill_enabled": self.distill_enabled, |
| "runtime_system": self.runtime_system, |
| "runtime_hardware": self.runtime_hardware, |
| "global_step": int(getattr(state, "global_step", 0) or 0), |
| "max_steps": int(getattr(state, "max_steps", 0) or 0), |
| "epoch": safe_float(getattr(state, "epoch", None)), |
| "metrics": self.latest_metrics, |
| } |
| write_json(self.progress_path, payload) |
|
|
| def _append_event(self, state: Any, event_type: str, payload: dict[str, Any]) -> None: |
| event = { |
| "ts": datetime.now(timezone.utc).isoformat(), |
| "event": event_type, |
| "global_step": int(getattr(state, "global_step", 0) or 0), |
| "epoch": safe_float(getattr(state, "epoch", None)), |
| "payload": payload, |
| } |
| append_jsonl(self.events_path, event) |
|
|
| @staticmethod |
| def _extract_metrics(logs: dict[str, Any]) -> dict[str, float]: |
| keys = [ |
| "loss", |
| "eval_loss", |
| "learning_rate", |
| "grad_norm", |
| "epoch", |
| "distill_ce_loss", |
| "distill_kd_loss", |
| "distill_temperature", |
| "distill_ce_weight", |
| "distill_kd_weight", |
| ] |
| metrics: dict[str, float] = {} |
| for key in keys: |
| value = safe_float(logs.get(key)) |
| if value is not None: |
| metrics[key] = value |
| return metrics |
|
|
| def on_train_begin(self, args: Any, state: Any, control: Any, **kwargs: Any) -> None: |
| del args, control, kwargs |
| self._sync_progress(state=state, status="running", message="training started") |
| self._append_event(state=state, event_type="train_begin", payload={}) |
|
|
| def on_log(self, args: Any, state: Any, control: Any, logs: dict[str, Any] | None = None, **kwargs: Any) -> None: |
| del args, control, kwargs |
| payload = logs or {} |
| metrics = self._extract_metrics(payload) |
| if metrics: |
| self.latest_metrics.update(metrics) |
| step = int(getattr(state, "global_step", 0) or 0) |
| max_steps = int(getattr(state, "max_steps", 0) or 0) |
| message = f"step {step}/{max_steps}" if max_steps > 0 else f"step {step}" |
| self._sync_progress(state=state, status="running", message=message) |
| self._append_event(state=state, event_type="log", payload=metrics) |
|
|
| def on_save(self, args: Any, state: Any, control: Any, **kwargs: Any) -> None: |
| del args, control, kwargs |
| self._sync_progress(state=state, status="running", message="checkpoint saved") |
| self._append_event(state=state, event_type="save", payload={}) |
|
|
| def on_evaluate( |
| self, |
| args: Any, |
| state: Any, |
| control: Any, |
| metrics: dict[str, Any] | None = None, |
| **kwargs: Any, |
| ) -> None: |
| del args, control, kwargs |
| values = self._extract_metrics(metrics or {}) |
| if values: |
| self.latest_metrics.update(values) |
| self._sync_progress(state=state, status="running", message="evaluation completed") |
| self._append_event(state=state, event_type="evaluate", payload=values) |
|
|
| def on_train_end(self, args: Any, state: Any, control: Any, **kwargs: Any) -> None: |
| del args, control, kwargs |
| self._sync_progress(state=state, status="completed", message="training finished") |
| self._append_event(state=state, event_type="train_end", payload={}) |
|
|
|
|
| def parse_version_tuple(version_text: str) -> tuple[int, int, int]: |
| numbers: list[int] = [] |
| for segment in str(version_text).split("."): |
| digits = "" |
| for ch in segment: |
| if ch.isdigit(): |
| digits += ch |
| else: |
| break |
| if not digits: |
| numbers.append(0) |
| else: |
| numbers.append(int(digits)) |
| if len(numbers) == 3: |
| break |
| while len(numbers) < 3: |
| numbers.append(0) |
| return tuple(numbers[:3]) |
|
|
|
|
| def probe_runtime_hardware() -> dict[str, Any]: |
| info: dict[str, Any] = { |
| "torch_available": bool(torch is not None), |
| "torch_version": to_text(getattr(torch, "__version__", "unknown")) if torch is not None else "missing", |
| "cuda_available": False, |
| "cuda_device_count": 0, |
| "cuda_device_0": "", |
| "cuda_compute_capability_0": "", |
| "cuda_total_memory_gb_0": None, |
| "mps_available": False, |
| } |
| if torch is None: |
| return info |
|
|
| try: |
| cuda_available = bool(torch.cuda.is_available()) |
| except Exception: |
| cuda_available = False |
| info["cuda_available"] = cuda_available |
|
|
| try: |
| device_count = int(torch.cuda.device_count()) |
| except Exception: |
| device_count = 0 |
| info["cuda_device_count"] = max(0, device_count) |
|
|
| if cuda_available and device_count > 0: |
| try: |
| info["cuda_device_0"] = to_text(torch.cuda.get_device_name(0)) |
| except Exception: |
| info["cuda_device_0"] = "" |
| try: |
| props = torch.cuda.get_device_properties(0) |
| info["cuda_compute_capability_0"] = f"{int(props.major)}.{int(props.minor)}" |
| info["cuda_total_memory_gb_0"] = round(float(props.total_memory) / float(1024 ** 3), 2) |
| except Exception: |
| info["cuda_compute_capability_0"] = "" |
| info["cuda_total_memory_gb_0"] = None |
|
|
| try: |
| info["mps_available"] = bool(hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) |
| except Exception: |
| info["mps_available"] = False |
|
|
| return info |
|
|
|
|
| def log_runtime_hardware(info: dict[str, Any]) -> None: |
| cuda_available = bool(info.get("cuda_available", False)) |
| device_count = int(info.get("cuda_device_count", 0) or 0) |
| gpu_name = to_text(info.get("cuda_device_0")) |
| print(f"CUDA available: {cuda_available}") |
| print(f"CUDA device count: {device_count}") |
| if cuda_available and device_count > 0: |
| print(f"GPU: {gpu_name or 'unknown'}") |
|
|
| print( |
| f"[train_worker][runtime] torch={to_text(info.get('torch_version')) or 'unknown'} " |
| f"cuda={cuda_available} devices={device_count} " |
| f"mps={bool(info.get('mps_available', False))}" |
| ) |
| if gpu_name: |
| details = [f"name={gpu_name}"] |
| capability = to_text(info.get("cuda_compute_capability_0")) |
| if capability: |
| details.append(f"sm={capability}") |
| memory_gb = info.get("cuda_total_memory_gb_0") |
| if memory_gb is not None: |
| details.append(f"vram_gb={memory_gb}") |
| print("[train_worker][runtime] gpu0 " + " ".join(details)) |
|
|
|
|
| def preferred_loader_dtype_key() -> str: |
| version = parse_version_tuple(getattr(transformers_pkg, "__version__", "0.0.0")) |
| return "dtype" if version >= (5, 0, 0) else "torch_dtype" |
|
|
|
|
| def dtype_from_name(name: str) -> Any: |
| require_dependency("torch", torch is not None) |
| mapping = { |
| "float16": torch.float16, |
| "fp16": torch.float16, |
| "bfloat16": torch.bfloat16, |
| "bf16": torch.bfloat16, |
| "float32": torch.float32, |
| "fp32": torch.float32, |
| } |
| key = to_text(name).lower() |
| if key not in mapping: |
| raise ValueError(f"Unsupported model dtype '{name}'.") |
| return mapping[key] |
|
|
|
|
| def parse_target_modules(text: str) -> list[str]: |
| raw = to_text(text) |
| if not raw: |
| return [] |
| if raw.lower() in {"auto", "default"}: |
| return [] |
| return [item.strip() for item in raw.split(",") if item.strip()] |
|
|
|
|
| def parse_model_list(value: str) -> list[str]: |
| raw = to_text(value) |
| if not raw: |
| return [] |
| return [item.strip() for item in raw.split(",") if item.strip()] |
|
|
|
|
| def parse_csv_values(value: str) -> list[str]: |
| raw = to_text(value) |
| if not raw: |
| return [] |
| return [chunk.strip() for chunk in raw.split(",") if chunk.strip()] |
|
|
|
|
| def read_os_release() -> dict[str, str]: |
| path = Path("/etc/os-release") |
| if not path.exists(): |
| return {} |
| payload: dict[str, str] = {} |
| try: |
| lines = path.read_text(encoding="utf-8", errors="replace").splitlines() |
| except Exception: |
| return {} |
| for line in lines: |
| text = line.strip() |
| if not text or text.startswith("#") or "=" not in text: |
| continue |
| key, value = text.split("=", 1) |
| key = key.strip().upper() |
| value = value.strip().strip('"').strip("'") |
| if key: |
| payload[key] = value |
| return payload |
|
|
|
|
| def mem_total_gb() -> float | None: |
| path = Path("/proc/meminfo") |
| if not path.exists(): |
| return None |
| try: |
| for line in path.read_text(encoding="utf-8", errors="replace").splitlines(): |
| if not line.startswith("MemTotal:"): |
| continue |
| parts = line.split() |
| if len(parts) < 2: |
| return None |
| kb = float(parts[1]) |
| return round(kb / (1024.0 * 1024.0), 2) |
| except Exception: |
| return None |
| return None |
|
|
|
|
| def collect_runtime_system_snapshot(required_bins: list[str], native_mode: bool) -> dict[str, Any]: |
| os_release = read_os_release() |
| os_id = to_text(os_release.get("ID")).lower() |
| os_name = to_text(os_release.get("PRETTY_NAME")) or platform.platform() |
| kernel = to_text(platform.release()) |
| arch = to_text(platform.machine()) |
| python_exe = to_text(sys.executable) |
| python_version = to_text(platform.python_version()) |
| in_venv = bool(getattr(sys, "base_prefix", "") != getattr(sys, "prefix", "")) |
| cpu_count = int(os.cpu_count() or 0) |
| memory_gb = mem_total_gb() |
| try: |
| disk_usage = shutil.disk_usage(str(Path.cwd())) |
| disk_total_gb = round(float(disk_usage.total) / float(1024 ** 3), 2) |
| disk_free_gb = round(float(disk_usage.free) / float(1024 ** 3), 2) |
| except Exception: |
| disk_total_gb = None |
| disk_free_gb = None |
|
|
| binaries: dict[str, str] = {} |
| missing_required_bins: list[str] = [] |
| for binary in required_bins: |
| resolved = to_text(shutil.which(binary)) |
| binaries[binary] = resolved |
| if not resolved: |
| missing_required_bins.append(binary) |
|
|
| nvidia_smi_present = bool(to_text(shutil.which("nvidia-smi"))) |
| nvidia_smi_output = "" |
| if nvidia_smi_present: |
| try: |
| result = subprocess.run( |
| ["nvidia-smi", "--query-gpu=name,driver_version,memory.total", "--format=csv,noheader"], |
| capture_output=True, |
| text=True, |
| timeout=5, |
| check=False, |
| ) |
| nvidia_smi_output = (result.stdout or result.stderr or "").strip() |
| except Exception as exc: |
| nvidia_smi_output = f"unavailable: {exc}" |
|
|
| is_ubuntu = os_id == "ubuntu" or "ubuntu" in os_name.lower() |
| native_ready = len(missing_required_bins) == 0 |
|
|
| return { |
| "collected_at": datetime.now(timezone.utc).isoformat(), |
| "native_mode": bool(native_mode), |
| "native_ready": bool(native_ready), |
| "is_ubuntu": bool(is_ubuntu), |
| "os_id": os_id or "unknown", |
| "os_name": os_name or "unknown", |
| "kernel": kernel or "unknown", |
| "arch": arch or "unknown", |
| "python_executable": python_exe or "unknown", |
| "python_version": python_version or "unknown", |
| "in_venv": bool(in_venv), |
| "cpu_count": cpu_count, |
| "memory_gb": memory_gb, |
| "disk_total_gb": disk_total_gb, |
| "disk_free_gb": disk_free_gb, |
| "required_bins": required_bins, |
| "binaries": binaries, |
| "missing_required_bins": missing_required_bins, |
| "nvidia_smi_present": nvidia_smi_present, |
| "nvidia_smi_output": nvidia_smi_output, |
| } |
|
|
|
|
| def log_runtime_system_snapshot(snapshot: dict[str, Any]) -> None: |
| os_name = to_text(snapshot.get("os_name")) or "unknown" |
| kernel = to_text(snapshot.get("kernel")) or "unknown" |
| arch = to_text(snapshot.get("arch")) or "unknown" |
| python_version = to_text(snapshot.get("python_version")) or "unknown" |
| python_executable = to_text(snapshot.get("python_executable")) or "unknown" |
| native_mode = bool(snapshot.get("native_mode", False)) |
| native_ready = bool(snapshot.get("native_ready", False)) |
|
|
| mode_text = "on" if native_mode else "off" |
| ready_text = "ready" if native_ready else "degraded" |
| print( |
| f"[train_worker][system] native_mode={mode_text} state={ready_text} " |
| f"os='{os_name}' kernel={kernel} arch={arch}" |
| ) |
| print( |
| f"[train_worker][system] python={python_version} executable={python_executable} " |
| f"venv={bool(snapshot.get('in_venv', False))}" |
| ) |
| print( |
| f"[train_worker][system] cpu_count={int(snapshot.get('cpu_count', 0) or 0)} " |
| f"memory_gb={snapshot.get('memory_gb')} disk_free_gb={snapshot.get('disk_free_gb')}" |
| ) |
|
|
| missing = snapshot.get("missing_required_bins") |
| if isinstance(missing, list) and missing: |
| print("[train_worker][warn] missing required native binaries: " + ",".join(str(item) for item in missing)) |
| else: |
| required = snapshot.get("required_bins") |
| if isinstance(required, list) and required: |
| print("[train_worker][system] required native binaries detected: " + ",".join(str(item) for item in required)) |
|
|
| nvidia_text = to_text(snapshot.get("nvidia_smi_output")) |
| if nvidia_text: |
| print("[train_worker][system] nvidia-smi: " + nvidia_text) |
|
|
|
|
| def build_quant_config(use_4bit: bool, dtype: Any) -> Any: |
| if not use_4bit: |
| return None |
| require_dependency("BitsAndBytesConfig", BitsAndBytesConfig is not None) |
| return BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_compute_dtype=dtype, |
| ) |
|
|
|
|
| def _available_module_suffixes(model: Any) -> set[str]: |
| suffixes: set[str] = set() |
| for name, module in model.named_modules(): |
| if not name: |
| continue |
| parts = name.split(".") |
| lowered_parts = [part.lower() for part in parts] |
|
|
| |
| if lowered_parts and lowered_parts[-1] == "base_layer" and len(parts) >= 2: |
| suffixes.add(parts[-2]) |
| for marker in ("lora_a", "lora_b", "lora_embedding_a", "lora_embedding_b", "lora_magnitude_vector"): |
| if marker in lowered_parts: |
| marker_index = lowered_parts.index(marker) |
| if marker_index > 0: |
| suffixes.add(parts[marker_index - 1]) |
|
|
| if len(list(module.children())) > 0: |
| continue |
| suffixes.add(name.split(".")[-1]) |
| return suffixes |
|
|
|
|
| def model_has_existing_lora(model: Any) -> bool: |
| if bool(getattr(model, "peft_config", None)): |
| return True |
| class_name = model.__class__.__name__.lower() |
| if "peft" in class_name: |
| return True |
|
|
| for name, module in model.named_modules(): |
| lowered = name.lower() |
| if ".lora_" in lowered or lowered.endswith("lora_a") or lowered.endswith("lora_b"): |
| return True |
| if hasattr(module, "lora_A") or hasattr(module, "lora_B"): |
| return True |
| return False |
|
|
|
|
| def set_lora_only_trainable(model: Any) -> tuple[int, int]: |
| trainable_names = ( |
| "lora_A", |
| "lora_B", |
| "lora_embedding_A", |
| "lora_embedding_B", |
| "lora_magnitude_vector", |
| "modules_to_save", |
| ) |
| trainable_params = 0 |
| total_params = 0 |
| for name, param in model.named_parameters(): |
| total_params += int(param.numel()) |
| is_trainable = any(token in name for token in trainable_names) |
| param.requires_grad = is_trainable |
| if is_trainable: |
| trainable_params += int(param.numel()) |
| return trainable_params, total_params |
|
|
|
|
| def resolve_lora_target_modules(model: Any, requested: list[str]) -> list[str]: |
| available_suffixes = _available_module_suffixes(model) |
|
|
| |
| fallback_priority: list[str] = [ |
| "q_proj", |
| "k_proj", |
| "v_proj", |
| "o_proj", |
| "gate_proj", |
| "up_proj", |
| "down_proj", |
| "query_key_value", |
| "dense", |
| "dense_h_to_4h", |
| "dense_4h_to_h", |
| "c_attn", |
| "c_proj", |
| "c_fc", |
| ] |
|
|
| if requested: |
| selected = [name for name in requested if name in available_suffixes] |
| if selected: |
| missing = sorted(set(requested) - set(selected)) |
| if missing: |
| print( |
| "[train_worker][warn] dropping unavailable LoRA target modules: " |
| + ",".join(missing) |
| ) |
| return selected |
| print( |
| "[train_worker][warn] requested LoRA target modules not found for this base model; " |
| "falling back to auto-detected targets." |
| ) |
|
|
| auto_selected = [name for name in fallback_priority if name in available_suffixes] |
| if auto_selected: |
| print( |
| "[train_worker][info] using auto-resolved LoRA target modules: " |
| + ",".join(auto_selected) |
| ) |
| return auto_selected |
|
|
| raise RuntimeError( |
| "LoRA is enabled, but no compatible target modules were found in the base model. " |
| f"Available leaf module suffixes: {sorted(available_suffixes)}" |
| ) |
|
|
|
|
| def load_split_dataset(dataset_id: str, split: str, cache_dir: str) -> Any: |
| require_dependency("datasets", load_dataset is not None) |
|
|
| def is_retryable_dataset_error(exc: Exception) -> bool: |
| message = to_text(exc).lower() |
| if not message: |
| return False |
| retry_markers = ( |
| "client has been closed", |
| "connection aborted", |
| "connection reset", |
| "connection refused", |
| "network is unreachable", |
| "temporary failure", |
| "timed out", |
| "timeout", |
| "503", |
| "429", |
| ) |
| return any(marker in message for marker in retry_markers) |
|
|
| def resolve_local_fallback(split_name: str) -> Path | None: |
| name = to_text(split_name) |
| if not name: |
| return None |
| candidates = [ |
| Path("data_processed_v2/parquet") / f"{name}.parquet", |
| REPO_ROOT / "data_processed_v2" / "parquet" / f"{name}.parquet", |
| REPO_ROOT / "data_processed" / "parquet" / f"{name}.parquet", |
| ] |
| for candidate in candidates: |
| if candidate.exists(): |
| return candidate |
| return None |
|
|
| attempts = 3 |
| last_exc: Exception | None = None |
| for attempt in range(1, attempts + 1): |
| try: |
| return load_dataset(dataset_id, split=split, cache_dir=cache_dir) |
| except Exception as exc: |
| last_exc = exc |
| if attempt < attempts and is_retryable_dataset_error(exc): |
| wait_seconds = attempt * 2 |
| print( |
| f"[train_worker][warn] failed to load '{dataset_id}:{split}' " |
| f"(attempt {attempt}/{attempts}): {exc}; retrying in {wait_seconds}s" |
| ) |
| time.sleep(wait_seconds) |
| continue |
| break |
|
|
| fallback = resolve_local_fallback(split) |
| if fallback is not None: |
| print( |
| f"[train_worker][warn] failed to load '{dataset_id}:{split}' ({last_exc}); " |
| f"using local fallback {fallback}" |
| ) |
| return load_dataset( |
| "parquet", |
| data_files={split: str(fallback)}, |
| split=split, |
| cache_dir=cache_dir, |
| ) |
| raise RuntimeError( |
| f"Unable to load dataset split '{dataset_id}:{split}' and no local fallback exists." |
| ) from last_exc |
|
|
|
|
| def latest_checkpoint_dir(output_dir: Path) -> Path | None: |
| checkpoints: list[tuple[int, Path]] = [] |
| for candidate in output_dir.glob("checkpoint-*"): |
| if not candidate.is_dir(): |
| continue |
| suffix = candidate.name.replace("checkpoint-", "", 1) |
| try: |
| step = int(suffix) |
| except Exception: |
| continue |
| checkpoints.append((step, candidate)) |
| if not checkpoints: |
| return None |
| checkpoints.sort(key=lambda item: item[0]) |
| return checkpoints[-1][1] |
|
|
|
|
| def latest_checkpoint_in_sibling_runs(output_dir: Path) -> Path | None: |
| runs_root = output_dir.parent |
| if not runs_root.exists(): |
| return None |
| checkpoints: list[tuple[float, int, Path]] = [] |
| for run_dir in runs_root.iterdir(): |
| if not run_dir.is_dir() or run_dir == output_dir: |
| continue |
| for candidate in run_dir.glob("checkpoint-*"): |
| if not candidate.is_dir(): |
| continue |
| suffix = candidate.name.replace("checkpoint-", "", 1) |
| try: |
| step = int(suffix) |
| except Exception: |
| continue |
| try: |
| mtime = candidate.stat().st_mtime |
| except Exception: |
| mtime = 0.0 |
| checkpoints.append((mtime, step, candidate)) |
| if not checkpoints: |
| return None |
| checkpoints.sort(key=lambda item: (item[0], item[1])) |
| return checkpoints[-1][2] |
|
|
|
|
| def checkpoint_resume_compatible(checkpoint_dir: Path) -> tuple[bool, str]: |
| if not checkpoint_dir.exists(): |
| return False, "path does not exist" |
| if not checkpoint_dir.is_dir(): |
| return False, "path is not a directory" |
|
|
| full_model_markers = ( |
| "model.safetensors", |
| "pytorch_model.bin", |
| "model.safetensors.index.json", |
| "pytorch_model.bin.index.json", |
| ) |
| if any((checkpoint_dir / marker).exists() for marker in full_model_markers): |
| return True, "" |
|
|
| adapter_markers = ( |
| "adapter_model.safetensors", |
| "adapter_model.bin", |
| "adapter_config.json", |
| ) |
| if any((checkpoint_dir / marker).exists() for marker in adapter_markers): |
| return ( |
| False, |
| "adapter-only checkpoint (missing full-model checkpoint files required by Trainer resume)", |
| ) |
|
|
| return False, "missing model checkpoint files" |
|
|
|
|
| def resolve_resume_checkpoint(value: str | None, output_dir: Path) -> str | None: |
| requested = to_text(value).lower() |
| if requested in {"", "none", "false", "no"}: |
| return None |
| if requested in {"auto", "latest"}: |
| latest = latest_checkpoint_dir(output_dir) |
| if latest is not None: |
| compatible, reason = checkpoint_resume_compatible(latest) |
| if compatible: |
| return str(latest) |
| print( |
| "[train_worker][warn] auto-resume skipped latest checkpoint " |
| f"'{latest}' ({reason})." |
| ) |
| sibling_latest = latest_checkpoint_in_sibling_runs(output_dir=output_dir) |
| if sibling_latest is not None: |
| compatible, reason = checkpoint_resume_compatible(sibling_latest) |
| if compatible: |
| print( |
| "[train_worker][info] auto-resume fallback selected sibling checkpoint: " |
| f"{sibling_latest}" |
| ) |
| return str(sibling_latest) |
| print( |
| "[train_worker][warn] auto-resume skipped sibling checkpoint " |
| f"'{sibling_latest}' ({reason})." |
| ) |
| return None |
| candidate = Path(to_text(value)) |
| if not candidate.is_absolute(): |
| candidate = output_dir / candidate |
| if candidate.exists(): |
| compatible, reason = checkpoint_resume_compatible(candidate) |
| if compatible: |
| return str(candidate) |
| raise RuntimeError( |
| f"Requested resume checkpoint is not trainer-resume compatible ({reason}): {candidate}" |
| ) |
| raise RuntimeError(f"Requested resume checkpoint does not exist: {candidate}") |
|
|
|
|
| def resolve_schedule_weights( |
| ce_weight_start: float, |
| ce_weight_end: float, |
| kd_weight_start: float, |
| kd_weight_end: float, |
| ) -> tuple[float, float, float, float]: |
| start_total = float(ce_weight_start) + float(kd_weight_start) |
| end_total = float(ce_weight_end) + float(kd_weight_end) |
| if start_total <= 0 or end_total <= 0: |
| raise ValueError("Distillation CE/KD weights must sum to positive values.") |
| return ( |
| float(ce_weight_start) / start_total, |
| float(ce_weight_end) / end_total, |
| float(kd_weight_start) / start_total, |
| float(kd_weight_end) / end_total, |
| ) |
|
|
|
|
| class DistillationTrainer(TrainerBase): |
| def __init__( |
| self, |
| *args: Any, |
| teacher_models: list[Any], |
| temperature_start: float, |
| temperature_end: float, |
| ce_weight_start: float, |
| ce_weight_end: float, |
| kd_weight_start: float, |
| kd_weight_end: float, |
| **kwargs: Any, |
| ): |
| require_dependency("Trainer", Trainer is not None) |
| super().__init__(*args, **kwargs) |
| if not teacher_models: |
| raise ValueError("teacher_models must contain at least one teacher.") |
|
|
| self.teacher_models = teacher_models |
| self.temperature_start = float(temperature_start) |
| self.temperature_end = float(temperature_end) |
| self.ce_weight_start = float(ce_weight_start) |
| self.ce_weight_end = float(ce_weight_end) |
| self.kd_weight_start = float(kd_weight_start) |
| self.kd_weight_end = float(kd_weight_end) |
| self._latest_distill_metrics: dict[str, float] = {} |
|
|
| for teacher in self.teacher_models: |
| teacher.eval() |
| for param in teacher.parameters(): |
| param.requires_grad = False |
|
|
| def _progress(self) -> float: |
| max_steps = int(getattr(self.state, "max_steps", 0) or 0) |
| if max_steps <= 1: |
| return 0.0 |
| step = float(getattr(self.state, "global_step", 0) or 0) |
| return max(0.0, min(1.0, step / float(max_steps))) |
|
|
| @staticmethod |
| def _interp(start: float, end: float, progress: float) -> float: |
| return start + (end - start) * progress |
|
|
| def _teacher_forward(self, teacher: Any, input_ids: Any, attention_mask: Any) -> Any: |
| try: |
| return teacher(input_ids=input_ids, attention_mask=attention_mask).logits |
| except Exception: |
| teacher_device = next(teacher.parameters()).device |
| out = teacher( |
| input_ids=input_ids.to(teacher_device), |
| attention_mask=attention_mask.to(teacher_device), |
| ).logits |
| return out.to(input_ids.device) |
|
|
| def _teacher_logits(self, input_ids: Any, attention_mask: Any) -> Any: |
| accum = None |
| for teacher in self.teacher_models: |
| logits = self._teacher_forward(teacher, input_ids, attention_mask) |
| accum = logits if accum is None else (accum + logits) |
| return accum / float(len(self.teacher_models)) |
|
|
| def compute_loss( |
| self, |
| model: Any, |
| inputs: dict[str, Any], |
| return_outputs: bool = False, |
| num_items_in_batch: int | None = None, |
| ) -> Any: |
| del num_items_in_batch |
| require_dependency("torch.nn.functional", F is not None) |
| labels = inputs["labels"] |
|
|
| progress = self._progress() |
| temperature = max(1e-6, self._interp(self.temperature_start, self.temperature_end, progress)) |
| ce_weight = self._interp(self.ce_weight_start, self.ce_weight_end, progress) |
| kd_weight = self._interp(self.kd_weight_start, self.kd_weight_end, progress) |
|
|
| student_out = model( |
| input_ids=inputs["input_ids"], |
| attention_mask=inputs["attention_mask"], |
| labels=labels, |
| ) |
| ce_loss = student_out.loss |
|
|
| with torch.no_grad(): |
| teacher_logits = self._teacher_logits(inputs["input_ids"], inputs["attention_mask"]) |
|
|
| student_logits = student_out.logits |
| student_shift = student_logits[:, :-1, :] |
| teacher_shift = teacher_logits[:, :-1, :] |
| labels_shift = labels[:, 1:] |
|
|
| if student_shift.shape[-1] != teacher_shift.shape[-1]: |
| vocab = min(int(student_shift.shape[-1]), int(teacher_shift.shape[-1])) |
| student_shift = student_shift[:, :, :vocab] |
| teacher_shift = teacher_shift[:, :, :vocab] |
|
|
| active = labels_shift.ne(-100) |
| if active.any(): |
| s = student_shift[active] |
| t = teacher_shift[active] |
| kd_loss = F.kl_div( |
| F.log_softmax(s / temperature, dim=-1), |
| F.softmax(t / temperature, dim=-1), |
| reduction="batchmean", |
| ) * (temperature * temperature) |
| else: |
| kd_loss = torch.tensor(0.0, device=ce_loss.device) |
|
|
| self._latest_distill_metrics = { |
| "distill_ce_loss": float(ce_loss.detach().float().item()), |
| "distill_kd_loss": float(kd_loss.detach().float().item()), |
| "distill_temperature": float(temperature), |
| "distill_ce_weight": float(ce_weight), |
| "distill_kd_weight": float(kd_weight), |
| } |
|
|
| loss = ce_weight * ce_loss + kd_weight * kd_loss |
| if return_outputs: |
| student_out.loss = loss |
| return loss, student_out |
| return loss |
|
|
| def log(self, logs: dict[str, float], *args: Any, **kwargs: Any) -> None: |
| merged = dict(logs) |
| if self._latest_distill_metrics: |
| merged.update(self._latest_distill_metrics) |
| super().log(merged, *args, **kwargs) |
|
|
|
|
| def estimate_total_train_steps( |
| train_rows: int, |
| batch_size: int, |
| grad_accum: int, |
| epochs: float, |
| ) -> int: |
| effective_batch = max(1, int(batch_size) * int(grad_accum)) |
| steps_per_epoch = max(1, math.ceil(int(train_rows) / effective_batch)) |
| return max(1, math.ceil(steps_per_epoch * float(epochs))) |
|
|
|
|
| def build_teacher_models( |
| teacher_names: list[str], |
| teacher_dtype: Any, |
| trust_remote_code: bool, |
| attn_implementation: str, |
| use_4bit: bool, |
| using_cuda: bool, |
| ) -> list[Any]: |
| require_dependency("AutoModelForCausalLM", AutoModelForCausalLM is not None) |
| if not teacher_names: |
| raise RuntimeError("Distillation is enabled but no teacher model was configured.") |
|
|
| dtype_key = preferred_loader_dtype_key() |
|
|
| teachers: list[Any] = [] |
| for teacher_name in teacher_names: |
| teacher_kwargs: dict[str, Any] = { |
| "trust_remote_code": bool(trust_remote_code), |
| } |
| teacher_kwargs[dtype_key] = teacher_dtype |
|
|
| quant_cfg = build_quant_config(use_4bit=bool(use_4bit), dtype=teacher_dtype) |
| if quant_cfg is not None: |
| teacher_kwargs["quantization_config"] = quant_cfg |
| teacher_kwargs["device_map"] = "auto" |
| elif using_cuda: |
| teacher_kwargs["device_map"] = "auto" |
|
|
| if attn_implementation: |
| teacher_kwargs["attn_implementation"] = attn_implementation |
|
|
| try: |
| teacher = AutoModelForCausalLM.from_pretrained(teacher_name, **teacher_kwargs) |
| except Exception as exc: |
| if attn_implementation and "attn_implementation" in teacher_kwargs: |
| print( |
| f"[train_worker][warn] teacher '{teacher_name}' failed with " |
| f"attn_implementation='{attn_implementation}' ({exc}); " |
| "retrying with default attention backend" |
| ) |
| teacher_kwargs.pop("attn_implementation", None) |
| teacher = AutoModelForCausalLM.from_pretrained(teacher_name, **teacher_kwargs) |
| else: |
| raise |
|
|
| teacher.eval() |
| for param in teacher.parameters(): |
| param.requires_grad = False |
| teachers.append(teacher) |
| return teachers |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| default_use_4bit = env_bool("UMSR_USE_4BIT", True) |
| default_use_4bit_teacher = env_bool("UMSR_USE_4BIT_TEACHER", True) |
| default_lora_enabled = env_bool("UMSR_LORA_ENABLED", True) |
| default_grad_ckpt = env_bool("UMSR_GRADIENT_CHECKPOINTING", True) |
| default_distill_enabled = env_bool("UMSR_DISTILL_ENABLED", True) |
| default_enforce_inhouse_models = env_bool("UMSR_ENFORCE_INHOUSE_MODELS", True) |
| default_native_trainer_mode = env_bool("UMSR_NATIVE_TRAINER_MODE", True) |
| default_native_strict_mode = env_bool("UMSR_NATIVE_STRICT_MODE", False) |
|
|
| parser = argparse.ArgumentParser(description="Train and optionally push an autonomous UMSR run") |
| parser.add_argument("--dataset-id", default=os.environ.get("UMSR_DATASET_ID", "NorthernTribe-Research/UMSR-v1")) |
| parser.add_argument("--train-split", default=os.environ.get("UMSR_TRAIN_SPLIT", "train")) |
| parser.add_argument("--eval-split", default=os.environ.get("UMSR_EVAL_SPLIT", "validation")) |
| parser.add_argument("--min-quality", type=float, default=float(os.environ.get("UMSR_MIN_QUALITY", "0.72"))) |
| parser.add_argument("--model-name", default=os.environ.get("UMSR_BASE_MODEL", "NorthernTribe-Research/UMSR-Reasoner-7B")) |
| parser.add_argument( |
| "--teacher-model", |
| default=os.environ.get("UMSR_TEACHER_MODEL", "NorthernTribe-Research/UMSR-Reasoner-7B"), |
| ) |
| parser.add_argument("--model-dtype", default=os.environ.get("UMSR_MODEL_DTYPE", "bfloat16")) |
| parser.add_argument("--teacher-dtype", default=os.environ.get("UMSR_TEACHER_DTYPE", "bfloat16")) |
| parser.add_argument( |
| "--attn-implementation", |
| default=os.environ.get("UMSR_ATTN_IMPLEMENTATION", ""), |
| ) |
| parser.add_argument("--distill-enabled", dest="distill_enabled", action="store_true") |
| parser.add_argument("--no-distill-enabled", dest="distill_enabled", action="store_false") |
| parser.set_defaults(distill_enabled=default_distill_enabled) |
| parser.add_argument("--enforce-inhouse-models", dest="enforce_inhouse_models", action="store_true") |
| parser.add_argument("--allow-external-models", dest="enforce_inhouse_models", action="store_false") |
| parser.set_defaults(enforce_inhouse_models=default_enforce_inhouse_models) |
| parser.add_argument("--use-4bit", dest="use_4bit", action="store_true") |
| parser.add_argument("--no-use-4bit", dest="use_4bit", action="store_false") |
| parser.set_defaults(use_4bit=default_use_4bit) |
| parser.add_argument("--use-4bit-teacher", dest="use_4bit_teacher", action="store_true") |
| parser.add_argument("--no-use-4bit-teacher", dest="use_4bit_teacher", action="store_false") |
| parser.set_defaults(use_4bit_teacher=default_use_4bit_teacher) |
| parser.add_argument("--lora-enabled", dest="lora_enabled", action="store_true") |
| parser.add_argument("--no-lora-enabled", dest="lora_enabled", action="store_false") |
| parser.set_defaults(lora_enabled=default_lora_enabled) |
| parser.add_argument("--lora-r", type=int, default=int(os.environ.get("UMSR_LORA_R", "32"))) |
| parser.add_argument("--lora-alpha", type=int, default=int(os.environ.get("UMSR_LORA_ALPHA", "64"))) |
| parser.add_argument("--lora-dropout", type=float, default=float(os.environ.get("UMSR_LORA_DROPOUT", "0.05"))) |
| parser.add_argument( |
| "--lora-target-modules", |
| default=os.environ.get( |
| "UMSR_LORA_TARGET_MODULES", |
| "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj", |
| ), |
| ) |
| parser.add_argument("--repo-id", default=os.environ.get("UMSR_MODEL_REPO_ID", "NorthernTribe-Research/UMSR-Reasoner-7B")) |
| parser.add_argument("--output-dir", default="runs/latest") |
| parser.add_argument("--max-train-samples", type=int, default=int(os.environ.get("UMSR_MAX_TRAIN_SAMPLES", "256"))) |
| parser.add_argument("--max-eval-samples", type=int, default=int(os.environ.get("UMSR_MAX_EVAL_SAMPLES", "64"))) |
| parser.add_argument("--num-train-epochs", type=float, default=float(os.environ.get("UMSR_NUM_TRAIN_EPOCHS", "1"))) |
| parser.add_argument("--learning-rate", type=float, default=float(os.environ.get("UMSR_LEARNING_RATE", "1e-4"))) |
| parser.add_argument("--weight-decay", type=float, default=float(os.environ.get("UMSR_WEIGHT_DECAY", "0.0"))) |
| parser.add_argument("--warmup-ratio", type=float, default=float(os.environ.get("UMSR_WARMUP_RATIO", "0.03"))) |
| parser.add_argument("--warmup-steps", type=int, default=int(os.environ.get("UMSR_WARMUP_STEPS", "0"))) |
| parser.add_argument("--per-device-train-batch-size", type=int, default=int(os.environ.get("UMSR_BATCH_SIZE", "1"))) |
| parser.add_argument("--per-device-eval-batch-size", type=int, default=int(os.environ.get("UMSR_EVAL_BATCH_SIZE", "1"))) |
| parser.add_argument("--gradient-accumulation-steps", type=int, default=int(os.environ.get("UMSR_GRAD_ACCUM", "1"))) |
| parser.add_argument("--max-length", type=int, default=int(os.environ.get("UMSR_MAX_LENGTH", "512"))) |
| parser.add_argument("--logging-steps", type=int, default=int(os.environ.get("UMSR_LOGGING_STEPS", "1"))) |
| parser.add_argument("--eval-steps", type=int, default=int(os.environ.get("UMSR_EVAL_STEPS", "25"))) |
| parser.add_argument("--save-steps", type=int, default=int(os.environ.get("UMSR_SAVE_STEPS", "25"))) |
| parser.add_argument("--save-total-limit", type=int, default=int(os.environ.get("UMSR_SAVE_TOTAL_LIMIT", "4"))) |
| parser.add_argument("--seed", type=int, default=int(os.environ.get("UMSR_SEED", "42"))) |
| parser.add_argument("--temperature-start", type=float, default=float(os.environ.get("UMSR_TEMPERATURE_START", "2.5"))) |
| parser.add_argument("--temperature-end", type=float, default=float(os.environ.get("UMSR_TEMPERATURE_END", "1.2"))) |
| parser.add_argument("--ce-weight-start", type=float, default=float(os.environ.get("UMSR_CE_WEIGHT_START", "0.35"))) |
| parser.add_argument("--ce-weight-end", type=float, default=float(os.environ.get("UMSR_CE_WEIGHT_END", "0.5"))) |
| parser.add_argument("--kd-weight-start", type=float, default=float(os.environ.get("UMSR_KD_WEIGHT_START", "0.65"))) |
| parser.add_argument("--kd-weight-end", type=float, default=float(os.environ.get("UMSR_KD_WEIGHT_END", "0.5"))) |
| parser.add_argument( |
| "--resume-from-checkpoint", |
| default=os.environ.get("UMSR_RESUME_FROM_CHECKPOINT", "auto"), |
| help="Use 'auto' to continue from the latest checkpoint in output-dir if available.", |
| ) |
| parser.add_argument("--gradient-checkpointing", dest="gradient_checkpointing", action="store_true") |
| parser.add_argument("--no-gradient-checkpointing", dest="gradient_checkpointing", action="store_false") |
| parser.set_defaults(gradient_checkpointing=default_grad_ckpt) |
| parser.add_argument( |
| "--tie-word-embeddings", |
| action="store_true", |
| help="Keep embedding and lm_head weights tied in the model config.", |
| ) |
| parser.add_argument("--native-trainer-mode", dest="native_trainer_mode", action="store_true") |
| parser.add_argument("--no-native-trainer-mode", dest="native_trainer_mode", action="store_false") |
| parser.set_defaults(native_trainer_mode=default_native_trainer_mode) |
| parser.add_argument("--native-strict-mode", dest="native_strict_mode", action="store_true") |
| parser.add_argument("--no-native-strict-mode", dest="native_strict_mode", action="store_false") |
| parser.set_defaults(native_strict_mode=default_native_strict_mode) |
| parser.add_argument( |
| "--required-bins", |
| default=os.environ.get("UMSR_REQUIRED_BINS", "bash,python3,git,curl"), |
| help="Comma-separated native binaries required for the runtime preflight.", |
| ) |
| parser.add_argument("--token-env", default=os.environ.get("UMSR_TOKEN_ENV", "HF_TOKEN")) |
| parser.add_argument("--push-to-hub", action="store_true") |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
|
|
| require_dependency("torch", torch is not None) |
| require_dependency("datasets", load_dataset is not None) |
| require_dependency("transformers", all( |
| dep is not None |
| for dep in [AutoConfig, AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments] |
| )) |
| require_dependency("accelerate>=1.1.0", accelerate is not None) |
| if accelerate is not None: |
| current_accelerate = parse_version_tuple(getattr(accelerate, "__version__", "0.0.0")) |
| if current_accelerate < (1, 1, 0): |
| raise RuntimeError( |
| "accelerate>=1.1.0 is required by Trainer. " |
| f"Found accelerate=={getattr(accelerate, '__version__', 'unknown')}." |
| ) |
|
|
| if args.push_to_hub: |
| require_dependency("huggingface_hub", HfApi is not None) |
|
|
| if set_seed is not None: |
| set_seed(int(args.seed)) |
|
|
| required_bins = parse_csv_values(args.required_bins) |
| runtime_system = collect_runtime_system_snapshot( |
| required_bins=required_bins, |
| native_mode=bool(args.native_trainer_mode), |
| ) |
| log_runtime_system_snapshot(runtime_system) |
| if bool(args.native_trainer_mode) and bool(args.native_strict_mode): |
| missing = runtime_system.get("missing_required_bins") |
| if isinstance(missing, list) and missing: |
| raise RuntimeError( |
| "Native strict mode failed: missing required binaries: " + ",".join(str(item) for item in missing) |
| ) |
|
|
| teacher_names = parse_model_list(args.teacher_model) |
| validate_model_reference( |
| args.model_name, |
| role="base", |
| enforce_inhouse=bool(args.enforce_inhouse_models), |
| ) |
| for teacher_name in teacher_names: |
| validate_model_reference( |
| teacher_name, |
| role="teacher", |
| enforce_inhouse=bool(args.enforce_inhouse_models), |
| ) |
|
|
| runtime_hardware = probe_runtime_hardware() |
| log_runtime_hardware(runtime_hardware) |
|
|
| using_cuda = bool(runtime_hardware.get("cuda_available", False)) |
| using_mps = bool(runtime_hardware.get("mps_available", False)) |
| requested_use_4bit = bool(args.use_4bit) |
| effective_use_4bit = bool(requested_use_4bit and using_cuda) |
| bf16_supported = bool( |
| using_cuda |
| and hasattr(torch.cuda, "is_bf16_supported") |
| and torch.cuda.is_bf16_supported() |
| ) |
| fp16_enabled = bool(using_cuda and not bf16_supported) |
| bf16_enabled = bool(using_cuda and bf16_supported) |
| device_label = "cuda" if using_cuda else ("mps" if using_mps else "cpu") |
|
|
| if using_cuda: |
| |
| try: |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| except Exception: |
| pass |
| elif requested_use_4bit: |
| print("[train_worker][warn] 4-bit quantization requested without CUDA; using non-quantized model load.") |
|
|
| cache_dir = dataset_cache_dir() |
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| write_json(output_dir / "system_snapshot.json", runtime_system) |
| write_json( |
| output_dir / "live_progress.json", |
| { |
| "updated_at": datetime.now(timezone.utc).isoformat(), |
| "status": "initializing", |
| "message": "preparing datasets and models", |
| "distill_enabled": bool(args.distill_enabled), |
| "runtime_system": runtime_system, |
| "runtime_hardware": runtime_hardware, |
| "global_step": 0, |
| "max_steps": 0, |
| "epoch": 0.0, |
| "metrics": {}, |
| }, |
| ) |
|
|
| train_ds = load_split_dataset(args.dataset_id, split=args.train_split, cache_dir=cache_dir) |
| train_ds = train_ds.filter(lambda row: quality_ok(row, float(args.min_quality))) |
| if args.max_train_samples > 0 and len(train_ds) > args.max_train_samples: |
| train_ds = train_ds.shuffle(seed=int(args.seed)).select(range(args.max_train_samples)) |
|
|
| eval_ds = None |
| if args.eval_split: |
| eval_ds = load_split_dataset(args.dataset_id, split=args.eval_split, cache_dir=cache_dir) |
| eval_ds = eval_ds.filter(lambda row: quality_ok(row, float(args.min_quality))) |
| if args.max_eval_samples > 0 and len(eval_ds) > args.max_eval_samples: |
| eval_ds = eval_ds.shuffle(seed=int(args.seed)).select(range(args.max_eval_samples)) |
|
|
| train_ds = train_ds.map( |
| lambda row: {"text": format_text(row)}, |
| remove_columns=train_ds.column_names, |
| ) |
| if eval_ds is not None: |
| eval_ds = eval_ds.map( |
| lambda row: {"text": format_text(row)}, |
| remove_columns=eval_ds.column_names, |
| ) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| def tokenize_batch(batch: dict[str, Any]) -> dict[str, Any]: |
| return tokenizer( |
| batch["text"], |
| truncation=True, |
| max_length=int(args.max_length), |
| padding=False, |
| ) |
|
|
| train_tokenized = train_ds.map(tokenize_batch, batched=True, remove_columns=["text"]) |
| eval_tokenized = None |
| if eval_ds is not None: |
| eval_tokenized = eval_ds.map(tokenize_batch, batched=True, remove_columns=["text"]) |
|
|
| model_dtype = dtype_from_name(args.model_dtype) |
| teacher_dtype = dtype_from_name(args.teacher_dtype) |
| lora_target_modules = parse_target_modules(args.lora_target_modules) |
| requested_teacher_use_4bit = bool(args.use_4bit_teacher) |
| effective_teacher_use_4bit = bool(requested_teacher_use_4bit and using_cuda) |
| if requested_teacher_use_4bit and not using_cuda: |
| print("[train_worker][warn] teacher 4-bit quantization requested without CUDA; using non-quantized teacher load.") |
|
|
| ce_weight_start, ce_weight_end, kd_weight_start, kd_weight_end = resolve_schedule_weights( |
| ce_weight_start=float(args.ce_weight_start), |
| ce_weight_end=float(args.ce_weight_end), |
| kd_weight_start=float(args.kd_weight_start), |
| kd_weight_end=float(args.kd_weight_end), |
| ) |
|
|
| model_config = AutoConfig.from_pretrained(args.model_name) |
| if hasattr(model_config, "tie_word_embeddings"): |
| model_config.tie_word_embeddings = bool(args.tie_word_embeddings) |
|
|
| model_kwargs: dict[str, Any] = {"config": model_config} |
| dtype_key = preferred_loader_dtype_key() |
| attn_impl = to_text(args.attn_implementation) |
| if attn_impl: |
| model_kwargs["attn_implementation"] = attn_impl |
|
|
| student_quant_config = build_quant_config(use_4bit=effective_use_4bit, dtype=model_dtype) |
| if student_quant_config is not None: |
| model_kwargs["quantization_config"] = student_quant_config |
| model_kwargs["device_map"] = "auto" |
| else: |
| model_kwargs[dtype_key] = model_dtype |
| if using_cuda: |
| model_kwargs["device_map"] = "auto" |
|
|
| try: |
| model = AutoModelForCausalLM.from_pretrained(args.model_name, **model_kwargs) |
| except Exception as exc: |
| if attn_impl and "attn_implementation" in model_kwargs: |
| print( |
| f"[train_worker][warn] failed with attn_implementation='{attn_impl}' ({exc}); " |
| "retrying with default attention backend" |
| ) |
| model_kwargs.pop("attn_implementation", None) |
| model = AutoModelForCausalLM.from_pretrained(args.model_name, **model_kwargs) |
| else: |
| raise |
| model.config.use_cache = False |
| model.config.pad_token_id = tokenizer.pad_token_id |
| if bool(args.gradient_checkpointing): |
| model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) |
|
|
| if effective_use_4bit: |
| require_dependency("peft.prepare_model_for_kbit_training", prepare_model_for_kbit_training is not None) |
| model = prepare_model_for_kbit_training( |
| model, |
| use_gradient_checkpointing=bool(args.gradient_checkpointing), |
| ) |
|
|
| if bool(args.lora_enabled): |
| require_dependency("peft.LoraConfig", LoraConfig is not None) |
| require_dependency("peft.get_peft_model", get_peft_model is not None) |
| if model_has_existing_lora(model): |
| print( |
| "[train_worker][warn] base model already has LoRA adapters attached; " |
| "skipping adapter reinjection and training existing adapters." |
| ) |
| trainable, total = set_lora_only_trainable(model) |
| pct = (100.0 * float(trainable) / float(total)) if total > 0 else 0.0 |
| print( |
| f"[train_worker][info] trainable params set to existing LoRA adapters: " |
| f"{trainable}/{total} ({pct:.4f}%)" |
| ) |
| if not lora_target_modules: |
| lora_target_modules = ["preloaded-adapter"] |
| else: |
| lora_target_modules = resolve_lora_target_modules(model=model, requested=lora_target_modules) |
| lora_cfg = LoraConfig( |
| r=int(args.lora_r), |
| lora_alpha=int(args.lora_alpha), |
| lora_dropout=float(args.lora_dropout), |
| bias="none", |
| task_type="CAUSAL_LM", |
| target_modules=lora_target_modules, |
| ) |
| model = get_peft_model(model, lora_cfg) |
| if hasattr(model, "print_trainable_parameters"): |
| model.print_trainable_parameters() |
|
|
| teacher_models: list[Any] = [] |
| if bool(args.distill_enabled): |
| require_dependency("torch.nn.functional", F is not None) |
| teacher_models = build_teacher_models( |
| teacher_names=teacher_names, |
| teacher_dtype=teacher_dtype, |
| trust_remote_code=False, |
| attn_implementation=attn_impl, |
| use_4bit=effective_teacher_use_4bit, |
| using_cuda=using_cuda, |
| ) |
| elif teacher_names: |
| print("[train_worker][info] teacher model configured but distillation is disabled; running CE-only SFT mode.") |
|
|
| total_steps = estimate_total_train_steps( |
| train_rows=len(train_tokenized), |
| batch_size=int(args.per_device_train_batch_size), |
| grad_accum=int(args.gradient_accumulation_steps), |
| epochs=float(args.num_train_epochs), |
| ) |
| requested_warmup_steps = max(0, int(args.warmup_steps)) |
| warmup_ratio = max(0.0, float(args.warmup_ratio)) |
| derived_warmup_steps = max(0, int(round(total_steps * warmup_ratio))) if warmup_ratio > 0 else 0 |
| effective_warmup_steps = requested_warmup_steps if requested_warmup_steps > 0 else derived_warmup_steps |
|
|
| training_kwargs: dict[str, Any] = { |
| "output_dir": str(output_dir), |
| "run_name": "umsr-autonomous-space", |
| "num_train_epochs": float(args.num_train_epochs), |
| "learning_rate": float(args.learning_rate), |
| "weight_decay": float(args.weight_decay), |
| "warmup_steps": int(effective_warmup_steps), |
| "per_device_train_batch_size": int(args.per_device_train_batch_size), |
| "per_device_eval_batch_size": int(args.per_device_eval_batch_size), |
| "gradient_accumulation_steps": int(args.gradient_accumulation_steps), |
| "logging_steps": int(args.logging_steps), |
| "disable_tqdm": True, |
| "save_steps": int(args.save_steps), |
| "save_total_limit": max(1, int(args.save_total_limit)), |
| "report_to": ["none"], |
| "remove_unused_columns": False, |
| "seed": int(args.seed), |
| "fp16": fp16_enabled, |
| "bf16": bf16_enabled, |
| "gradient_checkpointing": bool(args.gradient_checkpointing), |
| "dataloader_pin_memory": using_cuda, |
| "optim": "paged_adamw_8bit" if effective_use_4bit else "adamw_torch", |
| } |
| training_arg_params = set(inspect.signature(TrainingArguments.__init__).parameters.keys()) |
| eval_key = "eval_strategy" if "eval_strategy" in training_arg_params else "evaluation_strategy" |
| if "logging_first_step" in training_arg_params: |
| training_kwargs["logging_first_step"] = True |
| if eval_tokenized is not None: |
| training_kwargs[eval_key] = "steps" |
| training_kwargs["eval_steps"] = int(args.eval_steps) |
| else: |
| training_kwargs[eval_key] = "no" |
|
|
| train_args = TrainingArguments(**training_kwargs) |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
| callbacks: list[Any] | None = None |
| if TrainerCallback is not None: |
| callbacks = [ |
| LiveProgressCallback( |
| output_dir=output_dir, |
| distill_enabled=bool(args.distill_enabled), |
| runtime_hardware=runtime_hardware, |
| runtime_system=runtime_system, |
| ) |
| ] |
| else: |
| print("[train_worker][warn] TrainerCallback unavailable; live progress telemetry disabled.") |
|
|
| resume_checkpoint = resolve_resume_checkpoint(args.resume_from_checkpoint, output_dir=output_dir) |
| if resume_checkpoint: |
| print(f"[train_worker][info] resuming from checkpoint: {resume_checkpoint}") |
|
|
| run_config = { |
| "dataset_id": args.dataset_id, |
| "train_split": args.train_split, |
| "eval_split": args.eval_split, |
| "min_quality": float(args.min_quality), |
| "student_model": args.model_name, |
| "teacher_models": teacher_names, |
| "distill_enabled": bool(args.distill_enabled), |
| "enforce_inhouse_models": bool(args.enforce_inhouse_models), |
| "model_dtype": to_text(args.model_dtype).lower(), |
| "teacher_dtype": to_text(args.teacher_dtype).lower(), |
| "use_4bit_student_requested": requested_use_4bit, |
| "use_4bit_student_effective": effective_use_4bit, |
| "use_4bit_teacher_requested": requested_teacher_use_4bit, |
| "use_4bit_teacher_effective": effective_teacher_use_4bit, |
| "temperature_start": float(args.temperature_start), |
| "temperature_end": float(args.temperature_end), |
| "ce_weight_start": float(ce_weight_start), |
| "ce_weight_end": float(ce_weight_end), |
| "kd_weight_start": float(kd_weight_start), |
| "kd_weight_end": float(kd_weight_end), |
| "lora_enabled": bool(args.lora_enabled), |
| "lora_r": int(args.lora_r), |
| "lora_alpha": int(args.lora_alpha), |
| "lora_dropout": float(args.lora_dropout), |
| "lora_target_modules": lora_target_modules, |
| "save_total_limit": max(1, int(args.save_total_limit)), |
| "resume_from_checkpoint": resume_checkpoint or "", |
| "output_dir": str(output_dir), |
| "system_snapshot_path": str(output_dir / "system_snapshot.json"), |
| "target_repo_id": args.repo_id, |
| "native_trainer_mode": bool(args.native_trainer_mode), |
| "native_strict_mode": bool(args.native_strict_mode), |
| "required_bins": required_bins, |
| "runtime_system": runtime_system, |
| "runtime_hardware": runtime_hardware, |
| "created_at": datetime.now(timezone.utc).isoformat(), |
| } |
| write_json(output_dir / "effective_run_config.json", run_config) |
|
|
| if bool(args.distill_enabled): |
| trainer = DistillationTrainer( |
| model=model, |
| teacher_models=teacher_models, |
| args=train_args, |
| train_dataset=train_tokenized, |
| eval_dataset=eval_tokenized, |
| data_collator=data_collator, |
| callbacks=callbacks, |
| temperature_start=max(1e-6, float(args.temperature_start)), |
| temperature_end=max(1e-6, float(args.temperature_end)), |
| ce_weight_start=float(ce_weight_start), |
| ce_weight_end=float(ce_weight_end), |
| kd_weight_start=float(kd_weight_start), |
| kd_weight_end=float(kd_weight_end), |
| ) |
| else: |
| trainer = Trainer( |
| model=model, |
| args=train_args, |
| train_dataset=train_tokenized, |
| eval_dataset=eval_tokenized, |
| data_collator=data_collator, |
| callbacks=callbacks, |
| ) |
|
|
| if resume_checkpoint: |
| train_result = trainer.train(resume_from_checkpoint=resume_checkpoint) |
| else: |
| train_result = trainer.train() |
| trainer.save_model() |
| tokenizer.save_pretrained(str(output_dir)) |
| trainer.save_state() |
|
|
| train_metrics = dict(train_result.metrics) |
| train_metrics["train_samples"] = len(train_tokenized) |
| train_metrics["distill_enabled"] = bool(args.distill_enabled) |
| train_metrics["teacher_count"] = len(teacher_models) |
| train_metrics["ce_weight_start"] = float(ce_weight_start) |
| train_metrics["ce_weight_end"] = float(ce_weight_end) |
| train_metrics["kd_weight_start"] = float(kd_weight_start) |
| train_metrics["kd_weight_end"] = float(kd_weight_end) |
| train_metrics["temperature_start"] = float(args.temperature_start) |
| train_metrics["temperature_end"] = float(args.temperature_end) |
| write_json(output_dir / "metrics" / "train_metrics.json", train_metrics) |
|
|
| eval_metrics: dict[str, Any] = {} |
| if eval_tokenized is not None: |
| eval_metrics = dict(trainer.evaluate()) |
| eval_metrics["eval_samples"] = len(eval_tokenized) |
| write_json(output_dir / "metrics" / "eval_metrics.json", eval_metrics) |
|
|
| summary = { |
| "dataset_id": args.dataset_id, |
| "train_rows": len(train_tokenized), |
| "eval_rows": len(eval_tokenized) if eval_tokenized is not None else 0, |
| "output_dir": str(output_dir), |
| "system_snapshot_path": str(output_dir / "system_snapshot.json"), |
| "live_progress_path": str(output_dir / "live_progress.json"), |
| "live_events_path": str(output_dir / "live_events.jsonl"), |
| "base_model": args.model_name, |
| "target_repo_id": args.repo_id, |
| "native_trainer_mode": bool(args.native_trainer_mode), |
| "native_strict_mode": bool(args.native_strict_mode), |
| "required_bins": required_bins, |
| "runtime_system": runtime_system, |
| "runtime_hardware": runtime_hardware, |
| "device": device_label, |
| "cuda_available": using_cuda, |
| "mps_available": using_mps, |
| "fp16": fp16_enabled, |
| "bf16": bf16_enabled, |
| "model_dtype": to_text(args.model_dtype).lower(), |
| "teacher_dtype": to_text(args.teacher_dtype).lower(), |
| "attn_implementation": attn_impl, |
| "distill_enabled": bool(args.distill_enabled), |
| "enforce_inhouse_models": bool(args.enforce_inhouse_models), |
| "teacher_models": teacher_names, |
| "teacher_count": len(teacher_models), |
| "temperature_start": float(args.temperature_start), |
| "temperature_end": float(args.temperature_end), |
| "ce_weight_start": float(ce_weight_start), |
| "ce_weight_end": float(ce_weight_end), |
| "kd_weight_start": float(kd_weight_start), |
| "kd_weight_end": float(kd_weight_end), |
| "use_4bit_requested": requested_use_4bit, |
| "use_4bit_effective": effective_use_4bit, |
| "use_4bit_teacher_requested": requested_teacher_use_4bit, |
| "use_4bit_teacher_effective": effective_teacher_use_4bit, |
| "lora_enabled": bool(args.lora_enabled), |
| "lora_r": int(args.lora_r), |
| "lora_alpha": int(args.lora_alpha), |
| "lora_dropout": float(args.lora_dropout), |
| "lora_target_modules": lora_target_modules, |
| "save_total_limit": max(1, int(args.save_total_limit)), |
| "gradient_checkpointing": bool(args.gradient_checkpointing), |
| "warmup_ratio": float(warmup_ratio), |
| "requested_warmup_steps": int(requested_warmup_steps), |
| "warmup_steps": int(effective_warmup_steps), |
| "total_train_steps_estimate": int(total_steps), |
| "tie_word_embeddings": bool(getattr(model.config, "tie_word_embeddings", False)), |
| "resume_from_checkpoint": resume_checkpoint or "", |
| "finished_at": datetime.now(timezone.utc).isoformat(), |
| } |
| write_json(output_dir / "run_summary.json", summary) |
| finalize_live_progress(output_dir=output_dir, message="training finished") |
|
|
| (output_dir / "README.md").write_text( |
| model_card_text( |
| dataset_id=args.dataset_id, |
| model_id=args.repo_id, |
| base_model=args.model_name, |
| teacher_models=teacher_names, |
| distill_enabled=bool(args.distill_enabled), |
| ), |
| encoding="utf-8", |
| ) |
|
|
| if args.push_to_hub: |
| token = os.environ.get(args.token_env, "") |
| if not token: |
| raise RuntimeError(f"Missing token in environment variable ${args.token_env}") |
|
|
| api = HfApi(token=token) |
| api.create_repo(repo_id=args.repo_id, repo_type="model", private=False, exist_ok=True) |
| api.upload_folder( |
| repo_id=args.repo_id, |
| repo_type="model", |
| folder_path=str(output_dir), |
| commit_message="Autonomous Space trainer update", |
| ignore_patterns=["checkpoint-*", "optimizer.pt", "scheduler.pt", "rng_state.pth"], |
| ) |
|
|
| print(json.dumps(summary, indent=2, sort_keys=True)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|