| """Modeļa apmācība ar HuggingFace Trainer.""" |
|
|
| from __future__ import annotations |
|
|
| import asyncio |
| import hashlib |
| import importlib.util |
| import inspect |
| import json |
| import logging |
| import math |
| import os |
| import re |
| import subprocess |
| import sys |
| import tempfile |
| import time |
| from contextlib import suppress |
| from dataclasses import dataclass, replace |
| from datetime import UTC, datetime |
| from importlib.metadata import PackageNotFoundError |
| from importlib.metadata import version as get_installed_package_version |
| from pathlib import Path |
| from typing import Any |
|
|
| from maris_core.data.datasets import HFDatasetError, load_hf_dataset |
| from maris_core.data.preprocessing import record_to_training_text |
| from maris_core.data.quality import ( |
| DatasetQualityGateConfig, |
| DatasetQualityReport, |
| DatasetQualitySplitReport, |
| apply_quality_gate_to_records, |
| build_dataset_quality_report, |
| ) |
| from maris_core.data.scoring import ( |
| DatasetBenchmarkFeedback, |
| DatasetScoringConfig, |
| DatasetScoringReport, |
| DatasetScoringSplitReport, |
| apply_scoring_to_records, |
| build_benchmark_feedback_artifact, |
| build_dataset_scoring_report, |
| load_benchmark_feedback, |
| ) |
| from maris_core.orchestrator.routing import build_system_prompt |
| from maris_core.text.benchmark import ( |
| build_chat_benchmark_history_artifact, |
| build_chat_benchmark_manifest, |
| build_chat_benchmark_regression_report, |
| load_chat_benchmark_dataset, |
| run_chat_benchmark_with_responder, |
| select_chat_benchmark_cases, |
| ) |
| from maris_core.text.generate import ( |
| _extract_response_text, |
| _extract_usage_tokens, |
| call_generation_pipeline, |
| ) |
| from maris_core.training.config import ( |
| DEFAULT_BENCHMARK_NAME, |
| DEFAULT_BRANCH_DATASET_FILTER_RULES, |
| TrainingConfig, |
| load_training_config, |
| ) |
| from maris_core.training.hf_compat import ( |
| MARIS_COMPATIBILITY_ARTIFACT_NAME, |
| apply_maris_compatibility_identity, |
| maris_hf_compatible_path, |
| write_maris_compatibility_artifact, |
| ) |
| from maris_core.training.preferences import ( |
| PreferenceExample, |
| build_blind_side_by_side_artifact, |
| build_human_eval_summary, |
| load_preference_dataset, |
| summarize_preference_dataset, |
| ) |
| from maris_core.utils.env import get_env_any |
|
|
| logger = logging.getLogger(__name__) |
| TEXT_TRAINABLE_BRANCHES = {"master", "coder", "planner"} |
| MAX_LOSS_FOR_PERPLEXITY = 20.0 |
| MARIS_ORIGIN_NAME = "Maris AI" |
| MARIS_FRAMEWORK_NAME = "maris-ai-core" |
| TOKENIZER_LOAD_RETRY_EXCEPTIONS = (ImportError, OSError, RuntimeError, TypeError, ValueError) |
| SANITIZED_ARTIFACT_KEYS = ( |
| "_name_or_path", |
| "name_or_path", |
| "base_model_name_or_path", |
| ) |
| IDENTITY_TEXT_KEYS = frozenset( |
| { |
| "base_model_lineage", |
| "base_model_name", |
| "chat_template", |
| "default_system_prompt", |
| "description", |
| "model_name", |
| "system_prompt", |
| } |
| ) |
| TEXT_SANITIZED_ARTIFACT_EXTENSIONS = frozenset({".jinja", ".md", ".template", ".txt"}) |
| FOREIGN_MODEL_REFERENCE_RE = re.compile( |
| r"(?i)\b(?:" |
| r"allenai|anthropic|cohereforai|deepseek-ai|google|huggingfaceh4|meta-llama|" |
| r"microsoft|mistralai|nousresearch|openai|qwen|stabilityai|tiiuae|tinyllama" |
| r")/[A-Za-z0-9][\w.-]*\b" |
| ) |
| FOREIGN_AI_BRAND_PATTERNS = ( |
| re.compile(r"(?i)\bTinyLlama\b"), |
| re.compile(r"(?i)\bDeepSeek\b"), |
| re.compile(r"(?i)\bMistral\b"), |
| re.compile(r"(?i)\bLlama\b"), |
| re.compile(r"(?i)\bQwen\b"), |
| re.compile(r"(?i)\bChatGPT\b"), |
| re.compile(r"(?i)\bClaude\b"), |
| re.compile(r"(?i)\bGemini\b"), |
| re.compile(r"(?i)\bOpenAI\b"), |
| re.compile(r"(?i)\bAnthropic\b"), |
| ) |
| MARIS_IDENTITY_VARIANT_RE = re.compile( |
| rf"{re.escape(MARIS_ORIGIN_NAME)}(?:[-_/][A-Za-z0-9][\w.-]*)+" |
| ) |
| TRAINING_ARGUMENT_ALIASES = { |
| "evaluation_strategy": "eval_strategy", |
| } |
| MODEL_CARD_TAGS = ("maris-ai", "maris-origin", "conversational-ai") |
| PEFT_ADAPTER_TYPES = {"lora", "qlora"} |
| PREFERENCE_OPTIMIZATION_TYPES = {"dpo", "orpo"} |
| MODEL_SIZE_HINT_RE = re.compile(r"(\d+(?:\.\d+)?)B") |
| LARGE_MODEL_RESOURCE_THRESHOLD_B = 30.0 |
| GIANT_MODEL_AUTO_QLORA_THRESHOLD_B = 70.0 |
| LONG_CONTEXT_MIN_SEQ_LENGTH = 32_768 |
| GIANT_MODEL_NAME_HINTS = ( |
| "deepseek-v3", |
| "deepseek-r1", |
| "qwen3-coder-480b", |
| "405b", |
| "480b", |
| "671b", |
| ) |
| |
| |
| LOCAL_TRAINING_ARTIFACT_FILES = ( |
| "config.json", |
| "adapter_config.json", |
| "model.safetensors", |
| "pytorch_model.bin", |
| ) |
| MODEL_SOURCE_FINGERPRINT_KEY = "model_source_fingerprint" |
| REPO_ROOT = Path(__file__).resolve().parents[3] |
|
|
|
|
| def _emit_training_progress_event(event: str, **payload: Any) -> None: |
| body = { |
| "maris_training_event": True, |
| "event": event, |
| **{key: value for key, value in payload.items() if value is not None}, |
| } |
| print(json.dumps(body, ensure_ascii=False), flush=True) |
|
|
|
|
| def _build_training_progress_label( |
| *, |
| stage: str, |
| epoch: float | None = None, |
| total_epochs: int | None = None, |
| step: int | None = None, |
| total_steps: int | None = None, |
| ) -> str: |
| if stage == "saving": |
| return "Saglabā modeļa artefaktus" |
| if stage == "evaluating": |
| return "Veic eval metriku aprēķinu" |
| if stage == "publishing": |
| return "Publicē modeli origin repozitorijā" |
| if stage == "benchmarking": |
| return "Palaiž benchmark un release gate pārbaudes" |
| if stage == "starting": |
| return "Inicializē treniņu" |
| if step is not None and total_steps: |
| return f"Trenē modeli · solis {step}/{total_steps}" |
| if epoch is not None and total_epochs: |
| return f"Trenē modeli · epoha {epoch:g}/{total_epochs}" |
| if epoch is not None: |
| return f"Trenē modeli · epoha {epoch:g}" |
| return "Trenē modeli" |
|
|
|
|
| class MarisTrainingProgressCallback: |
| """Emitē strukturētus progress notikumus Space UI monitoringam.""" |
|
|
| def __init__(self, *, total_epochs: int | None = None) -> None: |
| self.total_epochs = total_epochs |
| self.started_monotonic: float | None = None |
|
|
| def on_train_begin(self, args, state, control, **kwargs) -> None: |
| self.started_monotonic = time.monotonic() |
| total_steps = int(state.max_steps) if getattr(state, "max_steps", 0) else None |
| _emit_training_progress_event( |
| "train_begin", |
| stage="starting", |
| label=_build_training_progress_label(stage="starting"), |
| epoch=_coerce_float(getattr(state, "epoch", None)), |
| total_epochs=self.total_epochs, |
| step=_coerce_int(getattr(state, "global_step", None)), |
| total_steps=total_steps, |
| learning_rate=getattr(args, "learning_rate", None), |
| ) |
|
|
| def on_log(self, args, state, control, logs=None, **kwargs) -> None: |
| logs = dict(logs or {}) |
| step = _coerce_int(getattr(state, "global_step", None)) |
| total_steps = int(state.max_steps) if getattr(state, "max_steps", 0) else None |
| epoch = _coerce_float(logs.get("epoch", getattr(state, "epoch", None))) |
| learning_rate = logs.get("learning_rate", getattr(args, "learning_rate", None)) |
| _emit_training_progress_event( |
| "log", |
| stage="training", |
| label=_build_training_progress_label( |
| stage="training", |
| epoch=epoch, |
| total_epochs=self.total_epochs, |
| step=step, |
| total_steps=total_steps, |
| ), |
| epoch=epoch, |
| total_epochs=self.total_epochs, |
| step=step, |
| total_steps=total_steps, |
| loss=_coerce_float(logs.get("loss")), |
| eval_loss=_coerce_float(logs.get("eval_loss")), |
| learning_rate=_coerce_float(learning_rate), |
| eta_seconds=self._estimate_eta_seconds(step=step, total_steps=total_steps), |
| ) |
|
|
| def on_evaluate(self, args, state, control, metrics=None, **kwargs) -> None: |
| metrics = dict(metrics or {}) |
| step = _coerce_int(getattr(state, "global_step", None)) |
| total_steps = int(state.max_steps) if getattr(state, "max_steps", 0) else None |
| epoch = _coerce_float(getattr(state, "epoch", None)) |
| _emit_training_progress_event( |
| "evaluate", |
| stage="evaluating", |
| label=_build_training_progress_label(stage="evaluating"), |
| epoch=epoch, |
| total_epochs=self.total_epochs, |
| step=step, |
| total_steps=total_steps, |
| eval_loss=_coerce_float(metrics.get("eval_loss")), |
| eta_seconds=self._estimate_eta_seconds(step=step, total_steps=total_steps), |
| ) |
|
|
| def on_save(self, args, state, control, **kwargs) -> None: |
| _emit_training_progress_event( |
| "save", |
| stage="saving", |
| label=_build_training_progress_label(stage="saving"), |
| epoch=_coerce_float(getattr(state, "epoch", None)), |
| total_epochs=self.total_epochs, |
| step=_coerce_int(getattr(state, "global_step", None)), |
| total_steps=int(state.max_steps) if getattr(state, "max_steps", 0) else None, |
| output_dir=getattr(args, "output_dir", None), |
| ) |
|
|
| def on_train_end(self, args, state, control, **kwargs) -> None: |
| _emit_training_progress_event( |
| "train_end", |
| stage="saving", |
| label="Treniņš pabeigts, sagatavo artefaktus", |
| epoch=_coerce_float(getattr(state, "epoch", None)), |
| total_epochs=self.total_epochs, |
| step=_coerce_int(getattr(state, "global_step", None)), |
| total_steps=int(state.max_steps) if getattr(state, "max_steps", 0) else None, |
| ) |
|
|
| def _estimate_eta_seconds(self, *, step: int | None, total_steps: int | None) -> float | None: |
| if self.started_monotonic is None or not step or not total_steps or step <= 0: |
| return None |
| elapsed = time.monotonic() - self.started_monotonic |
| if elapsed <= 0: |
| return None |
| remaining_steps = max(total_steps - step, 0) |
| if remaining_steps <= 0: |
| return 0.0 |
| return elapsed / step * remaining_steps |
|
|
|
|
| def _coerce_float(value: Any) -> float | None: |
| if isinstance(value, bool) or value in (None, ""): |
| return None |
| try: |
| return float(value) |
| except (TypeError, ValueError): |
| return None |
|
|
|
|
| def _coerce_int(value: Any) -> int | None: |
| if isinstance(value, bool) or value in (None, ""): |
| return None |
| try: |
| return int(float(value)) |
| except (TypeError, ValueError): |
| return None |
|
|
|
|
| @dataclass(slots=True, frozen=True) |
| class DatasetBranchFilterReport: |
| split_name: str |
| branch_name: str |
| input_records: int |
| kept_records: int |
| dropped_records: int |
|
|
|
|
| @dataclass(slots=True, frozen=True) |
| class BranchFilterSignals: |
| explicit_branches: frozenset[str] = frozenset() |
| record_types: frozenset[str] = frozenset() |
| task_types: frozenset[str] = frozenset() |
| languages: frozenset[str] = frozenset() |
| repo_context_terms: frozenset[str] = frozenset() |
| presence_keys: frozenset[str] = frozenset() |
|
|
|
|
| def _get_split(dataset: Any, *names: str) -> Any: |
| for name in names: |
| if isinstance(dataset, dict) and name in dataset: |
| return dataset[name] |
|
|
| getter = getattr(dataset, "get", None) |
| if callable(getter): |
| candidate = getter(name) |
| if candidate is not None: |
| return candidate |
|
|
| return None |
|
|
|
|
| def _dataset_split_names(dataset: Any) -> list[str]: |
| if isinstance(dataset, dict): |
| return [str(name) for name in dataset] |
| keys = getattr(dataset, "keys", None) |
| if callable(keys): |
| try: |
| return [str(name) for name in keys()] |
| except Exception: |
| pass |
| return [ |
| name |
| for name in ("train", "validation", "eval", "test") |
| if _get_split(dataset, name) is not None |
| ] |
|
|
|
|
| def _merge_dataset_splits(splits: list[Any]) -> Any: |
| merged_splits = [split for split in splits if split is not None] |
| if not merged_splits: |
| return None |
| if len(merged_splits) == 1: |
| return merged_splits[0] |
| if all(type(split).__module__.startswith("datasets") for split in merged_splits): |
| try: |
| from datasets import concatenate_datasets |
|
|
| return concatenate_datasets(merged_splits) |
| except Exception as exc: |
| logger.warning( |
| "HF split concatenation neizdevās; pārslēdzamies uz ierakstu līmeņa merge: %s", |
| exc, |
| ) |
| records: list[dict[str, Any]] = [] |
| for split in merged_splits: |
| records.extend(_materialize_split_records(split)) |
| return _rebuild_split_like(merged_splits[0], records) |
|
|
|
|
| def _load_combined_hf_dataset(repo_ids: list[str]) -> Any: |
| datasets_by_repo = [(repo_id, load_hf_dataset(repo_id)) for repo_id in repo_ids] |
| split_names: list[str] = [] |
| for _, dataset in datasets_by_repo: |
| for split_name in _dataset_split_names(dataset): |
| if split_name not in split_names: |
| split_names.append(split_name) |
| if not split_names: |
| raise ValueError("Neviens no norādītajiem dataset repo nesatur pieejamus splitus.") |
| return { |
| split_name: _merge_dataset_splits( |
| [_get_split(dataset, split_name) for _, dataset in datasets_by_repo] |
| ) |
| for split_name in split_names |
| } |
|
|
|
|
| def _resolve_training_dataset_repos(config: TrainingConfig) -> list[str]: |
| return list(config.dataset_repos or [config.dataset_repo]) |
|
|
|
|
| def _resolve_eval_dataset_repos(config: TrainingConfig) -> list[str]: |
| if config.eval_dataset_repos: |
| return list(config.eval_dataset_repos) |
| if config.eval_dataset_repo: |
| return [config.eval_dataset_repo] |
| return [] |
|
|
|
|
| def _resolve_primary_eval_dataset_repo(config: TrainingConfig) -> str: |
| eval_dataset_repos = _resolve_eval_dataset_repos(config) |
| if eval_dataset_repos: |
| return eval_dataset_repos[0] |
| return config.eval_dataset_repo or config.dataset_repo |
|
|
|
|
| def _prepare_train_eval_splits( |
| dataset: Any, |
| config: TrainingConfig, |
| ) -> tuple[Any, Any | None]: |
| train_split = _get_split(dataset, "train") |
| if train_split is None: |
| raise ValueError("Datasetam nav pieejams 'train' split apmācībai.") |
|
|
| eval_split = _get_split(dataset, "validation", "eval", "test") |
| if eval_split is not None: |
| return train_split, eval_split |
|
|
| if ( |
| config.validation_split_ratio <= 0 |
| or not hasattr(train_split, "train_test_split") |
| or len(train_split) < 2 |
| ): |
| return train_split, None |
|
|
| split_result = train_split.train_test_split( |
| test_size=config.validation_split_ratio, |
| seed=config.seed, |
| ) |
| return split_result["train"], split_result["test"] |
|
|
|
|
| def _select_eval_split(dataset: Any, config: TrainingConfig, *, allow_train_fallback: bool) -> Any: |
| """Atrod drošāko eval splitu novērtēšanai vai ārējam benchmark datasetam.""" |
| train_split = _get_split(dataset, "train") |
| eval_split = _get_split(dataset, "validation", "eval", "test") |
|
|
| if eval_split is not None: |
| return eval_split |
| if allow_train_fallback and train_split is not None: |
| return train_split |
| if train_split is None: |
| raise ValueError("Datasetam nav pieejams ne 'train', ne eval split novērtēšanai.") |
| _, derived_eval_split = _prepare_train_eval_splits(dataset, config) |
| if derived_eval_split is None: |
| raise ValueError("Datasetam nav pieejams stabils eval split novērtēšanai.") |
| return derived_eval_split |
|
|
|
|
| def _tokenize_dataset(split: Any, tokenizer: Any, max_seq_length: int) -> Any: |
| if not hasattr(split, "map"): |
| raise TypeError("HF apmācībai nepieciešams datasets.Dataset objekts ar map() atbalstu.") |
|
|
| remove_columns = list(getattr(split, "column_names", [])) |
|
|
| def tokenize_batch(batch: dict[str, list[Any]]) -> dict[str, Any]: |
| batch_size = len(next(iter(batch.values()))) if batch else 0 |
| texts = [] |
| for index in range(batch_size): |
| record = {key: values[index] for key, values in batch.items()} |
| texts.append(record_to_training_text(record, max_chars=max_seq_length * 8)) |
|
|
| return tokenizer( |
| texts, |
| truncation=True, |
| max_length=max_seq_length, |
| padding=False, |
| ) |
|
|
| return split.map( |
| tokenize_batch, |
| batched=True, |
| remove_columns=remove_columns, |
| desc="Tokenizing training data", |
| ) |
|
|
|
|
| def _configure_tokenizer(tokenizer: Any, config: TrainingConfig) -> None: |
| if getattr(tokenizer, "pad_token", None) is None: |
| fallback_token = getattr(tokenizer, "eos_token", None) or getattr( |
| tokenizer, "unk_token", None |
| ) |
| if fallback_token is not None: |
| tokenizer.pad_token = fallback_token |
|
|
| if ( |
| getattr(tokenizer, "pad_token_id", None) is None |
| and getattr(tokenizer, "eos_token_id", None) is not None |
| ): |
| tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
| current_model_max_length = getattr(tokenizer, "model_max_length", None) |
| if ( |
| isinstance(current_model_max_length, int) |
| and 0 < current_model_max_length < config.max_seq_length |
| and ( |
| _is_large_training_model(config.model_name) |
| or config.max_seq_length >= LONG_CONTEXT_MIN_SEQ_LENGTH |
| ) |
| ): |
| tokenizer.model_max_length = config.max_seq_length |
|
|
|
|
| def _uses_peft(config: TrainingConfig) -> bool: |
| return config.adapter_type in PEFT_ADAPTER_TYPES |
|
|
|
|
| def _uses_qlora(config: TrainingConfig) -> bool: |
| return config.adapter_type == "qlora" |
|
|
|
|
| def _uses_preference_optimization(config: TrainingConfig) -> bool: |
| return ( |
| bool(config.preference_dataset_path) |
| and config.preference_optimization in PREFERENCE_OPTIMIZATION_TYPES |
| ) |
|
|
|
|
| def _resolve_torch_dtype(dtype_name: str) -> Any | None: |
| normalized = dtype_name.strip().lower() |
| if not normalized: |
| return None |
| try: |
| import torch |
| except ImportError: |
| return None |
| return getattr(torch, normalized, None) |
|
|
|
|
| def _build_model_load_kwargs(config: TrainingConfig) -> dict[str, Any]: |
| model_load_kwargs: dict[str, Any] = {} |
| model_load_kwargs["trust_remote_code"] = True |
| model_load_kwargs["low_cpu_mem_usage"] = True |
|
|
| large_model = _is_large_training_model(config.model_name) |
| if large_model or _uses_qlora(config): |
| model_load_kwargs["device_map"] = "auto" |
| model_load_kwargs["offload_folder"] = str(Path(config.output_dir) / ".offload") |
|
|
| if not _uses_qlora(config): |
| return model_load_kwargs |
| try: |
| from transformers import BitsAndBytesConfig |
| except ImportError as exc: |
| raise ImportError( |
| "QLoRA apmācībai vajag transformers BitsAndBytesConfig un bitsandbytes atkarību." |
| ) from exc |
|
|
| quantization_kwargs: dict[str, Any] = { |
| "load_in_4bit": True, |
| "bnb_4bit_quant_type": config.qlora_quant_type, |
| "bnb_4bit_use_double_quant": config.qlora_use_double_quant, |
| } |
| compute_dtype = _resolve_torch_dtype(config.qlora_compute_dtype) |
| if compute_dtype is not None: |
| quantization_kwargs["bnb_4bit_compute_dtype"] = compute_dtype |
| model_load_kwargs["quantization_config"] = BitsAndBytesConfig(**quantization_kwargs) |
| return model_load_kwargs |
|
|
|
|
| def _extract_model_size_billions(model_name: str) -> float | None: |
| matches = [float(match) for match in MODEL_SIZE_HINT_RE.findall(model_name)] |
| if matches: |
| return max(matches) |
| normalized = model_name.strip().lower() |
| if any(hint in normalized for hint in GIANT_MODEL_NAME_HINTS): |
| return GIANT_MODEL_AUTO_QLORA_THRESHOLD_B |
| return None |
|
|
|
|
| def _is_large_training_model(model_name: str) -> bool: |
| size_hint = _extract_model_size_billions(model_name) |
| return bool(size_hint is not None and size_hint >= LARGE_MODEL_RESOURCE_THRESHOLD_B) |
|
|
|
|
| def _build_hf_auth_kwargs(callable_obj: Any) -> dict[str, str]: |
| auth_token = get_env_any("MARIS_REPO_TOKEN", "MARIS_TOKEN", "HF_TOKEN") |
| if not auth_token: |
| return {} |
|
|
| try: |
| parameters = inspect.signature(callable_obj).parameters |
| except (TypeError, ValueError): |
| return {"token": auth_token} |
|
|
| if "token" in parameters: |
| return {"token": auth_token} |
| if "use_auth_token" in parameters: |
| return {"use_auth_token": auth_token} |
| if any(parameter.kind is inspect.Parameter.VAR_KEYWORD for parameter in parameters.values()): |
| return {"token": auth_token} |
| return {} |
|
|
|
|
| def _normalize_training_runtime_config(config: TrainingConfig) -> TrainingConfig: |
| runtime_updates: dict[str, Any] = {} |
| if not config.fp16 and not config.bf16: |
| if _runtime_supports_bf16(): |
| runtime_updates["bf16"] = True |
| elif _runtime_has_cuda(): |
| runtime_updates["fp16"] = True |
|
|
| size_hint = _extract_model_size_billions(config.model_name) |
| if size_hint is not None: |
| if size_hint >= GIANT_MODEL_AUTO_QLORA_THRESHOLD_B and config.adapter_type != "qlora": |
| runtime_updates["adapter_type"] = "qlora" |
| if ( |
| size_hint >= GIANT_MODEL_AUTO_QLORA_THRESHOLD_B |
| and config.max_seq_length >= LONG_CONTEXT_MIN_SEQ_LENGTH |
| ): |
| if config.distributed_strategy == "none": |
| runtime_updates["distributed_strategy"] = "deepspeed" |
| if not config.use_accelerate: |
| runtime_updates["use_accelerate"] = True |
| if size_hint >= LARGE_MODEL_RESOURCE_THRESHOLD_B and not config.gradient_checkpointing: |
| runtime_updates["gradient_checkpointing"] = True |
| if ( |
| size_hint >= LARGE_MODEL_RESOURCE_THRESHOLD_B |
| and config.per_device_train_batch_size != 1 |
| ): |
| runtime_updates["per_device_train_batch_size"] = 1 |
| if size_hint >= LARGE_MODEL_RESOURCE_THRESHOLD_B and config.per_device_eval_batch_size != 1: |
| runtime_updates["per_device_eval_batch_size"] = 1 |
| if ( |
| size_hint >= GIANT_MODEL_AUTO_QLORA_THRESHOLD_B |
| and config.gradient_accumulation_steps < 16 |
| ): |
| runtime_updates["gradient_accumulation_steps"] = 16 |
|
|
| if not runtime_updates: |
| return config |
|
|
| updated_config = replace(config, **runtime_updates) |
| logger.info( |
| "Pielāgoju training runtime %s: %s", |
| config.model_name, |
| ", ".join(f"{key}={value}" for key, value in sorted(runtime_updates.items())), |
| ) |
| return updated_config |
|
|
|
|
| def _get_torch_module() -> Any | None: |
| try: |
| import torch |
| except ImportError: |
| return None |
| return torch |
|
|
|
|
| def _runtime_has_cuda() -> bool: |
| torch = _get_torch_module() |
| cuda = getattr(torch, "cuda", None) |
| return bool( |
| cuda is not None and callable(getattr(cuda, "is_available", None)) and cuda.is_available() |
| ) |
|
|
|
|
| def _runtime_supports_bf16() -> bool: |
| torch = _get_torch_module() |
| cuda = getattr(torch, "cuda", None) |
| is_available = getattr(cuda, "is_available", None) |
| if cuda is None or not callable(is_available) or not is_available(): |
| return False |
| is_bf16_supported = getattr(cuda, "is_bf16_supported", None) |
| return bool(callable(is_bf16_supported) and is_bf16_supported()) |
|
|
|
|
| def _runtime_has_accelerator() -> bool: |
| if _runtime_has_cuda(): |
| return True |
| torch = _get_torch_module() |
| backends = getattr(torch, "backends", None) |
| mps = getattr(backends, "mps", None) |
| is_available = getattr(mps, "is_available", None) |
| return bool(callable(is_available) and is_available()) |
|
|
|
|
| def _build_runtime_training_argument_overrides() -> dict[str, Any]: |
| overrides: dict[str, Any] = {} |
| if not _runtime_has_accelerator(): |
| overrides["dataloader_pin_memory"] = False |
| stderr = getattr(sys, "stderr", None) |
| is_interactive = bool(callable(getattr(stderr, "isatty", None)) and stderr.isatty()) |
| if not is_interactive: |
| overrides["disable_tqdm"] = True |
| overrides["logging_first_step"] = True |
| return overrides |
|
|
|
|
| def _resolve_repo_relative_path(path_value: str) -> str: |
| normalized = path_value.strip() |
| if not normalized: |
| return "" |
| candidate = Path(normalized) |
| if candidate.is_absolute(): |
| return str(candidate) |
| return str((REPO_ROOT / candidate).resolve()) |
|
|
|
|
| def _resolve_distributed_config_path(config: TrainingConfig) -> str: |
| if config.distributed_config_path: |
| return _resolve_repo_relative_path(config.distributed_config_path) |
| if config.distributed_strategy == "fsdp": |
| return str((REPO_ROOT / "huggingface" / "fsdp-config.json").resolve()) |
| if config.distributed_strategy == "deepspeed": |
| return str((REPO_ROOT / "huggingface" / "deepspeed-zero3.json").resolve()) |
| return "" |
|
|
|
|
| def _load_json_object(path_value: str, *, label: str) -> dict[str, Any]: |
| resolved_path = Path(path_value) |
| if not resolved_path.is_file(): |
| raise FileNotFoundError(f"{label} nav atrasts: {resolved_path}") |
| payload = json.loads(resolved_path.read_text(encoding="utf-8")) |
| if not isinstance(payload, dict): |
| raise ValueError(f"{label} jābūt JSON objektam.") |
| return payload |
|
|
|
|
| def _require_runtime_package(package_name: str, *, context_label: str) -> None: |
| try: |
| get_installed_package_version(package_name) |
| except (PackageNotFoundError, StopIteration) as exc: |
| raise ImportError( |
| f"{context_label} nepieciešams instalēt '{package_name}' Python pakotni." |
| ) from exc |
|
|
|
|
| def _build_distributed_training_argument_overrides(config: TrainingConfig) -> dict[str, Any]: |
| overrides: dict[str, Any] = {} |
| if config.use_accelerate and config.num_processes > 1: |
| overrides["ddp_find_unused_parameters"] = False |
|
|
| if config.distributed_strategy == "fsdp": |
| fsdp_config_path = _resolve_distributed_config_path(config) |
| fsdp_config = ( |
| _load_json_object(fsdp_config_path, label="FSDP konfigurācija") |
| if fsdp_config_path |
| else {} |
| ) |
| fsdp_config.setdefault("min_num_params", config.fsdp_min_num_params) |
| if config.fsdp_transformer_layer_cls_to_wrap: |
| fsdp_config["transformer_layer_cls_to_wrap"] = list( |
| config.fsdp_transformer_layer_cls_to_wrap |
| ) |
| overrides.update( |
| { |
| "fsdp": "full_shard auto_wrap", |
| "fsdp_config": fsdp_config, |
| "ddp_find_unused_parameters": False, |
| } |
| ) |
| elif config.distributed_strategy == "deepspeed": |
| deepspeed_config_path = _resolve_distributed_config_path(config) |
| if not deepspeed_config_path: |
| raise ValueError( |
| "DeepSpeed režīmam vajag distributed_config_path vai repo default konfigurāciju." |
| ) |
| if not Path(deepspeed_config_path).is_file(): |
| raise FileNotFoundError(f"DeepSpeed konfigurācija nav atrasta: {deepspeed_config_path}") |
| _require_runtime_package("deepspeed", context_label="DeepSpeed režīms") |
| overrides.update( |
| { |
| "deepspeed": deepspeed_config_path, |
| "ddp_find_unused_parameters": False, |
| } |
| ) |
|
|
| return overrides |
|
|
|
|
| def _load_tokenizer(model_name: str, config: TrainingConfig) -> Any: |
| from transformers import AutoTokenizer |
|
|
| load_kwargs = _filter_supported_kwargs( |
| AutoTokenizer.from_pretrained, |
| { |
| "trust_remote_code": True, |
| **_build_hf_auth_kwargs(AutoTokenizer.from_pretrained), |
| "use_fast": True, |
| }, |
| ) |
| with maris_hf_compatible_path(model_name, allow_remote_snapshot=True) as compatible_model_path: |
| try: |
| return _load_tokenizer_from_path( |
| compatible_model_path, |
| config=config, |
| auto_tokenizer_class=AutoTokenizer, |
| load_kwargs=load_kwargs, |
| ) |
| except TOKENIZER_LOAD_RETRY_EXCEPTIONS as exc: |
| if not _should_retry_after_installing_tokenizer_backends(exc): |
| raise |
| _install_missing_tokenizer_backends() |
| logger.warning( |
| "Retrying tokenizer load for %s after installing missing runtime backends.", |
| config.model_name, |
| ) |
| return _load_tokenizer_from_path( |
| compatible_model_path, |
| config=config, |
| auto_tokenizer_class=AutoTokenizer, |
| load_kwargs=load_kwargs, |
| ) |
|
|
|
|
| def _load_tokenizer_from_path( |
| compatible_model_path: str, |
| *, |
| config: TrainingConfig, |
| auto_tokenizer_class: Any, |
| load_kwargs: dict[str, Any], |
| ) -> Any: |
| try: |
| return auto_tokenizer_class.from_pretrained(compatible_model_path, **load_kwargs) |
| |
| |
| except TOKENIZER_LOAD_RETRY_EXCEPTIONS as exc: |
| if "use_fast" not in load_kwargs: |
| raise |
| retry_kwargs = dict(load_kwargs) |
| retry_kwargs["use_fast"] = False |
| logger.warning( |
| "Fast tokenizer load failed for %s; retrying with use_fast=False: %s", |
| config.model_name, |
| exc, |
| ) |
| try: |
| return auto_tokenizer_class.from_pretrained(compatible_model_path, **retry_kwargs) |
| except TOKENIZER_LOAD_RETRY_EXCEPTIONS: |
| explicit_slow_tokenizer = _load_explicit_slow_tokenizer( |
| compatible_model_path, |
| retry_kwargs, |
| ) |
| if explicit_slow_tokenizer is not None: |
| return explicit_slow_tokenizer |
| raise |
|
|
|
|
| def _should_retry_after_installing_tokenizer_backends(exc: BaseException) -> bool: |
| message = str(exc).lower() |
| missing_backend_markers = ( |
| "you need to have sentencepiece or tiktoken installed", |
| "requires the sentencepiece library", |
| "requires the tiktoken library", |
| "no module named 'sentencepiece'", |
| 'no module named "sentencepiece"', |
| "no module named 'tiktoken'", |
| 'no module named "tiktoken"', |
| ) |
| return any(marker in message for marker in missing_backend_markers) |
|
|
|
|
| def _install_missing_tokenizer_backends() -> bool: |
| missing_packages = [ |
| package_name |
| for package_name in ("sentencepiece", "tiktoken") |
| if importlib.util.find_spec(package_name) is None |
| ] |
| if not missing_packages: |
| return False |
| logger.warning( |
| "Installing missing tokenizer runtime backends for training: %s", |
| ", ".join(missing_packages), |
| ) |
| try: |
| completed = subprocess.run( |
| [sys.executable, "-m", "pip", "install", "--no-cache-dir", *missing_packages], |
| check=True, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.STDOUT, |
| text=True, |
| ) |
| except subprocess.CalledProcessError as exc: |
| output = (exc.stdout or "").strip() |
| if output: |
| logger.error("Tokenizer backend install failed: %s", output) |
| raise RuntimeError("Couldn't install missing tokenizer backends automatically.") from exc |
| output = (completed.stdout or "").strip() |
| if output: |
| logger.info("Tokenizer backend install output: %s", output) |
| importlib.invalidate_caches() |
| return True |
|
|
|
|
| def _load_local_tokenizer_artifact(path: Path) -> dict[str, Any] | None: |
| """Nolasa lokālu tokenizer JSON artefaktu un atgriež dict vai None.""" |
| if not path.is_file(): |
| return None |
| try: |
| payload = json.loads(path.read_text(encoding="utf-8")) |
| except (OSError, json.JSONDecodeError) as exc: |
| logger.warning("Couldn't read tokenizer artifact %s: %s", path, exc) |
| return None |
| return payload if isinstance(payload, dict) else None |
|
|
|
|
| def _collect_explicit_slow_tokenizer_classes(model_name_or_path: str) -> list[str]: |
| """Savāc deduplicētus slow tokenizer klašu nosaukumus no lokāliem HF artefaktiem.""" |
| model_dir = Path(model_name_or_path) |
| if not model_dir.is_dir(): |
| return [] |
|
|
| class_names: list[str] = [] |
| seen: set[str] = set() |
|
|
| def add_class_name(candidate: Any) -> None: |
| if not isinstance(candidate, str): |
| return |
| normalized = candidate.strip() |
| if not normalized: |
| return |
| if normalized.endswith("Fast"): |
| normalized = normalized[:-4] |
| if not normalized or normalized in seen: |
| return |
| seen.add(normalized) |
| class_names.append(normalized) |
|
|
| for artifact_name in ("tokenizer_config.json", "config.json"): |
| payload = _load_local_tokenizer_artifact(model_dir / artifact_name) |
| if payload is None: |
| continue |
| add_class_name(payload.get("tokenizer_class")) |
| auto_map = payload.get("auto_map") |
| if not isinstance(auto_map, dict): |
| continue |
| auto_tokenizer_entry = auto_map.get("AutoTokenizer") |
| if isinstance(auto_tokenizer_entry, list): |
| for candidate in auto_tokenizer_entry: |
| add_class_name(candidate) |
| else: |
| add_class_name(auto_tokenizer_entry) |
|
|
| return class_names |
|
|
|
|
| def _load_explicit_slow_tokenizer( |
| model_name_or_path: str, load_kwargs: dict[str, Any] |
| ) -> Any | None: |
| """Mēģina ielādēt explicit slow tokenizer klasi pēc AutoTokenizer fallback neveiksmes.""" |
| class_names = _collect_explicit_slow_tokenizer_classes(model_name_or_path) |
| if not class_names: |
| return None |
|
|
| import transformers |
|
|
| retry_kwargs = dict(load_kwargs) |
| retry_kwargs.pop("use_fast", None) |
|
|
| for class_name in class_names: |
| tokenizer_class = getattr(transformers, class_name, None) |
| if tokenizer_class is None: |
| continue |
| try: |
| logger.warning( |
| "AutoTokenizer retry still failed for %s; loading explicit slow tokenizer %s.", |
| model_name_or_path, |
| class_name, |
| ) |
| return tokenizer_class.from_pretrained(model_name_or_path, **retry_kwargs) |
| except TOKENIZER_LOAD_RETRY_EXCEPTIONS as exc: |
| logger.warning( |
| "Explicit slow tokenizer %s failed for %s: %s", |
| class_name, |
| model_name_or_path, |
| exc, |
| ) |
| return None |
|
|
|
|
| def _load_model( |
| model_name_or_path: str, |
| config: TrainingConfig, |
| *, |
| trainable_adapter: bool = False, |
| ) -> Any: |
| from transformers import AutoModelForCausalLM |
|
|
| with maris_hf_compatible_path( |
| model_name_or_path, allow_remote_snapshot=True |
| ) as compatible_model_path: |
| local_path = Path(compatible_model_path) |
| has_local_adapter_config = ( |
| local_path.exists() and local_path.joinpath("adapter_config.json").is_file() |
| ) |
| should_try_auto_peft = _uses_peft(config) and ( |
| trainable_adapter or has_local_adapter_config |
| ) |
| if should_try_auto_peft: |
| try: |
| import peft |
| except ImportError: |
| auto_peft_model_cls = None |
| else: |
| auto_peft_model_cls = getattr(peft, "AutoPeftModelForCausalLM", None) |
| if auto_peft_model_cls is not None: |
| adapter_kwargs = {} |
| if ( |
| "is_trainable" |
| in inspect.signature(auto_peft_model_cls.from_pretrained).parameters |
| ): |
| adapter_kwargs["is_trainable"] = trainable_adapter |
| adapter_kwargs.update(_build_hf_auth_kwargs(auto_peft_model_cls.from_pretrained)) |
| adapter_kwargs.setdefault("trust_remote_code", True) |
| return auto_peft_model_cls.from_pretrained( |
| compatible_model_path, |
| **_filter_supported_kwargs( |
| auto_peft_model_cls.from_pretrained, |
| adapter_kwargs, |
| ), |
| ) |
|
|
| return AutoModelForCausalLM.from_pretrained( |
| compatible_model_path, |
| **_filter_supported_kwargs( |
| AutoModelForCausalLM.from_pretrained, |
| { |
| **_build_model_load_kwargs(config), |
| **_build_hf_auth_kwargs(AutoModelForCausalLM.from_pretrained), |
| }, |
| ), |
| ) |
|
|
|
|
| def _apply_peft_adapter(model: Any, config: TrainingConfig) -> Any: |
| if not _uses_peft(config): |
| return model |
|
|
| try: |
| from peft import ( |
| LoraConfig, |
| TaskType, |
| get_peft_model, |
| prepare_model_for_kbit_training, |
| ) |
| except ImportError as exc: |
| raise ImportError("LoRA/QLoRA apmācībai vajag peft atkarību.") from exc |
|
|
| if _uses_qlora(config): |
| model = prepare_model_for_kbit_training( |
| model, |
| use_gradient_checkpointing=config.gradient_checkpointing, |
| ) |
|
|
| target_modules: str | list[str] | None |
| if config.peft_target_modules: |
| target_modules = list(config.peft_target_modules) |
| else: |
| target_modules = "all-linear" |
|
|
| peft_config = LoraConfig( |
| r=config.lora_r, |
| lora_alpha=config.lora_alpha, |
| lora_dropout=config.lora_dropout, |
| bias=config.lora_bias, |
| task_type=TaskType.CAUSAL_LM, |
| target_modules=target_modules, |
| ) |
| adapted_model = get_peft_model(model, peft_config) |
| if hasattr(adapted_model, "print_trainable_parameters"): |
| adapted_model.print_trainable_parameters() |
| return adapted_model |
|
|
|
|
| def _supports_gradient_checkpointing_kwargs(method: Any) -> bool: |
| with suppress(TypeError, ValueError): |
| signature = inspect.signature(method) |
| return "gradient_checkpointing_kwargs" in signature.parameters or any( |
| parameter.kind is inspect.Parameter.VAR_KEYWORD |
| for parameter in signature.parameters.values() |
| ) |
| return False |
|
|
|
|
| def _prepare_training_model(model_name: str, tokenizer: Any, config: TrainingConfig) -> Any: |
| model = _load_model(model_name, config) |
| if ( |
| getattr(model, "config", None) is not None |
| and getattr(model.config, "pad_token_id", None) is None |
| ): |
| model.config.pad_token_id = getattr(tokenizer, "pad_token_id", None) |
| if getattr(model, "config", None) is not None and hasattr(model.config, "use_cache"): |
| model.config.use_cache = not config.gradient_checkpointing |
| if config.gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): |
| gradient_checkpointing_enable = model.gradient_checkpointing_enable |
| if config.gradient_checkpointing_use_reentrant is None: |
| gradient_checkpointing_enable() |
| else: |
| if _supports_gradient_checkpointing_kwargs(gradient_checkpointing_enable): |
| gradient_checkpointing_enable( |
| gradient_checkpointing_kwargs={ |
| "use_reentrant": config.gradient_checkpointing_use_reentrant |
| } |
| ) |
| else: |
| logger.warning( |
| "Ignoring explicit gradient_checkpointing_use_reentrant=%s because " |
| "gradient_checkpointing_enable() does not accept keyword overrides in " |
| "the current runtime.", |
| config.gradient_checkpointing_use_reentrant, |
| ) |
| gradient_checkpointing_enable() |
| return _apply_peft_adapter(model, config) |
|
|
|
|
| def _load_preference_training_dataset(config: TrainingConfig) -> Any: |
| examples = _filter_preference_examples_for_branch( |
| load_preference_dataset(config.preference_dataset_path), |
| branch_name=config.branch_name, |
| branch_filter_rules=config.branch_dataset_filter_rules, |
| ) |
| rows = [ |
| { |
| "prompt": example.prompt, |
| "chosen": example.chosen, |
| "rejected": example.rejected, |
| "source": example.source, |
| "annotator": example.annotator or "", |
| "edit_target": example.edit_target or "", |
| "context": example.context or "", |
| "branch": example.branch or "", |
| "task_type": example.task_type or "", |
| "language": example.language or "", |
| "repo_context": list(example.repo_context), |
| "execution_required": example.execution_required, |
| "tags": list(example.tags), |
| } |
| for example in examples |
| ] |
| try: |
| from datasets import Dataset |
| except ImportError: |
| return rows |
| try: |
| return Dataset.from_list(rows) |
| except AttributeError as exc: |
| if "PreTrainedTokenizerBase" in str(exc): |
| logger.warning( |
| "Falling back to plain preference rows because datasets fingerprinting needs a fuller transformers install: %s", |
| exc, |
| ) |
| return rows |
| raise |
|
|
|
|
| def _filter_preference_examples_for_branch( |
| examples: list[PreferenceExample], |
| *, |
| branch_name: str, |
| branch_filter_rules: dict[str, dict[str, Any]] | None = None, |
| ) -> list[PreferenceExample]: |
| if branch_name not in TEXT_TRAINABLE_BRANCHES: |
| return examples |
| filtered = [ |
| example |
| for example in examples |
| if _matches_branch_filter_rule( |
| _preference_branch_filter_signals(example), |
| branch_name=branch_name, |
| branch_filter_rules=branch_filter_rules, |
| ) |
| ] |
| if filtered: |
| return filtered |
| raise ValueError( |
| f"Preference dataset filter neatgrieza nevienu piemēru atzaram '{branch_name}'." |
| ) |
|
|
|
|
| def _is_local_training_artifact_dir(path: Path) -> bool: |
| """Return whether a directory contains a reusable saved model or adapter artifact.""" |
| return path.is_dir() and any( |
| path.joinpath(name).is_file() for name in LOCAL_TRAINING_ARTIFACT_FILES |
| ) |
|
|
|
|
| def _build_model_source_fingerprint(model_name: str) -> str: |
| normalized = model_name.strip().casefold().encode("utf-8") |
| return hashlib.sha256(normalized).hexdigest() |
|
|
|
|
| def _load_local_model_source_fingerprint(path: Path) -> str | None: |
| training_config = _load_json_if_exists(path / "training-config.json") |
| if not isinstance(training_config, dict): |
| return None |
| fingerprint = training_config.get(MODEL_SOURCE_FINGERPRINT_KEY) |
| if not isinstance(fingerprint, str): |
| return None |
| normalized = fingerprint.strip().casefold() |
| return normalized or None |
|
|
|
|
| def _resolve_training_model_source(config: TrainingConfig) -> str: |
| """Resolve the effective source model, preferring local persistent artifacts when enabled.""" |
| if not config.continue_from_latest_artifact: |
| return config.model_name |
|
|
| candidates: list[tuple[Path, bool]] = [] |
| if config.continue_model_path: |
| candidates.append((Path(config.continue_model_path), True)) |
| output_dir = Path(config.output_dir) |
| candidates.append((output_dir, False)) |
| expected_fingerprint = _build_model_source_fingerprint(config.model_name) |
|
|
| seen: set[Path] = set() |
| for candidate, explicit_candidate in candidates: |
| try: |
| resolved = candidate.expanduser().resolve() |
| except OSError: |
| continue |
| if resolved in seen: |
| continue |
| seen.add(resolved) |
| if _is_local_training_artifact_dir(resolved): |
| if not explicit_candidate: |
| saved_fingerprint = _load_local_model_source_fingerprint(resolved) |
| if saved_fingerprint != expected_fingerprint: |
| if saved_fingerprint: |
| logger.info( |
| "Izlaižu auto-continue no %s, jo saglabātā modeļa bāze neatbilst izvēlētajam modelim.", |
| resolved, |
| ) |
| else: |
| logger.info( |
| "Izlaižu auto-continue no %s, jo trūkst modeļa saderības metadata.", |
| resolved, |
| ) |
| continue |
| logger.info("Turpinu treniņu no lokālā artefakta: %s", resolved) |
| return str(resolved) |
|
|
| return config.model_name |
|
|
|
|
| def _resolve_runtime_uid_suffix() -> str: |
| """Best-effort UID suffix for runtime fallback paths and user labels.""" |
| uid = "unknown" |
| getuid = getattr(os, "getuid", None) |
| if callable(getuid): |
| with suppress(OSError): |
| uid = str(getuid()) |
| return uid |
|
|
|
|
| def _ensure_runtime_home_dir() -> str: |
| """Ensure runtime HOME and user identity vars exist in containers without passwd entries.""" |
| configured_home = os.environ.get("HOME", "").strip() |
| uid = _resolve_runtime_uid_suffix() |
|
|
| if not configured_home: |
| fallback_home = (Path(tempfile.gettempdir()) / f"maris-home-{uid}").resolve() |
| fallback_home.mkdir(parents=True, exist_ok=True) |
| configured_home = str(fallback_home) |
| os.environ["HOME"] = configured_home |
| logger.warning( |
| "HOME nav iestatīts; izmantojam fallback runtime home %s, lai treniņš strādātu konteineros bez passwd ieraksta.", |
| fallback_home, |
| ) |
|
|
| runtime_user = "" |
| for name in ("USER", "LOGNAME", "USERNAME"): |
| candidate = os.environ.get(name, "").strip() |
| if candidate: |
| runtime_user = candidate |
| break |
| if not runtime_user: |
| runtime_user = f"maris-{uid}" |
| logger.warning( |
| "Lietotāja vides mainīgie nav iestatīti; izmantojam fallback runtime lietotāju %s, lai izvairītos no getpwuid kļūdām.", |
| runtime_user, |
| ) |
| for name in ("USER", "LOGNAME", "USERNAME"): |
| if not os.environ.get(name, "").strip(): |
| os.environ[name] = runtime_user |
|
|
| return configured_home |
|
|
|
|
| def _filter_supported_kwargs(callable_obj: Any, kwargs: dict[str, Any]) -> dict[str, Any]: |
| signature = inspect.signature(callable_obj) |
| parameters = signature.parameters |
| if any(parameter.kind is inspect.Parameter.VAR_KEYWORD for parameter in parameters.values()): |
| return kwargs |
| supported_names = { |
| name |
| for name, parameter in parameters.items() |
| if name != "self" |
| and parameter.kind |
| in {inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY} |
| } |
| return {name: value for name, value in kwargs.items() if name in supported_names} |
|
|
|
|
| def _build_preference_training_arguments(preference_config_cls: Any, config: TrainingConfig) -> Any: |
| return _build_training_arguments( |
| preference_config_cls, |
| output_dir=config.output_dir, |
| overwrite_output_dir=True, |
| num_train_epochs=config.num_epochs, |
| learning_rate=config.learning_rate, |
| per_device_train_batch_size=config.per_device_train_batch_size, |
| per_device_eval_batch_size=config.per_device_eval_batch_size, |
| gradient_accumulation_steps=config.gradient_accumulation_steps, |
| warmup_ratio=config.warmup_ratio, |
| weight_decay=config.weight_decay, |
| logging_steps=config.logging_steps, |
| save_steps=config.save_steps, |
| eval_steps=config.eval_steps, |
| save_total_limit=config.save_total_limit, |
| lr_scheduler_type=config.lr_scheduler_type, |
| seed=config.seed, |
| fp16=config.fp16, |
| bf16=config.bf16, |
| report_to=config.report_to, |
| save_safetensors=config.save_safetensors, |
| remove_unused_columns=False, |
| beta=config.preference_beta, |
| max_length=config.preference_max_length, |
| max_prompt_length=config.preference_max_prompt_length, |
| evaluation_strategy="no", |
| ) |
|
|
|
|
| def _load_preference_reference_model(output_dir: Path, config: TrainingConfig) -> Any: |
| reference_path = config.preference_reference_model or str(output_dir) |
| return _load_model(reference_path, config, trainable_adapter=False) |
|
|
|
|
| def _run_preference_optimization( |
| model: Any, |
| tokenizer: Any, |
| config: TrainingConfig, |
| *, |
| output_dir: Path, |
| ) -> tuple[Any, Any, dict[str, float]]: |
| if not _uses_preference_optimization(config): |
| return model, None, {} |
|
|
| try: |
| import trl |
| except ImportError as exc: |
| raise ImportError("Preference optimization (DPO/ORPO) vajag trl atkarību.") from exc |
|
|
| preference_dataset = _load_preference_training_dataset(config) |
| strategy = config.preference_optimization |
| trainer_name = f"{strategy.upper()}Trainer" |
| trainer_cls = getattr(trl, trainer_name, None) |
| if trainer_cls is None: |
| raise RuntimeError(f"TRL modulī nav pieejams {trainer_name}.") |
| trainer_config_cls = getattr(trl, f"{strategy.upper()}Config", None) |
| if trainer_config_cls is None: |
| from transformers import TrainingArguments |
|
|
| trainer_config_cls = TrainingArguments |
|
|
| preference_args = _build_preference_training_arguments(trainer_config_cls, config) |
| trainer_kwargs: dict[str, Any] = { |
| "model": model, |
| "args": preference_args, |
| "train_dataset": preference_dataset, |
| "tokenizer": tokenizer, |
| "processing_class": tokenizer, |
| } |
| if strategy == "dpo": |
| trainer_kwargs["ref_model"] = _load_preference_reference_model(output_dir, config) |
|
|
| preference_trainer = trainer_cls( |
| **_filter_supported_kwargs(trainer_cls.__init__, trainer_kwargs) |
| ) |
| preference_result = preference_trainer.train() |
| preference_metrics = { |
| f"preference_{key}": float(value) |
| for key, value in dict(getattr(preference_result, "metrics", {})).items() |
| if isinstance(value, int | float) |
| } |
| preference_metrics["preference_examples"] = float(len(preference_dataset)) |
| preference_metrics["preference_stage"] = 1.0 |
| preference_metrics["preference_strategy"] = 1.0 if strategy == "dpo" else 2.0 |
| return model, preference_trainer, preference_metrics |
|
|
|
|
| def _save_json(path: Path, payload: dict[str, Any]) -> None: |
| _save_json_payload(path, payload) |
|
|
|
|
| def _save_json_payload(path: Path, payload: Any) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| path.write_text(json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8") |
|
|
|
|
| def _load_json_if_exists(path: Path) -> dict[str, Any] | None: |
| if not path.is_file(): |
| return None |
| payload = json.loads(path.read_text(encoding="utf-8")) |
| return payload if isinstance(payload, dict) else None |
|
|
|
|
| def _write_text(path: Path, content: str) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| path.write_text(content, encoding="utf-8") |
|
|
|
|
| def _build_training_artifact_config(config: TrainingConfig) -> dict[str, Any]: |
| payload = config.to_dict().copy() |
| payload.pop("model_name", None) |
| payload["model_preset"] = config.model_preset or "custom" |
| payload[MODEL_SOURCE_FINGERPRINT_KEY] = _build_model_source_fingerprint(config.model_name) |
| payload["maris_origin"] = MARIS_ORIGIN_NAME |
| payload["maris_framework"] = MARIS_FRAMEWORK_NAME |
| payload["maris_model_id"] = config.hub_model_id |
| payload["artifact_format"] = "maris-training-config-v1" |
| return payload |
|
|
|
|
| def _build_artifact_identity(config: TrainingConfig) -> dict[str, Any]: |
| return { |
| "maris_origin": MARIS_ORIGIN_NAME, |
| "maris_framework": MARIS_FRAMEWORK_NAME, |
| "dataset_repo": config.dataset_repo, |
| "dataset_repos": _resolve_training_dataset_repos(config), |
| "eval_dataset_repo": _resolve_primary_eval_dataset_repo(config), |
| "eval_dataset_repos": _resolve_eval_dataset_repos(config), |
| "branch_name": config.branch_name, |
| "branch_focus": config.branch_focus, |
| } |
|
|
|
|
| def _build_model_artifact_identity(config: TrainingConfig) -> dict[str, Any]: |
| return { |
| **_build_artifact_identity(config), |
| "maris_model_id": config.hub_model_id, |
| } |
|
|
|
|
| def _build_metrics_artifact( |
| config: TrainingConfig, |
| metrics: dict[str, float], |
| *, |
| artifact_type: str, |
| extra_payload: dict[str, Any] | None = None, |
| ) -> dict[str, Any]: |
| payload = dict(metrics) |
| payload.update(_build_model_artifact_identity(config)) |
| payload["artifact_type"] = artifact_type |
| if extra_payload: |
| payload.update(extra_payload) |
| return payload |
|
|
|
|
| def _build_branch_suite_artifact( |
| config: TrainingConfig, suite_results: dict[str, dict[str, Any]] |
| ) -> dict[str, Any]: |
| return { |
| "branches": suite_results, |
| **_build_artifact_identity(config), |
| "artifact_type": "branch-suite", |
| } |
|
|
|
|
| def _write_origin_metadata(output_dir: Path, config: TrainingConfig) -> None: |
| _save_json( |
| output_dir / "maris-metadata.json", |
| { |
| "origin": MARIS_ORIGIN_NAME, |
| "framework": MARIS_FRAMEWORK_NAME, |
| **_build_model_artifact_identity(config), |
| }, |
| ) |
|
|
|
|
| def _build_model_card( |
| config: TrainingConfig, |
| *, |
| trained_at: str, |
| train_examples: int, |
| eval_examples: int, |
| metrics: dict[str, float], |
| ) -> str: |
| metrics_lines = [] |
| for key in ("train_loss", "eval_loss", "perplexity"): |
| if key in metrics: |
| metrics_lines.append(f"- `{key}`: {metrics[key]}") |
| if not metrics_lines: |
| metrics_lines.append("- Nav pieejamu metriku šim skrējienam.") |
|
|
| return "\n".join( |
| [ |
| "---", |
| "language:", |
| " - lv", |
| " - en", |
| "license: mit", |
| "pipeline_tag: text-generation", |
| "tags:", |
| *(f" - {tag}" for tag in MODEL_CARD_TAGS), |
| "---", |
| "", |
| f"# {MARIS_ORIGIN_NAME} Model", |
| "", |
| "## Model Summary", |
| "", |
| f"Šis artefakts ir {MARIS_ORIGIN_NAME} treniņa izvads atzaram `{config.branch_name}`.", |
| f"Modelis tiek publicēts kā `{config.hub_model_id}` un saglabā Maris AI identitāti visos galvenajos artefaktos.", |
| "", |
| "## Training Data", |
| "", |
| f"- Dataset repo: `{config.dataset_repo}`", |
| f"- Dataset repos: `{', '.join(_resolve_training_dataset_repos(config))}`", |
| f"- Eval dataset repo: `{_resolve_primary_eval_dataset_repo(config)}`", |
| f"- Eval dataset repos: `{', '.join(_resolve_eval_dataset_repos(config) or [_resolve_primary_eval_dataset_repo(config)])}`", |
| f"- Train examples: `{train_examples}`", |
| f"- Eval examples: `{eval_examples}`", |
| f"- Branch focus: `{config.branch_focus}`", |
| "", |
| "## Training Configuration", |
| "", |
| "- Base lineage: `Maris AI`", |
| f"- Model preset: `{config.model_preset or 'custom'}`", |
| f"- Adapter strategy: `{config.adapter_type}`", |
| f"- Preference optimization: `{config.preference_optimization}`", |
| f"- Output model id: `{config.hub_model_id}`", |
| f"- Trained at: `{trained_at}`", |
| "", |
| "## Metrics", |
| "", |
| *metrics_lines, |
| "", |
| "## Maris Identity Guarantees", |
| "", |
| f"- `config.json`, `generation_config.json` un `tokenizer_config.json` tiek sanitizēti uz `{config.hub_model_id}`", |
| f"- `training-config.json`, `training-provenance.json` un `maris-metadata.json` satur tikai `{MARIS_ORIGIN_NAME}` izcelsmes metadatus", |
| "- Šis artefakts ir paredzēts publicēšanai un lietošanai kā Maris AI modelis", |
| "", |
| ] |
| ) |
|
|
|
|
| def _write_training_provenance( |
| output_dir: Path, |
| config: TrainingConfig, |
| *, |
| trained_at: str, |
| train_examples: int, |
| eval_examples: int, |
| metrics: dict[str, float], |
| quality_report: DatasetQualityReport | None = None, |
| scoring_report: DatasetScoringReport | None = None, |
| benchmark_feedback: DatasetBenchmarkFeedback | None = None, |
| ) -> None: |
| artifact_files = [ |
| "README.md", |
| "training-config.json", |
| "training-metrics.json", |
| "maris-metadata.json", |
| "training-provenance.json", |
| ] |
| if quality_report is not None: |
| artifact_files.append("dataset-quality-report.json") |
| if scoring_report is not None: |
| artifact_files.append("dataset-scoring-report.json") |
| _save_json( |
| output_dir / "training-provenance.json", |
| { |
| "trained_at": trained_at, |
| "dataset_repo": config.dataset_repo, |
| "dataset_repos": _resolve_training_dataset_repos(config), |
| "eval_dataset_repo": _resolve_primary_eval_dataset_repo(config), |
| "eval_dataset_repos": _resolve_eval_dataset_repos(config), |
| "branch_name": config.branch_name, |
| "branch_focus": config.branch_focus, |
| "adapter_type": config.adapter_type, |
| "preference_optimization": config.preference_optimization, |
| "base_model_name": MARIS_ORIGIN_NAME, |
| "base_model_lineage": MARIS_ORIGIN_NAME, |
| "model_preset": config.model_preset or "custom", |
| "hub_model_id": config.hub_model_id, |
| "train_examples": train_examples, |
| "eval_examples": eval_examples, |
| "artifact_files": artifact_files, |
| "metrics": metrics, |
| "quality_report": quality_report.to_dict() if quality_report is not None else None, |
| "scoring_report": scoring_report.to_dict() if scoring_report is not None else None, |
| "benchmark_feedback": ( |
| build_benchmark_feedback_artifact(benchmark_feedback) |
| if benchmark_feedback is not None |
| else None |
| ), |
| "maris_origin": MARIS_ORIGIN_NAME, |
| "maris_framework": MARIS_FRAMEWORK_NAME, |
| }, |
| ) |
|
|
|
|
| def _sanitize_saved_artifact_json(path: Path, *, maris_model_id: str) -> None: |
| if not path.is_file(): |
| return |
|
|
| payload = json.loads(path.read_text(encoding="utf-8")) |
| sanitized_payload = _sanitize_artifact_payload( |
| payload, |
| maris_model_id=maris_model_id, |
| ) |
| if isinstance(sanitized_payload, dict): |
| sanitized_payload["maris_origin"] = MARIS_ORIGIN_NAME |
| sanitized_payload["maris_framework"] = MARIS_FRAMEWORK_NAME |
| sanitized_payload["maris_model_id"] = maris_model_id |
| _save_json_payload(path, sanitized_payload) |
|
|
|
|
| def _sanitize_identity_text(value: str, *, maris_model_id: str) -> str: |
| sanitized = FOREIGN_MODEL_REFERENCE_RE.sub(maris_model_id, value) |
| for pattern in FOREIGN_AI_BRAND_PATTERNS: |
| sanitized = pattern.sub(MARIS_ORIGIN_NAME, sanitized) |
| sanitized = MARIS_IDENTITY_VARIANT_RE.sub(MARIS_ORIGIN_NAME, sanitized) |
| return sanitized |
|
|
|
|
| def _is_sanitized_reference_key(field_name: str | None) -> bool: |
| return bool(field_name) and ( |
| field_name in SANITIZED_ARTIFACT_KEYS or field_name.endswith("_name_or_path") |
| ) |
|
|
|
|
| def _sanitize_artifact_payload( |
| payload: Any, |
| *, |
| maris_model_id: str, |
| field_name: str | None = None, |
| path: tuple[str, ...] = (), |
| ) -> Any: |
| if isinstance(payload, dict): |
| return { |
| key: _sanitize_artifact_payload( |
| value, |
| maris_model_id=maris_model_id, |
| field_name=str(key), |
| path=(*path, str(key)), |
| ) |
| for key, value in payload.items() |
| } |
| if isinstance(payload, list): |
| return [ |
| _sanitize_artifact_payload( |
| item, |
| maris_model_id=maris_model_id, |
| field_name=field_name, |
| path=(*path, str(index)), |
| ) |
| for index, item in enumerate(payload) |
| ] |
| if not isinstance(payload, str): |
| return payload |
| if _is_sanitized_reference_key(field_name): |
| return maris_model_id |
| if field_name in {"base_model_name", "base_model_lineage"}: |
| return MARIS_ORIGIN_NAME |
| if field_name in IDENTITY_TEXT_KEYS or "chat_template" in path: |
| return _sanitize_identity_text(payload, maris_model_id=maris_model_id) |
| if FOREIGN_MODEL_REFERENCE_RE.search(payload) or any( |
| pattern.search(payload) for pattern in FOREIGN_AI_BRAND_PATTERNS |
| ): |
| return _sanitize_identity_text(payload, maris_model_id=maris_model_id) |
| return payload |
|
|
|
|
| def _is_text_identity_artifact(path: Path) -> bool: |
| lowered = path.name.lower() |
| return ( |
| path.suffix.lower() in TEXT_SANITIZED_ARTIFACT_EXTENSIONS |
| or "chat_template" in lowered |
| or "chat-template" in lowered |
| ) |
|
|
|
|
| def _sanitize_text_identity_artifact(path: Path, *, maris_model_id: str) -> None: |
| if not path.is_file(): |
| return |
| path.write_text( |
| _sanitize_identity_text(path.read_text(encoding="utf-8"), maris_model_id=maris_model_id), |
| encoding="utf-8", |
| ) |
|
|
|
|
| def _sanitize_saved_artifacts(output_dir: Path, *, maris_model_id: str) -> None: |
| for path in sorted(output_dir.rglob("*")): |
| if not path.is_file(): |
| continue |
| if path.suffix.lower() == ".json": |
| _sanitize_saved_artifact_json(path, maris_model_id=maris_model_id) |
| continue |
| if _is_text_identity_artifact(path): |
| _sanitize_text_identity_artifact(path, maris_model_id=maris_model_id) |
| write_maris_compatibility_artifact(output_dir, maris_model_id=maris_model_id) |
| apply_maris_compatibility_identity(output_dir) |
|
|
|
|
| def _collect_unsanitized_identity_strings( |
| payload: Any, |
| *, |
| path: tuple[str, ...] = (), |
| ) -> list[str]: |
| if isinstance(payload, dict): |
| matches: list[str] = [] |
| for key, value in payload.items(): |
| matches.extend( |
| _collect_unsanitized_identity_strings( |
| value, |
| path=(*path, str(key)), |
| ) |
| ) |
| return matches |
| if isinstance(payload, list): |
| list_matches: list[str] = [] |
| for index, item in enumerate(payload): |
| list_matches.extend( |
| _collect_unsanitized_identity_strings(item, path=(*path, str(index))) |
| ) |
| return list_matches |
| if not isinstance(payload, str): |
| return [] |
| if FOREIGN_MODEL_REFERENCE_RE.search(payload): |
| return ["/".join(path) or "<root>"] |
| for pattern in FOREIGN_AI_BRAND_PATTERNS: |
| if pattern.search(payload): |
| return ["/".join(path) or "<root>"] |
| return [] |
|
|
|
|
| def _verify_saved_artifacts(output_dir: Path) -> None: |
| issues: list[str] = [] |
| for path in sorted(output_dir.rglob("*")): |
| if not path.is_file(): |
| continue |
| if path.name == MARIS_COMPATIBILITY_ARTIFACT_NAME: |
| continue |
| if path.suffix.lower() == ".json": |
| try: |
| payload = json.loads(path.read_text(encoding="utf-8")) |
| except json.JSONDecodeError: |
| continue |
| for issue in _collect_unsanitized_identity_strings(payload): |
| issues.append(f"{path.relative_to(output_dir)}:{issue}") |
| continue |
| if _is_text_identity_artifact(path): |
| content = path.read_text(encoding="utf-8") |
| if FOREIGN_MODEL_REFERENCE_RE.search(content) or any( |
| pattern.search(content) for pattern in FOREIGN_AI_BRAND_PATTERNS |
| ): |
| issues.append(str(path.relative_to(output_dir))) |
| if issues: |
| unique_issues = sorted(set(issues)) |
| visible_issues = unique_issues[:20] |
| overflow_suffix = "" |
| if len(unique_issues) > len(visible_issues): |
| overflow_suffix = f" ... un vēl {len(unique_issues) - len(visible_issues)}" |
| raise ValueError( |
| "Atrasti nesanitizēti Hugging Face artefakti: " |
| + ", ".join(visible_issues) |
| + overflow_suffix |
| ) |
|
|
|
|
| def _write_training_artifacts( |
| output_dir: Path, |
| config: TrainingConfig, |
| *, |
| trained_at: str, |
| train_examples: int, |
| eval_examples: int, |
| metrics: dict[str, float], |
| quality_report: DatasetQualityReport | None = None, |
| scoring_report: DatasetScoringReport | None = None, |
| benchmark_feedback: DatasetBenchmarkFeedback | None = None, |
| ) -> None: |
| _sanitize_saved_artifacts(output_dir, maris_model_id=config.hub_model_id) |
| _save_json(output_dir / "training-config.json", _build_training_artifact_config(config)) |
| _save_json( |
| output_dir / "training-metrics.json", |
| _build_metrics_artifact( |
| config, |
| metrics, |
| artifact_type="training-metrics", |
| extra_payload=_build_scoring_dashboard_payload( |
| scoring_report=scoring_report, |
| benchmark_feedback=benchmark_feedback, |
| ), |
| ), |
| ) |
| if quality_report is not None: |
| _save_json(output_dir / "dataset-quality-report.json", quality_report.to_dict()) |
| if scoring_report is not None: |
| _save_json(output_dir / "dataset-scoring-report.json", scoring_report.to_dict()) |
| _write_origin_metadata(output_dir, config) |
| _write_training_provenance( |
| output_dir, |
| config, |
| trained_at=trained_at, |
| train_examples=train_examples, |
| eval_examples=eval_examples, |
| metrics=metrics, |
| quality_report=quality_report, |
| scoring_report=scoring_report, |
| benchmark_feedback=benchmark_feedback, |
| ) |
| _write_text( |
| output_dir / "README.md", |
| _build_model_card( |
| config, |
| trained_at=trained_at, |
| train_examples=train_examples, |
| eval_examples=eval_examples, |
| metrics=metrics, |
| ), |
| ) |
|
|
|
|
| def _write_preference_artifact(output_dir: Path, config: TrainingConfig) -> dict[str, Any] | None: |
| if not config.preference_dataset_path: |
| return None |
| examples = _load_filtered_preference_examples(config) |
| payload = summarize_preference_dataset(examples) |
| payload.update(_build_model_artifact_identity(config)) |
| human_eval_summary = build_human_eval_summary(examples) |
| human_eval_summary.update(_build_model_artifact_identity(config)) |
| _save_json(output_dir / "preference-summary.json", payload) |
| _save_json(output_dir / "human-eval-summary.json", human_eval_summary) |
| _save_json( |
| output_dir / "blind-side-by-side-eval.json", |
| build_blind_side_by_side_artifact(examples), |
| ) |
| return payload |
|
|
|
|
| def _load_filtered_preference_examples(config: TrainingConfig) -> list[PreferenceExample]: |
| return _filter_preference_examples_for_branch( |
| load_preference_dataset(config.preference_dataset_path), |
| branch_name=config.branch_name, |
| branch_filter_rules=config.branch_dataset_filter_rules, |
| ) |
|
|
|
|
| def _sync_output_dir_to_hub( |
| output_dir: Path, config: TrainingConfig, *, commit_message: str |
| ) -> None: |
| try: |
| from huggingface_hub import HfApi |
| except ImportError: |
| logger.warning("huggingface_hub nav pieejams — izlaižam Maris artefaktu sinhronizāciju.") |
| return |
|
|
| try: |
| token = get_env_any("MARIS_REPO_TOKEN", "MARIS_TOKEN", "HF_TOKEN") |
| api = HfApi(token=token) if token else HfApi() |
| api.create_repo(repo_id=config.hub_model_id, repo_type="model", exist_ok=True) |
| api.upload_folder( |
| folder_path=str(output_dir), |
| repo_id=config.hub_model_id, |
| repo_type="model", |
| commit_message=commit_message, |
| ) |
| except Exception as exc: |
| logger.warning("Neizdevās pārpublicēt Maris artefaktus uz Hub: %s", exc) |
|
|
|
|
| def _build_training_arguments(training_arguments_cls: Any, **kwargs: Any) -> Any: |
| signature = inspect.signature(training_arguments_cls.__init__) |
| parameters = signature.parameters |
|
|
| if any(parameter.kind is inspect.Parameter.VAR_KEYWORD for parameter in parameters.values()): |
| return training_arguments_cls(**kwargs) |
|
|
| supported_names = { |
| name |
| for name, parameter in parameters.items() |
| if name != "self" |
| and parameter.kind |
| in {inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY} |
| } |
| resolved_kwargs = dict(kwargs) |
|
|
| for source_name, target_name in TRAINING_ARGUMENT_ALIASES.items(): |
| if ( |
| source_name in resolved_kwargs |
| and source_name not in supported_names |
| and target_name in supported_names |
| ): |
| resolved_kwargs[target_name] = resolved_kwargs.pop(source_name) |
|
|
| ignored_arguments = sorted(set(resolved_kwargs) - supported_names) |
| if ignored_arguments: |
| logger.info( |
| "Izlaiž neatbalstītos TrainingArguments parametrus: %s", |
| ", ".join(ignored_arguments), |
| ) |
|
|
| return training_arguments_cls( |
| **{name: value for name, value in resolved_kwargs.items() if name in supported_names} |
| ) |
|
|
|
|
| def _materialize_split_records(split: Any) -> list[dict[str, Any]]: |
| raw_records = getattr(split, "items", split) |
| return [ |
| dict(record) if isinstance(record, dict) else {"text": str(record)} |
| for record in raw_records |
| ] |
|
|
|
|
| def _rebuild_split_like(split: Any, records: list[dict[str, Any]]) -> Any: |
| if type(split).__module__.startswith("datasets"): |
| from datasets import Dataset |
|
|
| return Dataset.from_list(records) |
|
|
| split_type = type(split) |
| try: |
| return split_type(records) |
| except Exception: |
| return records |
|
|
|
|
| def _resolve_branch_benchmark_dataset_path(config: TrainingConfig, branch_name: str) -> str: |
| explicit = str(config.branch_benchmark_dataset_paths.get(branch_name, "") or "").strip() |
| if explicit: |
| return explicit |
| return config.benchmark_dataset_path |
|
|
|
|
| def _resolve_branch_benchmark_name(config: TrainingConfig, branch_name: str) -> str: |
| explicit = str(config.branch_benchmark_names.get(branch_name, "") or "").strip() |
| if explicit: |
| return explicit |
| return config.benchmark_name |
|
|
|
|
| def _resolve_branch_preference_dataset_path(config: TrainingConfig, branch_name: str) -> str: |
| explicit = str(config.branch_preference_dataset_paths.get(branch_name, "") or "").strip() |
| if explicit: |
| return explicit |
| return config.preference_dataset_path |
|
|
|
|
| def _apply_branch_runtime_defaults(config: TrainingConfig) -> TrainingConfig: |
| benchmark_dataset_path = config.benchmark_dataset_path |
| benchmark_name = config.benchmark_name |
| preference_dataset_path = config.preference_dataset_path |
| if benchmark_dataset_path and preference_dataset_path: |
| return config |
| if not benchmark_dataset_path and config.benchmark_gate_enabled: |
| benchmark_dataset_path = _resolve_branch_benchmark_dataset_path(config, config.branch_name) |
| if config.benchmark_gate_enabled and ( |
| not benchmark_name or benchmark_name == DEFAULT_BENCHMARK_NAME |
| ): |
| benchmark_name = _resolve_branch_benchmark_name(config, config.branch_name) |
| if not preference_dataset_path: |
| preference_dataset_path = _resolve_branch_preference_dataset_path( |
| config, config.branch_name |
| ) |
| if ( |
| benchmark_dataset_path == config.benchmark_dataset_path |
| and benchmark_name == config.benchmark_name |
| and preference_dataset_path == config.preference_dataset_path |
| ): |
| return config |
| return replace( |
| config, |
| benchmark_dataset_path=benchmark_dataset_path, |
| benchmark_name=benchmark_name, |
| preference_dataset_path=preference_dataset_path, |
| ) |
|
|
|
|
| def _normalize_branch_filter_token(value: Any) -> str: |
| return str(value or "").strip().lower().replace("-", "_").replace(" ", "_") |
|
|
|
|
| def _normalize_branch_filter_values(value: Any) -> set[str]: |
| if isinstance(value, list | tuple | set | frozenset): |
| return { |
| normalized for item in value if (normalized := _normalize_branch_filter_token(item)) |
| } |
| normalized = _normalize_branch_filter_token(value) |
| return {normalized} if normalized else set() |
|
|
|
|
| def _record_branch_labels(record: dict[str, Any]) -> set[str]: |
| labels: set[str] = set() |
|
|
| def add_label(value: Any) -> None: |
| if isinstance(value, str): |
| normalized = _normalize_branch_filter_token(value) |
| if normalized: |
| labels.add(normalized) |
| elif isinstance(value, list): |
| for item in value: |
| add_label(item) |
|
|
| for key in ("branch", "profile", "type", "task_type", "category"): |
| add_label(record.get(key)) |
| metadata = record.get("metadata") |
| if isinstance(metadata, dict): |
| for key in ("branch", "profile", "type", "task", "workflow", "project_area"): |
| add_label(metadata.get(key)) |
| add_label(metadata.get("tags")) |
| add_label(record.get("tags")) |
| return labels |
|
|
|
|
| def _record_repo_context_terms(record: dict[str, Any]) -> set[str]: |
| values: list[str] = [] |
| repo_context = record.get("repo_context") |
| if isinstance(repo_context, list): |
| values.extend(str(item).strip().lower() for item in repo_context if str(item).strip()) |
| elif isinstance(repo_context, str) and repo_context.strip(): |
| values.append(repo_context.strip().lower()) |
| metadata = record.get("metadata") |
| if isinstance(metadata, dict): |
| for key in ("project_area", "audience"): |
| value = metadata.get(key) |
| if isinstance(value, str) and value.strip(): |
| values.append(value.strip().lower()) |
| return set(values) |
|
|
|
|
| def _record_branch_filter_signals(record: dict[str, Any]) -> BranchFilterSignals: |
| labels = _record_branch_labels(record) |
| metadata = record.get("metadata") |
| task_types = _normalize_branch_filter_values(record.get("task_type")) |
| if isinstance(metadata, dict): |
| task_types.update(_normalize_branch_filter_values(metadata.get("task"))) |
| task_types.update(_normalize_branch_filter_values(metadata.get("workflow"))) |
| languages = _normalize_branch_filter_values(record.get("language")) |
| if isinstance(metadata, dict): |
| languages.update(_normalize_branch_filter_values(metadata.get("language"))) |
| record_types = _normalize_branch_filter_values(record.get("type")) |
| if isinstance(metadata, dict): |
| record_types.update(_normalize_branch_filter_values(metadata.get("type"))) |
| presence_keys = { |
| _normalize_branch_filter_token(key) |
| for key, value in record.items() |
| if key != "metadata" and value not in (None, "", [], {}, ()) |
| } |
| return BranchFilterSignals( |
| explicit_branches=frozenset(label for label in labels if label in TEXT_TRAINABLE_BRANCHES), |
| record_types=frozenset(record_types), |
| task_types=frozenset(task_types), |
| languages=frozenset(languages), |
| repo_context_terms=frozenset(_record_repo_context_terms(record)), |
| presence_keys=frozenset(presence_keys), |
| ) |
|
|
|
|
| def _preference_branch_filter_signals(example: PreferenceExample) -> BranchFilterSignals: |
| explicit_branches = _normalize_branch_filter_values(example.branch) |
| return BranchFilterSignals( |
| explicit_branches=frozenset( |
| label for label in explicit_branches if label in TEXT_TRAINABLE_BRANCHES |
| ), |
| task_types=frozenset(_normalize_branch_filter_values(example.task_type)), |
| languages=frozenset(_normalize_branch_filter_values(example.language)), |
| repo_context_terms=frozenset( |
| _normalize_branch_filter_token(item) |
| for item in example.repo_context |
| if str(item).strip() |
| ), |
| presence_keys=frozenset( |
| key |
| for key, value in { |
| "edit_target": example.edit_target, |
| "context": example.context, |
| "execution_required": example.execution_required, |
| }.items() |
| if value not in (None, "", False) |
| ), |
| ) |
|
|
|
|
| def _matches_branch_filter_rule( |
| signals: BranchFilterSignals, |
| *, |
| branch_name: str, |
| branch_filter_rules: dict[str, dict[str, Any]] | None, |
| ) -> bool: |
| normalized_branch = _normalize_branch_filter_token(branch_name) |
| if normalized_branch not in TEXT_TRAINABLE_BRANCHES: |
| return True |
|
|
| rules = (branch_filter_rules or DEFAULT_BRANCH_DATASET_FILTER_RULES).get(normalized_branch, {}) |
| if not rules: |
| return True |
|
|
| include_explicit_branches = _normalize_branch_filter_values( |
| rules.get("include_explicit_branches") |
| ) |
| exclude_explicit_branches = _normalize_branch_filter_values( |
| rules.get("exclude_explicit_branches") |
| ) |
| if signals.explicit_branches.intersection(exclude_explicit_branches): |
| return False |
| if signals.explicit_branches.intersection(include_explicit_branches): |
| return True |
|
|
| positive_rule_matches = [] |
| for rule_key, signal_values in ( |
| ("include_record_types", signals.record_types), |
| ("include_task_types", signals.task_types), |
| ("include_languages", signals.languages), |
| ("include_repo_context_terms", signals.repo_context_terms), |
| ("include_presence_keys", signals.presence_keys), |
| ): |
| required_values = _normalize_branch_filter_values(rules.get(rule_key)) |
| if required_values: |
| positive_rule_matches.append(bool(signal_values.intersection(required_values))) |
|
|
| allow_unlabeled = bool(rules.get("allow_unlabeled", False)) |
| if any(positive_rule_matches): |
| return True |
| if positive_rule_matches or include_explicit_branches: |
| return allow_unlabeled and not signals.explicit_branches |
| if allow_unlabeled and not signals.explicit_branches: |
| return True |
| return not signals.explicit_branches |
|
|
|
|
| def _filter_records_for_branch( |
| records: list[dict[str, Any]], |
| *, |
| branch_name: str, |
| split_name: str, |
| branch_filter_rules: dict[str, dict[str, Any]] | None = None, |
| ) -> tuple[list[dict[str, Any]], DatasetBranchFilterReport]: |
| if branch_name not in TEXT_TRAINABLE_BRANCHES: |
| report = DatasetBranchFilterReport( |
| split_name=split_name, |
| branch_name=branch_name, |
| input_records=len(records), |
| kept_records=len(records), |
| dropped_records=0, |
| ) |
| return records, report |
|
|
| filtered_records = [ |
| record |
| for record in records |
| if _matches_branch_filter_rule( |
| _record_branch_filter_signals(record), |
| branch_name=branch_name, |
| branch_filter_rules=branch_filter_rules, |
| ) |
| ] |
| report = DatasetBranchFilterReport( |
| split_name=split_name, |
| branch_name=branch_name, |
| input_records=len(records), |
| kept_records=len(filtered_records), |
| dropped_records=max(len(records) - len(filtered_records), 0), |
| ) |
| return filtered_records, report |
|
|
|
|
| def _apply_branch_dataset_filter_to_split( |
| split: Any, |
| *, |
| split_name: str, |
| config: TrainingConfig, |
| allow_empty: bool = False, |
| ) -> tuple[Any | None, DatasetBranchFilterReport]: |
| filtered_records, report = _filter_records_for_branch( |
| _materialize_split_records(split), |
| branch_name=config.branch_name, |
| split_name=split_name, |
| branch_filter_rules=config.branch_dataset_filter_rules, |
| ) |
| if not filtered_records: |
| message = ( |
| f"Branch dataset filter atmeta visus ierakstus splitam '{split_name}' " |
| f"atzaram '{config.branch_name}'. Ienākošie={report.input_records}." |
| ) |
| if allow_empty: |
| logger.warning("%s Šo splitu izlaidīšu šajā training skrējienā.", message) |
| return None, report |
| raise ValueError(message) |
| return _rebuild_split_like(split, filtered_records), report |
|
|
|
|
| def _dataset_quality_config(config: TrainingConfig) -> DatasetQualityGateConfig: |
| return DatasetQualityGateConfig( |
| enabled=config.quality_gate_enabled, |
| dedupe_enabled=config.dedupe_enabled, |
| min_text_chars=config.quality_min_text_chars, |
| max_text_chars=max(config.max_seq_length * 8, config.quality_min_text_chars), |
| ) |
|
|
|
|
| def _dataset_scoring_config(config: TrainingConfig) -> DatasetScoringConfig: |
| return DatasetScoringConfig( |
| enabled=config.scoring_enabled, |
| weighted_repetition_enabled=config.weighted_repetition_enabled, |
| max_text_chars=max(config.max_seq_length * 8, config.quality_min_text_chars), |
| medium_score_repeat_count=config.medium_score_repeat_count, |
| high_score_repeat_count=config.high_score_repeat_count, |
| source_weighting_enabled=config.source_weighting_enabled, |
| source_weight_map=config.source_weight_map.copy(), |
| category_weight_map=config.category_weight_map.copy(), |
| max_effective_repeat_count=config.max_effective_repeat_count, |
| benchmark_feedback_enabled=config.benchmark_feedback_enabled, |
| benchmark_feedback_path=config.benchmark_feedback_path, |
| benchmark_feedback_boost_scale=config.benchmark_feedback_boost_scale, |
| benchmark_feedback_max_multiplier=config.benchmark_feedback_max_multiplier, |
| ) |
|
|
|
|
| def _benchmark_feedback_discovery_candidates(config: TrainingConfig) -> list[Path]: |
| output_dir = Path(config.output_dir) |
| candidates: list[Path] = [] |
| direct_candidate = output_dir / "benchmark-feedback.json" |
| if direct_candidate.is_file(): |
| candidates.append(direct_candidate) |
|
|
| search_roots = [output_dir.parent] |
| if output_dir.parent != output_dir.parent.parent: |
| search_roots.append(output_dir.parent.parent) |
|
|
| seen: set[Path] = {direct_candidate.resolve()} if direct_candidate.exists() else set() |
| for root in search_roots: |
| if not root.exists(): |
| continue |
| for candidate in root.rglob("benchmark-feedback.json"): |
| try: |
| resolved = candidate.resolve() |
| except OSError: |
| continue |
| if resolved in seen or not candidate.is_file(): |
| continue |
| seen.add(resolved) |
| candidates.append(candidate) |
| return candidates |
|
|
|
|
| def _discover_benchmark_feedback_path(config: TrainingConfig) -> str: |
| candidates = _benchmark_feedback_discovery_candidates(config) |
| if not candidates: |
| return "" |
| newest = max(candidates, key=lambda path: path.stat().st_mtime) |
| return str(newest) |
|
|
|
|
| def _load_dataset_benchmark_feedback(config: TrainingConfig) -> DatasetBenchmarkFeedback | None: |
| if not config.benchmark_feedback_enabled: |
| return None |
| feedback_path = config.benchmark_feedback_path |
| if not feedback_path and config.benchmark_feedback_auto_discover: |
| feedback_path = _discover_benchmark_feedback_path(config) |
| if not feedback_path: |
| return None |
| discovery_mode = "explicit" if config.benchmark_feedback_path else "auto_discovered" |
| feedback = load_benchmark_feedback( |
| feedback_path, |
| targets=_default_branch_benchmark_targets(config), |
| boost_scale=config.benchmark_feedback_boost_scale, |
| max_multiplier=config.benchmark_feedback_max_multiplier, |
| ) |
| return DatasetBenchmarkFeedback( |
| artifact_path=feedback.artifact_path, |
| deficient_metrics=feedback.deficient_metrics, |
| overall_multiplier=feedback.overall_multiplier, |
| discovery_mode=discovery_mode, |
| ) |
|
|
|
|
| def _build_scoring_dashboard_payload( |
| *, |
| scoring_report: DatasetScoringReport | None, |
| benchmark_feedback: DatasetBenchmarkFeedback | None, |
| ) -> dict[str, Any]: |
| if scoring_report is None and benchmark_feedback is None: |
| return {} |
| payload: dict[str, Any] = {} |
| if scoring_report is not None: |
| payload["scoring_dashboard"] = { |
| split_name: { |
| "sources": split_report.source_dashboard, |
| "categories": split_report.category_dashboard, |
| } |
| for split_name, split_report in scoring_report.splits.items() |
| } |
| payload.update(_build_flattened_scoring_dashboard_metrics(scoring_report)) |
| if benchmark_feedback is not None: |
| payload["benchmark_feedback"] = build_benchmark_feedback_artifact(benchmark_feedback) |
| return payload |
|
|
|
|
| def _build_flattened_scoring_dashboard_metrics( |
| scoring_report: DatasetScoringReport | None, |
| ) -> dict[str, float]: |
| if scoring_report is None: |
| return {} |
|
|
| flattened: dict[str, float] = {} |
| for split_name, split_report in scoring_report.splits.items(): |
| for dimension_name, dimension_payload in ( |
| ("sources", split_report.source_dashboard), |
| ("categories", split_report.category_dashboard), |
| ): |
| for label, metrics in dimension_payload.items(): |
| for metric_name, metric_value in metrics.items(): |
| flattened[ |
| f"scoring_dashboard_{split_name}_{dimension_name}_{label}_{metric_name}" |
| ] = float(metric_value) |
| return flattened |
|
|
|
|
| def _apply_quality_gate_to_split( |
| split: Any, |
| *, |
| split_name: str, |
| config: TrainingConfig, |
| allow_empty: bool = False, |
| ) -> tuple[Any | None, DatasetQualitySplitReport]: |
| filtered_records, report = apply_quality_gate_to_records( |
| _materialize_split_records(split), |
| split_name=split_name, |
| config=_dataset_quality_config(config), |
| ) |
| if not filtered_records: |
| message = ( |
| f"Dataset quality gate atmeta visus ierakstus splitam '{split_name}'. " |
| f"Ienākošie={report.input_records}, paturētie={report.kept_records}, " |
| f"atmestie={report.dropped_records}, dublikāti={report.duplicates_removed}. " |
| "Pārbaudi īsos, zemas kvalitātes vai dublētos ierakstus." |
| ) |
| if allow_empty: |
| logger.warning( |
| "%s Iemesli=%s; piemēri=%s. Šo splitu izlaidīšu šajā training skrējienā.", |
| message, |
| report.reasons, |
| report.sample_rejections, |
| ) |
| return None, report |
| raise ValueError(message) |
| return _rebuild_split_like(split, filtered_records), report |
|
|
|
|
| def _apply_scoring_to_split( |
| split: Any, |
| *, |
| split_name: str, |
| config: TrainingConfig, |
| expand_weights: bool, |
| benchmark_feedback: DatasetBenchmarkFeedback | None, |
| ) -> tuple[Any, DatasetScoringSplitReport]: |
| scored_records, report = apply_scoring_to_records( |
| _materialize_split_records(split), |
| split_name=split_name, |
| config=_dataset_scoring_config(config), |
| expand_weights=expand_weights, |
| benchmark_feedback=benchmark_feedback, |
| ) |
| if not scored_records: |
| raise ValueError( |
| f"Dataset scoring pipeline neatgrieza nevienu ierakstu splitam '{split_name}'. " |
| f"Ienākošie={report.input_records}, expanded={report.expanded_records}, " |
| f"source_weighting={config.source_weighting_enabled}, " |
| f"benchmark_feedback={bool(benchmark_feedback)}." |
| ) |
| return _rebuild_split_like(split, scored_records), report |
|
|
|
|
| def build_branch_training_configs(config: TrainingConfig) -> list[TrainingConfig]: |
| """Izveido konfigurāciju katram arhitektūras atzaram.""" |
| base_output_dir = Path(config.output_dir) |
| branch_specs = ( |
| { |
| "branch_name": "master", |
| "branch_focus": "Maris AI realtime_text_chat", |
| "adapter_type": config.adapter_type, |
| "model_name": config.model_name, |
| "output_dir": base_output_dir / "master", |
| "hub_model_id": config.text_model_id, |
| "benchmark_name": _resolve_branch_benchmark_name(config, "master"), |
| "benchmark_dataset_path": _resolve_branch_benchmark_dataset_path(config, "master"), |
| "benchmark_gate_enabled": True, |
| "benchmark_min_overall": max(config.benchmark_min_overall, 0.76), |
| "quality_min_text_chars": max(config.quality_min_text_chars, 12), |
| "category_weight_map": { |
| **config.category_weight_map, |
| "reasoning": max(float(config.category_weight_map.get("reasoning", 1.0)), 1.2), |
| "helpfulness": max(float(config.category_weight_map.get("helpfulness", 1.0)), 1.15), |
| "grounding": max(float(config.category_weight_map.get("grounding", 1.0)), 1.15), |
| "latvian_quality": max( |
| float(config.category_weight_map.get("latvian_quality", 1.0)), |
| 1.1, |
| ), |
| }, |
| }, |
| { |
| "branch_name": "coder", |
| "branch_focus": "Maris AI code_generation_repo_grounded", |
| "adapter_type": config.adapter_type, |
| "model_name": config.model_name, |
| "output_dir": base_output_dir / "coder", |
| "hub_model_id": config.hub_model_id, |
| "benchmark_name": _resolve_branch_benchmark_name(config, "coder"), |
| "benchmark_dataset_path": _resolve_branch_benchmark_dataset_path(config, "coder"), |
| "preference_dataset_path": _resolve_branch_preference_dataset_path(config, "coder"), |
| "benchmark_gate_enabled": True, |
| "benchmark_min_overall": max(config.benchmark_min_overall, 0.76), |
| "quality_min_text_chars": max(config.quality_min_text_chars, 18), |
| "category_weight_map": { |
| **config.category_weight_map, |
| "coding": max(float(config.category_weight_map.get("coding", 1.0)), 1.35), |
| "debugging": max(float(config.category_weight_map.get("debugging", 1.0)), 1.35), |
| "grounding": max(float(config.category_weight_map.get("grounding", 1.0)), 1.25), |
| "refactor": max(float(config.category_weight_map.get("refactor", 1.0)), 1.2), |
| "tests": max(float(config.category_weight_map.get("tests", 1.0)), 1.15), |
| "unsafe": max(float(config.category_weight_map.get("unsafe", 1.0)), 1.15), |
| }, |
| }, |
| { |
| "branch_name": "planner", |
| "branch_focus": "Maris AI autonomous_tasks", |
| "adapter_type": config.adapter_type, |
| "model_name": config.model_name, |
| "output_dir": base_output_dir / "planner", |
| "hub_model_id": config.hub_model_id, |
| "benchmark_name": _resolve_branch_benchmark_name(config, "planner"), |
| "benchmark_dataset_path": _resolve_branch_benchmark_dataset_path(config, "planner"), |
| "benchmark_gate_enabled": True, |
| "benchmark_min_overall": max(config.benchmark_min_overall, 0.76), |
| "quality_min_text_chars": max(config.quality_min_text_chars, 18), |
| "category_weight_map": { |
| **config.category_weight_map, |
| "reasoning": max(float(config.category_weight_map.get("reasoning", 1.0)), 1.3), |
| "planning": max(float(config.category_weight_map.get("planning", 1.0)), 1.3), |
| "grounding": max(float(config.category_weight_map.get("grounding", 1.0)), 1.2), |
| "safety": max(float(config.category_weight_map.get("safety", 1.0)), 1.1), |
| }, |
| }, |
| { |
| "branch_name": "image", |
| "branch_focus": "Maris AI image_generation", |
| "adapter_type": "specialist_model", |
| "model_name": config.image_model_id, |
| "output_dir": base_output_dir / "image", |
| "hub_model_id": config.image_model_id, |
| }, |
| { |
| "branch_name": "music", |
| "branch_focus": "Maris AI music_generation", |
| "adapter_type": "specialist_model", |
| "model_name": config.music_model_id, |
| "output_dir": base_output_dir / "music", |
| "hub_model_id": config.music_model_id, |
| }, |
| { |
| "branch_name": "tts", |
| "branch_focus": "Maris AI realtime_tts", |
| "adapter_type": "specialist_model", |
| "model_name": config.tts_model_id, |
| "output_dir": base_output_dir / "tts", |
| "hub_model_id": config.tts_model_id, |
| }, |
| { |
| "branch_name": "stt", |
| "branch_focus": "Maris AI realtime_stt", |
| "adapter_type": "specialist_model", |
| "model_name": config.stt_model_id, |
| "output_dir": base_output_dir / "stt", |
| "hub_model_id": config.stt_model_id, |
| }, |
| { |
| "branch_name": "video", |
| "branch_focus": "Maris AI video_generation", |
| "adapter_type": "specialist_model", |
| "model_name": config.video_model_id, |
| "output_dir": base_output_dir / "video", |
| "hub_model_id": config.video_model_id, |
| }, |
| ) |
|
|
| return [ |
| TrainingConfig( |
| model_name=str(branch["model_name"]), |
| branch_name=str(branch["branch_name"]), |
| branch_focus=str(branch["branch_focus"]), |
| adapter_type=str(branch["adapter_type"]), |
| lora_r=config.lora_r, |
| lora_alpha=config.lora_alpha, |
| lora_dropout=config.lora_dropout, |
| lora_bias=config.lora_bias, |
| peft_target_modules=list(config.peft_target_modules), |
| qlora_quant_type=config.qlora_quant_type, |
| qlora_use_double_quant=config.qlora_use_double_quant, |
| qlora_compute_dtype=config.qlora_compute_dtype, |
| dataset_repo=config.dataset_repo, |
| eval_dataset_repo=config.eval_dataset_repo, |
| output_dir=str(branch["output_dir"]), |
| hub_model_id=str(branch["hub_model_id"]), |
| text_model_id=config.text_model_id, |
| image_model_id=config.image_model_id, |
| music_model_id=config.music_model_id, |
| tts_model_id=config.tts_model_id, |
| stt_model_id=config.stt_model_id, |
| video_model_id=config.video_model_id, |
| num_epochs=config.num_epochs, |
| learning_rate=config.learning_rate, |
| per_device_train_batch_size=config.per_device_train_batch_size, |
| per_device_eval_batch_size=config.per_device_eval_batch_size, |
| gradient_accumulation_steps=config.gradient_accumulation_steps, |
| warmup_ratio=config.warmup_ratio, |
| weight_decay=config.weight_decay, |
| logging_steps=config.logging_steps, |
| save_steps=config.save_steps, |
| eval_steps=config.eval_steps, |
| save_total_limit=config.save_total_limit, |
| max_seq_length=config.max_seq_length, |
| validation_split_ratio=config.validation_split_ratio, |
| seed=config.seed, |
| fp16=config.fp16, |
| bf16=config.bf16, |
| gradient_checkpointing=config.gradient_checkpointing, |
| gradient_checkpointing_use_reentrant=config.gradient_checkpointing_use_reentrant, |
| report_to=list(config.report_to), |
| push_to_hub=config.push_to_hub, |
| save_safetensors=config.save_safetensors, |
| lr_scheduler_type=config.lr_scheduler_type, |
| benchmark_dataset_path=str( |
| branch.get("benchmark_dataset_path", config.benchmark_dataset_path) or "" |
| ), |
| benchmark_name=str(branch.get("benchmark_name", config.benchmark_name)), |
| benchmark_levels=list(config.benchmark_levels), |
| benchmark_min_overall=float( |
| branch.get("benchmark_min_overall", config.benchmark_min_overall) |
| ), |
| benchmark_gate_enabled=bool( |
| branch.get("benchmark_gate_enabled", config.benchmark_gate_enabled) |
| ), |
| branch_benchmark_targets={ |
| key: value.copy() for key, value in config.branch_benchmark_targets.items() |
| }, |
| branch_benchmark_names=config.branch_benchmark_names.copy(), |
| branch_benchmark_dataset_paths=config.branch_benchmark_dataset_paths.copy(), |
| branch_preference_dataset_paths=config.branch_preference_dataset_paths.copy(), |
| branch_dataset_filter_rules={ |
| key: value.copy() for key, value in config.branch_dataset_filter_rules.items() |
| }, |
| preference_dataset_path=str( |
| branch.get("preference_dataset_path", config.preference_dataset_path) or "" |
| ), |
| preference_optimization=config.preference_optimization, |
| preference_beta=config.preference_beta, |
| preference_max_prompt_length=config.preference_max_prompt_length, |
| preference_max_length=config.preference_max_length, |
| preference_reference_model=config.preference_reference_model, |
| quality_gate_enabled=config.quality_gate_enabled, |
| dedupe_enabled=config.dedupe_enabled, |
| quality_min_text_chars=int( |
| branch.get("quality_min_text_chars", config.quality_min_text_chars) |
| ), |
| scoring_enabled=config.scoring_enabled, |
| weighted_repetition_enabled=config.weighted_repetition_enabled, |
| medium_score_repeat_count=config.medium_score_repeat_count, |
| high_score_repeat_count=config.high_score_repeat_count, |
| source_weighting_enabled=config.source_weighting_enabled, |
| source_weight_map=config.source_weight_map.copy(), |
| category_weight_map=dict(branch.get("category_weight_map", config.category_weight_map)), |
| max_effective_repeat_count=config.max_effective_repeat_count, |
| benchmark_feedback_enabled=config.benchmark_feedback_enabled, |
| benchmark_feedback_auto_discover=config.benchmark_feedback_auto_discover, |
| benchmark_feedback_path=config.benchmark_feedback_path, |
| benchmark_feedback_boost_scale=config.benchmark_feedback_boost_scale, |
| benchmark_feedback_max_multiplier=config.benchmark_feedback_max_multiplier, |
| continue_from_latest_artifact=config.continue_from_latest_artifact, |
| continue_model_path=config.continue_model_path, |
| ) |
| for branch in branch_specs |
| ] |
|
|
|
|
| def _augment_metrics(metrics: dict[str, float], *, eval_dataset: Any | None) -> dict[str, float]: |
| enriched = dict(metrics) |
| if eval_dataset is not None: |
| enriched["eval_examples"] = float(len(eval_dataset)) |
| if "eval_loss" in enriched: |
| |
| |
| |
| enriched["perplexity"] = float( |
| math.exp(min(enriched["eval_loss"], MAX_LOSS_FOR_PERPLEXITY)) |
| ) |
| return enriched |
|
|
|
|
| def _default_branch_benchmark_targets(config: TrainingConfig) -> dict[str, float]: |
| return config.branch_benchmark_targets.get(config.branch_name) or { |
| "overall": config.benchmark_min_overall |
| } |
|
|
|
|
| def _build_benchmark_gate_artifact( |
| config: TrainingConfig, |
| benchmark_manifest: dict[str, Any], |
| *, |
| regression_report: dict[str, Any] | None = None, |
| ) -> dict[str, Any]: |
| targets = _augment_release_gate_targets( |
| _default_branch_benchmark_targets(config), |
| benchmark_manifest, |
| ) |
| failed_metrics = { |
| metric: { |
| "required": target, |
| "actual": _benchmark_metric_value(benchmark_manifest, metric), |
| } |
| for metric, target in targets.items() |
| if metric == "overall" or _benchmark_metric_present(benchmark_manifest, metric) |
| if _benchmark_metric_value(benchmark_manifest, metric) < float(target) |
| } |
| regression_count = ( |
| int(regression_report.get("regression_count", 0) or 0) if regression_report else 0 |
| ) |
| if regression_count > 0: |
| failed_metrics["regression_count"] = {"required": 0.0, "actual": float(regression_count)} |
| return { |
| "artifact_type": "benchmark-release-gate", |
| **_build_model_artifact_identity(config), |
| "benchmark_name": benchmark_manifest.get("benchmark_name", config.benchmark_name), |
| "passed": not failed_metrics, |
| "targets": targets, |
| "failed_metrics": failed_metrics, |
| "regression_policy": { |
| "allow_regressions": False, |
| "regression_count": regression_count, |
| }, |
| "levels_checked": config.benchmark_levels, |
| } |
|
|
|
|
| def _augment_release_gate_targets( |
| targets: dict[str, float], |
| benchmark_manifest: dict[str, Any], |
| ) -> dict[str, float]: |
| augmented = dict(targets) |
| score_manifest = benchmark_manifest.get("score_manifest", {}) |
| critical_defaults = { |
| "success_rate": 0.85, |
| "grounding": 0.75, |
| "safety": 0.9, |
| "judge_overall": 0.72, |
| "judge_task_completion": 0.72, |
| "judge_instruction_following": 0.74, |
| "judge_safety": 0.9, |
| "judge_regression_risk": 0.72, |
| "pairwise_win_rate": 0.55, |
| } |
| if "long_context" in score_manifest: |
| critical_defaults["long_context"] = max( |
| float(augmented.get("long_context", 0.0) or 0.0), 0.72 |
| ) |
| if int(benchmark_manifest.get("execution_cases", 0) or 0) > 0: |
| critical_defaults["execution"] = 0.7 |
| if int(benchmark_manifest.get("grounding_cases", 0) or 0) > 0: |
| critical_defaults["grounding"] = 0.75 |
| if int(benchmark_manifest.get("production_like_cases", 0) or 0) > 0: |
| critical_defaults["production_like_pass_rate"] = 0.75 |
| for metric, threshold in critical_defaults.items(): |
| if _benchmark_metric_present(benchmark_manifest, metric): |
| augmented.setdefault(metric, threshold) |
| return augmented |
|
|
|
|
| def _benchmark_metric_present(benchmark_manifest: dict[str, Any], metric: str) -> bool: |
| score_manifest = benchmark_manifest.get("score_manifest", {}) |
| return metric in score_manifest or metric in benchmark_manifest |
|
|
|
|
| def _benchmark_metric_value(benchmark_manifest: dict[str, Any], metric: str) -> float: |
| score_manifest = benchmark_manifest.get("score_manifest", {}) |
| raw_value = ( |
| score_manifest.get(metric) if metric in score_manifest else benchmark_manifest.get(metric) |
| ) |
| return float(raw_value or 0.0) |
|
|
|
|
| def _same_benchmark_reference(left: dict[str, Any] | None, right: dict[str, Any]) -> bool: |
| if not isinstance(left, dict): |
| return False |
| return ( |
| str(left.get("benchmark_name", "")).strip() == str(right.get("benchmark_name", "")).strip() |
| and str(left.get("branch", "")).strip() == str(right.get("branch", "")).strip() |
| and str(left.get("model", "")).strip() == str(right.get("model", "")).strip() |
| and str(left.get("generated_at", "")).strip() == str(right.get("generated_at", "")).strip() |
| ) |
|
|
|
|
| def _select_previous_benchmark_baseline( |
| *, |
| current_manifest: dict[str, Any], |
| previous_history: dict[str, Any] | None, |
| previous_manifest: dict[str, Any] | None, |
| ) -> dict[str, Any] | None: |
| runs = previous_history.get("runs") if isinstance(previous_history, dict) else None |
| if isinstance(runs, list): |
| run_objects = [dict(item) for item in runs if isinstance(item, dict)] |
| if run_objects: |
| last_run = run_objects[-1] |
| if _same_benchmark_reference(last_run, current_manifest): |
| return run_objects[-2] if len(run_objects) > 1 else None |
| return last_run |
| if _same_benchmark_reference(previous_manifest, current_manifest): |
| return None |
| return previous_manifest |
|
|
|
|
| def _write_benchmark_artifacts( |
| output_dir: Path, |
| config: TrainingConfig, |
| benchmark_manifest: dict[str, Any], |
| ) -> tuple[dict[str, Any], dict[str, Any]]: |
| benchmark_manifest = _attach_human_eval_summary_to_manifest(config, benchmark_manifest) |
| previous_history = _load_json_if_exists(output_dir / "benchmark-history.json") |
| previous_manifest = _load_json_if_exists(output_dir / "benchmark-manifest.json") |
| regression_report = build_chat_benchmark_regression_report( |
| benchmark_manifest, |
| previous_run=_select_previous_benchmark_baseline( |
| current_manifest=benchmark_manifest, |
| previous_history=previous_history, |
| previous_manifest=previous_manifest, |
| ), |
| ) |
| benchmark_gate = _build_benchmark_gate_artifact( |
| config, |
| benchmark_manifest, |
| regression_report=regression_report, |
| ) |
| benchmark_history = build_chat_benchmark_history_artifact( |
| benchmark_manifest, |
| previous_history=previous_history, |
| ) |
| _save_json(output_dir / "benchmark-manifest.json", benchmark_manifest) |
| _save_json(output_dir / "benchmark-release-gate.json", benchmark_gate) |
| _save_json(output_dir / "benchmark-history.json", benchmark_history) |
| _save_json(output_dir / "benchmark-regression-report.json", regression_report) |
| _save_json( |
| output_dir / "benchmark-feedback.json", |
| build_benchmark_feedback_artifact( |
| load_benchmark_feedback( |
| output_dir / "benchmark-manifest.json", |
| targets=_default_branch_benchmark_targets(config), |
| boost_scale=config.benchmark_feedback_boost_scale, |
| max_multiplier=config.benchmark_feedback_max_multiplier, |
| ) |
| ), |
| ) |
| return benchmark_gate, regression_report |
|
|
|
|
| def _attach_human_eval_summary_to_manifest( |
| config: TrainingConfig, |
| benchmark_manifest: dict[str, Any], |
| ) -> dict[str, Any]: |
| if not config.preference_dataset_path: |
| return benchmark_manifest |
| examples = _load_filtered_preference_examples(config) |
| human_eval_summary = build_human_eval_summary(examples) |
| manifest = dict(benchmark_manifest) |
| manifest["human_eval_summary"] = human_eval_summary |
| score_manifest = dict(manifest.get("score_manifest", {})) |
| score_manifest["pairwise_win_rate"] = round( |
| float(human_eval_summary.get("pairwise_win_rate", 0.0) or 0.0), |
| 3, |
| ) |
| score_manifest["human_eval_confidence"] = round( |
| float(human_eval_summary.get("average_confidence", 0.0) or 0.0), |
| 3, |
| ) |
| manifest["score_manifest"] = score_manifest |
| return manifest |
|
|
|
|
| async def _run_post_training_benchmark( |
| config: TrainingConfig, |
| *, |
| model_path: str, |
| ) -> dict[str, Any] | None: |
| if not config.benchmark_dataset_path: |
| return None |
|
|
| from transformers import pipeline |
|
|
| cases = load_chat_benchmark_dataset(config.benchmark_dataset_path) |
| selected_cases = select_chat_benchmark_cases( |
| cases, |
| levels=config.benchmark_levels, |
| branch=config.branch_name, |
| ) |
| if not selected_cases: |
| raise ValueError("Benchmark datasetā nav neviena case izvēlētajiem benchmark leveliem.") |
|
|
| pipe = pipeline( |
| "text-generation", |
| model=model_path, |
| tokenizer=model_path, |
| trust_remote_code=True, |
| ) |
|
|
| async def responder(case: Any) -> dict[str, Any]: |
| effective_profile = case.profile or { |
| "coder": "coder", |
| "planner": "planner", |
| }.get(config.branch_name, "general") |
| messages = [ |
| { |
| "role": "system", |
| "content": build_system_prompt( |
| effective_profile, |
| persona_id=case.persona_id, |
| ), |
| }, |
| *list(case.history), |
| {"role": "user", "content": case.message}, |
| ] |
| raw_output = await asyncio.to_thread( |
| call_generation_pipeline, |
| pipe, |
| messages, |
| max_new_tokens=512, |
| temperature=0.0, |
| ) |
| response_text = _extract_response_text(raw_output) |
| return { |
| "response": response_text, |
| "model": config.hub_model_id, |
| "tokens_used": _extract_usage_tokens(raw_output) or 0, |
| "persona_title": "Core Assistant", |
| } |
|
|
| results = await run_chat_benchmark_with_responder( |
| selected_cases, |
| responder=responder, |
| concurrency=1, |
| ) |
| return build_chat_benchmark_manifest( |
| results, |
| benchmark_name=config.benchmark_name, |
| branch=config.branch_name, |
| model=config.hub_model_id, |
| ) |
|
|
|
|
| def train_with_config(config: TrainingConfig) -> dict[str, float]: |
| """Apmāca modeli pēc pilnas konfigurācijas.""" |
| config = _apply_branch_runtime_defaults(_normalize_training_runtime_config(config)) |
| _ensure_runtime_home_dir() |
| training_dataset_repos = _resolve_training_dataset_repos(config) |
| eval_dataset_repos = _resolve_eval_dataset_repos(config) |
| try: |
| _emit_training_progress_event( |
| "prepare_dataset", |
| stage="preparing", |
| label="Ielādē training datasetu", |
| total_epochs=config.num_epochs, |
| dataset_repo=config.dataset_repo, |
| dataset_repos=training_dataset_repos, |
| ) |
| logger.info("Ielādē training datasetus: %s", ", ".join(training_dataset_repos)) |
| dataset = _load_combined_hf_dataset(training_dataset_repos) |
| except HFDatasetError as exc: |
| logger.error("Apmācība apturēta: %s", exc) |
| raise SystemExit(str(exc)) from None |
| except Exception as exc: |
| logger.error("Apmācības kļūda: %s", exc) |
| raise |
|
|
| try: |
| from transformers import ( |
| DataCollatorForLanguageModeling, |
| Trainer, |
| TrainingArguments, |
| ) |
|
|
| try: |
| from transformers import TrainerCallback |
| except ImportError: |
|
|
| class TrainerCallback: |
| pass |
|
|
| training_model_source = _resolve_training_model_source(config) |
| _emit_training_progress_event( |
| "prepare_model", |
| stage="preparing", |
| label="Ielādē tokenizeri un modeli", |
| total_epochs=config.num_epochs, |
| model_source=training_model_source, |
| ) |
| logger.info("Ielādē modeli: %s", training_model_source) |
| tokenizer = _load_tokenizer(training_model_source, config) |
| _configure_tokenizer(tokenizer, config) |
| model = _prepare_training_model(training_model_source, tokenizer, config) |
|
|
| train_split, eval_split = _prepare_train_eval_splits(dataset, config) |
| if eval_dataset_repos: |
| _emit_training_progress_event( |
| "prepare_eval_dataset", |
| stage="preparing", |
| label="Ielādē eval datasetu", |
| total_epochs=config.num_epochs, |
| dataset_repo=config.eval_dataset_repo or eval_dataset_repos[0], |
| dataset_repos=eval_dataset_repos, |
| ) |
| logger.info("Ielādē atsevišķus eval datasetus: %s", ", ".join(eval_dataset_repos)) |
| eval_source_dataset = _load_combined_hf_dataset(eval_dataset_repos) |
| eval_split = _select_eval_split( |
| eval_source_dataset, |
| config, |
| allow_train_fallback=True, |
| ) |
| train_split, train_branch_filter_report = _apply_branch_dataset_filter_to_split( |
| train_split, |
| split_name="train", |
| config=config, |
| ) |
| eval_branch_filter_report: DatasetBranchFilterReport | None = None |
| if eval_split is not None: |
| eval_split, eval_branch_filter_report = _apply_branch_dataset_filter_to_split( |
| eval_split, |
| split_name="eval", |
| config=config, |
| allow_empty=True, |
| ) |
| train_split, train_quality_report = _apply_quality_gate_to_split( |
| train_split, |
| split_name="train", |
| config=config, |
| ) |
| benchmark_feedback = _load_dataset_benchmark_feedback(config) |
| train_split, train_scoring_report = _apply_scoring_to_split( |
| train_split, |
| split_name="train", |
| config=config, |
| expand_weights=True, |
| benchmark_feedback=benchmark_feedback, |
| ) |
| eval_quality_report: DatasetQualitySplitReport | None = None |
| eval_scoring_report: DatasetScoringSplitReport | None = None |
| if eval_split is not None: |
| eval_split, eval_quality_report = _apply_quality_gate_to_split( |
| eval_split, |
| split_name="eval", |
| config=config, |
| allow_empty=True, |
| ) |
| if eval_split is not None: |
| eval_split, eval_scoring_report = _apply_scoring_to_split( |
| eval_split, |
| split_name="eval", |
| config=config, |
| expand_weights=False, |
| benchmark_feedback=benchmark_feedback, |
| ) |
| _emit_training_progress_event( |
| "prepare_tokenization", |
| stage="preparing", |
| label="Tokenizē treniņa un eval datus", |
| total_epochs=config.num_epochs, |
| ) |
| train_dataset = _tokenize_dataset(train_split, tokenizer, config.max_seq_length) |
| eval_dataset = ( |
| _tokenize_dataset(eval_split, tokenizer, config.max_seq_length) |
| if eval_split is not None |
| else None |
| ) |
|
|
| _emit_training_progress_event( |
| "prepare_runtime", |
| stage="preparing", |
| label="Sagatavo treneri un runtime", |
| total_epochs=config.num_epochs, |
| ) |
| training_args = _build_training_arguments( |
| TrainingArguments, |
| output_dir=config.output_dir, |
| overwrite_output_dir=True, |
| num_train_epochs=config.num_epochs, |
| learning_rate=config.learning_rate, |
| per_device_train_batch_size=config.per_device_train_batch_size, |
| per_device_eval_batch_size=config.per_device_eval_batch_size, |
| gradient_accumulation_steps=config.gradient_accumulation_steps, |
| warmup_ratio=config.warmup_ratio, |
| weight_decay=config.weight_decay, |
| logging_steps=config.logging_steps, |
| save_steps=config.save_steps, |
| eval_steps=config.eval_steps, |
| save_total_limit=config.save_total_limit, |
| lr_scheduler_type=config.lr_scheduler_type, |
| seed=config.seed, |
| fp16=config.fp16, |
| bf16=config.bf16, |
| report_to=config.report_to, |
| save_safetensors=config.save_safetensors, |
| remove_unused_columns=False, |
| evaluation_strategy="steps" if eval_dataset is not None else "no", |
| load_best_model_at_end=eval_dataset is not None, |
| metric_for_best_model="eval_loss" if eval_dataset is not None else None, |
| greater_is_better=False if eval_dataset is not None else None, |
| **_build_runtime_training_argument_overrides(), |
| **_build_distributed_training_argument_overrides(config), |
| ) |
|
|
| class _SpaceTrainingProgressCallback(MarisTrainingProgressCallback, TrainerCallback): |
| pass |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False), |
| ) |
| if hasattr(trainer, "add_callback"): |
| trainer.add_callback(_SpaceTrainingProgressCallback(total_epochs=config.num_epochs)) |
|
|
| logger.info("Sāk apmācību...") |
| train_result = trainer.train() |
| metrics = dict(getattr(train_result, "metrics", {})) |
| if eval_dataset is not None: |
| metrics.update(trainer.evaluate()) |
| metrics = _augment_metrics(metrics, eval_dataset=eval_dataset) |
| train_examples = len(train_dataset) |
| eval_examples = len(eval_dataset) if eval_dataset is not None else 0 |
| trained_at = datetime.now(UTC).isoformat() |
| output_dir = Path(config.output_dir) |
| if _uses_preference_optimization(config): |
| trainer.save_model(str(output_dir)) |
| tokenizer.save_pretrained(str(output_dir)) |
| model, preference_trainer, preference_metrics = _run_preference_optimization( |
| model, |
| tokenizer, |
| config, |
| output_dir=output_dir, |
| ) |
| metrics.update(preference_metrics) |
| if preference_trainer is not None: |
| trainer = preference_trainer |
| quality_report = build_dataset_quality_report( |
| config=_dataset_quality_config(config), |
| train_report=train_quality_report, |
| eval_report=eval_quality_report, |
| ) |
| scoring_report = build_dataset_scoring_report( |
| config=_dataset_scoring_config(config), |
| train_report=train_scoring_report, |
| eval_report=eval_scoring_report, |
| ) |
| metrics["branch_filter_train_kept"] = float(train_branch_filter_report.kept_records) |
| metrics["branch_filter_train_dropped"] = float(train_branch_filter_report.dropped_records) |
| metrics["quality_train_kept"] = float(train_quality_report.kept_records) |
| metrics["quality_train_dropped"] = float(train_quality_report.dropped_records) |
| metrics["quality_train_duplicates_removed"] = float(train_quality_report.duplicates_removed) |
| metrics["scoring_train_average_score"] = float(train_scoring_report.average_score) |
| metrics["scoring_train_expanded_records"] = float(train_scoring_report.expanded_records) |
| metrics["scoring_train_repeated_records"] = float(train_scoring_report.repeated_records) |
| metrics["scoring_train_average_repeat_multiplier"] = float( |
| train_scoring_report.average_repeat_multiplier |
| ) |
| metrics["scoring_train_feedback_boosted_records"] = float( |
| train_scoring_report.feedback_boosted_records |
| ) |
| if eval_quality_report is not None: |
| if eval_branch_filter_report is not None: |
| metrics["branch_filter_eval_kept"] = float(eval_branch_filter_report.kept_records) |
| metrics["branch_filter_eval_dropped"] = float( |
| eval_branch_filter_report.dropped_records |
| ) |
| metrics["quality_eval_kept"] = float(eval_quality_report.kept_records) |
| metrics["quality_eval_dropped"] = float(eval_quality_report.dropped_records) |
| metrics["quality_eval_duplicates_removed"] = float( |
| eval_quality_report.duplicates_removed |
| ) |
| metrics["quality_eval_skipped"] = float(eval_split is None) |
| elif eval_branch_filter_report is not None: |
| metrics["branch_filter_eval_kept"] = float(eval_branch_filter_report.kept_records) |
| metrics["branch_filter_eval_dropped"] = float(eval_branch_filter_report.dropped_records) |
| if eval_scoring_report is not None: |
| metrics["scoring_eval_average_score"] = float(eval_scoring_report.average_score) |
| metrics["scoring_eval_expanded_records"] = float(eval_scoring_report.expanded_records) |
| metrics["scoring_eval_feedback_boosted_records"] = float( |
| eval_scoring_report.feedback_boosted_records |
| ) |
|
|
| trainer.save_model(str(output_dir)) |
| tokenizer.save_pretrained(str(output_dir)) |
| _write_training_artifacts( |
| output_dir, |
| config, |
| trained_at=trained_at, |
| train_examples=train_examples, |
| eval_examples=eval_examples, |
| metrics=metrics, |
| quality_report=quality_report, |
| scoring_report=scoring_report, |
| benchmark_feedback=benchmark_feedback, |
| ) |
| _write_preference_artifact(output_dir, config) |
| benchmark_manifest = ( |
| asyncio.run(_run_post_training_benchmark(config, model_path=str(output_dir))) |
| if config.benchmark_dataset_path |
| else None |
| ) |
| if benchmark_manifest is not None: |
| _emit_training_progress_event( |
| "benchmark", |
| stage="benchmarking", |
| label=_build_training_progress_label(stage="benchmarking"), |
| output_dir=str(output_dir), |
| ) |
| benchmark_gate, regression_report = _write_benchmark_artifacts( |
| output_dir, |
| config, |
| benchmark_manifest, |
| ) |
| metrics["benchmark_overall"] = float( |
| benchmark_manifest.get("score_manifest", {}).get("overall", 0.0) |
| ) |
| metrics["benchmark_gate_passed"] = 1.0 if benchmark_gate["passed"] else 0.0 |
| metrics["benchmark_regressions"] = float(regression_report.get("regression_count", 0)) |
| _save_json( |
| output_dir / "training-metrics.json", |
| _build_metrics_artifact( |
| config, |
| metrics, |
| artifact_type="training-metrics", |
| extra_payload=_build_scoring_dashboard_payload( |
| scoring_report=scoring_report, |
| benchmark_feedback=benchmark_feedback, |
| ), |
| ), |
| ) |
| if config.benchmark_gate_enabled and not benchmark_gate["passed"]: |
| failure_details = ", ".join( |
| f"{metric}: {details['actual']:.3f} < {details['required']:.3f}" |
| for metric, details in sorted(benchmark_gate["failed_metrics"].items()) |
| ) |
| raise ValueError("Benchmark release gate neizgāja: " + failure_details) |
| _sanitize_saved_artifacts(output_dir, maris_model_id=config.hub_model_id) |
| _verify_saved_artifacts(output_dir) |
| if config.push_to_hub and hasattr(trainer, "push_to_hub"): |
| _emit_training_progress_event( |
| "publish", |
| stage="publishing", |
| label=_build_training_progress_label(stage="publishing"), |
| output_dir=str(output_dir), |
| ) |
| trainer.push_to_hub(commit_message=f"Maris AI training sync ({config.branch_name})") |
| _write_training_artifacts( |
| output_dir, |
| config, |
| trained_at=trained_at, |
| train_examples=train_examples, |
| eval_examples=eval_examples, |
| metrics=metrics, |
| quality_report=quality_report, |
| scoring_report=scoring_report, |
| benchmark_feedback=benchmark_feedback, |
| ) |
| if benchmark_manifest is not None: |
| _write_benchmark_artifacts(output_dir, config, benchmark_manifest) |
| _sanitize_saved_artifacts(output_dir, maris_model_id=config.hub_model_id) |
| _verify_saved_artifacts(output_dir) |
| _sync_output_dir_to_hub( |
| output_dir, |
| config, |
| commit_message=f"Maris AI artifact sync ({config.branch_name})", |
| ) |
|
|
| logger.info("Apmācība pabeigta. Modelis saglabāts: %s", output_dir) |
| return metrics |
|
|
| except Exception as exc: |
| logger.error("Apmācības kļūda: %s", exc) |
| raise |
|
|
|
|
| def train_branch_suite(config: TrainingConfig) -> dict[str, dict[str, Any]]: |
| """Palaiž branch-oriented training pipeline visiem Maris atzariem.""" |
| suite_results: dict[str, dict[str, Any]] = {} |
| base_output_dir = Path(config.output_dir) |
|
|
| for branch_config in build_branch_training_configs(config): |
| if branch_config.branch_name in TEXT_TRAINABLE_BRANCHES: |
| metrics = train_with_config(branch_config) |
| suite_results[branch_config.branch_name] = { |
| "status": "trained", |
| "maris_model_id": branch_config.hub_model_id, |
| "branch_focus": branch_config.branch_focus, |
| "output_dir": branch_config.output_dir, |
| "metrics": metrics, |
| "dataset_repo": branch_config.dataset_repo, |
| "maris_origin": MARIS_ORIGIN_NAME, |
| "maris_framework": MARIS_FRAMEWORK_NAME, |
| } |
| continue |
|
|
| branch_output_dir = Path(branch_config.output_dir) |
| branch_output_dir.mkdir(parents=True, exist_ok=True) |
| branch_manifest = { |
| "status": "external_specialist", |
| "maris_model_id": branch_config.hub_model_id, |
| "branch_focus": branch_config.branch_focus, |
| "output_dir": branch_config.output_dir, |
| "reason": "Šim atzaram nepieciešams specializēts multimodāls treniņa pipeline ārpus CausalLM trainer.", |
| "maris_origin": MARIS_ORIGIN_NAME, |
| "maris_framework": MARIS_FRAMEWORK_NAME, |
| } |
| _save_json(branch_output_dir / "branch-config.json", branch_manifest) |
| suite_results[branch_config.branch_name] = branch_manifest |
|
|
| _save_json( |
| base_output_dir / "branch-suite.json", |
| _build_branch_suite_artifact(config, suite_results), |
| ) |
| return suite_results |
|
|
|
|
| def evaluate_with_config( |
| config: TrainingConfig, |
| *, |
| model_path: str | None = None, |
| ) -> dict[str, float]: |
| """Novērtē modeli ar to pašu datu pipeline, ko izmanto apmācībai.""" |
| config = _apply_branch_runtime_defaults(config) |
| _ensure_runtime_home_dir() |
| eval_dataset_repos = _resolve_eval_dataset_repos(config) |
| dataset_repos = eval_dataset_repos or _resolve_training_dataset_repos(config) |
| try: |
| dataset = _load_combined_hf_dataset(dataset_repos) |
| except HFDatasetError as exc: |
| logger.error("Novērtēšana apturēta: %s", exc) |
| raise SystemExit(str(exc)) from None |
|
|
| from transformers import ( |
| DataCollatorForLanguageModeling, |
| Trainer, |
| TrainingArguments, |
| ) |
|
|
| resolved_model = model_path or config.output_dir |
| tokenizer = _load_tokenizer(resolved_model, config) |
| _configure_tokenizer(tokenizer, config) |
| model = _load_model(resolved_model, config) |
|
|
| if eval_dataset_repos: |
| reference_split = _select_eval_split(dataset, config, allow_train_fallback=True) |
| else: |
| train_split, eval_split = _prepare_train_eval_splits(dataset, config) |
| reference_split = eval_split or train_split |
| reference_split, _ = _apply_branch_dataset_filter_to_split( |
| reference_split, |
| split_name="evaluation", |
| config=config, |
| ) |
| reference_split, _ = _apply_quality_gate_to_split( |
| reference_split, |
| split_name="evaluation", |
| config=config, |
| ) |
| reference_split, _ = _apply_scoring_to_split( |
| reference_split, |
| split_name="evaluation", |
| config=config, |
| expand_weights=False, |
| benchmark_feedback=_load_dataset_benchmark_feedback(config), |
| ) |
| eval_dataset = _tokenize_dataset(reference_split, tokenizer, config.max_seq_length) |
|
|
| trainer = Trainer( |
| model=model, |
| args=_build_training_arguments( |
| TrainingArguments, |
| output_dir=config.output_dir, |
| per_device_eval_batch_size=config.per_device_eval_batch_size, |
| report_to=[], |
| remove_unused_columns=False, |
| **_build_runtime_training_argument_overrides(), |
| **_build_distributed_training_argument_overrides(config), |
| ), |
| eval_dataset=eval_dataset, |
| data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False), |
| ) |
| metrics = _augment_metrics(dict(trainer.evaluate()), eval_dataset=eval_dataset) |
| output_dir = Path(config.output_dir) |
| _save_json( |
| output_dir / "evaluation-metrics.json", |
| _build_metrics_artifact(config, metrics, artifact_type="evaluation-metrics"), |
| ) |
| if config.benchmark_dataset_path: |
| benchmark_manifest = asyncio.run( |
| _run_post_training_benchmark(config, model_path=str(resolved_model)) |
| ) |
| if benchmark_manifest is not None: |
| benchmark_gate, regression_report = _write_benchmark_artifacts( |
| output_dir, |
| config, |
| benchmark_manifest, |
| ) |
| metrics["benchmark_overall"] = float( |
| benchmark_manifest.get("score_manifest", {}).get("overall", 0.0) |
| ) |
| metrics["benchmark_gate_passed"] = 1.0 if benchmark_gate["passed"] else 0.0 |
| metrics["benchmark_regressions"] = float(regression_report.get("regression_count", 0)) |
| _save_json( |
| output_dir / "evaluation-metrics.json", |
| _build_metrics_artifact(config, metrics, artifact_type="evaluation-metrics"), |
| ) |
| _sanitize_saved_artifacts(output_dir, maris_model_id=config.hub_model_id) |
| _verify_saved_artifacts(output_dir) |
| return metrics |
|
|
|
|
| def train( |
| config_path: str | None = None, |
| model_name: str | None = None, |
| dataset_repo: str | None = None, |
| dataset_repos: str | list[str] | None = None, |
| eval_dataset_repo: str | None = None, |
| eval_dataset_repos: str | list[str] | None = None, |
| benchmark_dataset_path: str | None = None, |
| benchmark_feedback_path: str | None = None, |
| preference_dataset_path: str | None = None, |
| preference_optimization: str | None = None, |
| preference_beta: float | None = None, |
| preference_max_prompt_length: int | None = None, |
| preference_max_length: int | None = None, |
| preference_reference_model: str | None = None, |
| output_dir: str | None = None, |
| num_epochs: int | None = None, |
| learning_rate: float | None = None, |
| max_seq_length: int | None = None, |
| validation_split_ratio: float | None = None, |
| adapter_type: str | None = None, |
| lora_r: int | None = None, |
| lora_alpha: int | None = None, |
| lora_dropout: float | None = None, |
| lora_bias: str | None = None, |
| peft_target_modules: str | list[str] | None = None, |
| qlora_quant_type: str | None = None, |
| qlora_use_double_quant: bool | None = None, |
| qlora_compute_dtype: str | None = None, |
| distributed_strategy: str | None = None, |
| distributed_config_path: str | None = None, |
| use_accelerate: bool | None = None, |
| accelerate_config_path: str | None = None, |
| num_processes: int | None = None, |
| num_machines: int | None = None, |
| machine_rank: int | None = None, |
| main_process_ip: str | None = None, |
| main_process_port: int | None = None, |
| fsdp_transformer_layer_cls_to_wrap: str | list[str] | None = None, |
| fsdp_min_num_params: int | None = None, |
| continue_from_latest_artifact: bool | None = None, |
| continue_model_path: str | None = None, |
| ) -> dict[str, float]: |
| """Apmāca modeli ar HuggingFace datasets datiem.""" |
| config = load_training_config( |
| config_path, |
| overrides={ |
| "model_name": model_name, |
| "dataset_repo": dataset_repo, |
| "dataset_repos": dataset_repos, |
| "eval_dataset_repo": eval_dataset_repo, |
| "eval_dataset_repos": eval_dataset_repos, |
| "benchmark_dataset_path": benchmark_dataset_path, |
| "benchmark_feedback_path": benchmark_feedback_path, |
| "preference_dataset_path": preference_dataset_path, |
| "preference_optimization": preference_optimization, |
| "preference_beta": preference_beta, |
| "preference_max_prompt_length": preference_max_prompt_length, |
| "preference_max_length": preference_max_length, |
| "preference_reference_model": preference_reference_model, |
| "output_dir": output_dir, |
| "num_epochs": num_epochs, |
| "learning_rate": learning_rate, |
| "max_seq_length": max_seq_length, |
| "validation_split_ratio": validation_split_ratio, |
| "adapter_type": adapter_type, |
| "lora_r": lora_r, |
| "lora_alpha": lora_alpha, |
| "lora_dropout": lora_dropout, |
| "lora_bias": lora_bias, |
| "peft_target_modules": peft_target_modules, |
| "qlora_quant_type": qlora_quant_type, |
| "qlora_use_double_quant": qlora_use_double_quant, |
| "qlora_compute_dtype": qlora_compute_dtype, |
| "distributed_strategy": distributed_strategy, |
| "distributed_config_path": distributed_config_path, |
| "use_accelerate": use_accelerate, |
| "accelerate_config_path": accelerate_config_path, |
| "num_processes": num_processes, |
| "num_machines": num_machines, |
| "machine_rank": machine_rank, |
| "main_process_ip": main_process_ip, |
| "main_process_port": main_process_port, |
| "fsdp_transformer_layer_cls_to_wrap": fsdp_transformer_layer_cls_to_wrap, |
| "fsdp_min_num_params": fsdp_min_num_params, |
| "continue_from_latest_artifact": continue_from_latest_artifact, |
| "continue_model_path": continue_model_path, |
| }, |
| ) |
| return train_with_config(config) |
|
|