MarisUK's picture
Maris AI model sync
f440f03 verified
"""Modeļa apmācība ar HuggingFace Trainer."""
from __future__ import annotations
import asyncio
import hashlib
import importlib.util
import inspect
import json
import logging
import math
import os
import re
import subprocess
import sys
import tempfile
import time
from contextlib import suppress
from dataclasses import dataclass, replace
from datetime import UTC, datetime
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version as get_installed_package_version
from pathlib import Path
from typing import Any
from maris_core.data.datasets import HFDatasetError, load_hf_dataset
from maris_core.data.preprocessing import record_to_training_text
from maris_core.data.quality import (
DatasetQualityGateConfig,
DatasetQualityReport,
DatasetQualitySplitReport,
apply_quality_gate_to_records,
build_dataset_quality_report,
)
from maris_core.data.scoring import (
DatasetBenchmarkFeedback,
DatasetScoringConfig,
DatasetScoringReport,
DatasetScoringSplitReport,
apply_scoring_to_records,
build_benchmark_feedback_artifact,
build_dataset_scoring_report,
load_benchmark_feedback,
)
from maris_core.orchestrator.routing import build_system_prompt
from maris_core.text.benchmark import (
build_chat_benchmark_history_artifact,
build_chat_benchmark_manifest,
build_chat_benchmark_regression_report,
load_chat_benchmark_dataset,
run_chat_benchmark_with_responder,
select_chat_benchmark_cases,
)
from maris_core.text.generate import (
_extract_response_text,
_extract_usage_tokens,
call_generation_pipeline,
)
from maris_core.training.config import (
DEFAULT_BENCHMARK_NAME,
DEFAULT_BRANCH_DATASET_FILTER_RULES,
TrainingConfig,
load_training_config,
)
from maris_core.training.hf_compat import (
MARIS_COMPATIBILITY_ARTIFACT_NAME,
apply_maris_compatibility_identity,
maris_hf_compatible_path,
write_maris_compatibility_artifact,
)
from maris_core.training.preferences import (
PreferenceExample,
build_blind_side_by_side_artifact,
build_human_eval_summary,
load_preference_dataset,
summarize_preference_dataset,
)
from maris_core.utils.env import get_env_any
logger = logging.getLogger(__name__)
TEXT_TRAINABLE_BRANCHES = {"master", "coder", "planner"}
MAX_LOSS_FOR_PERPLEXITY = 20.0
MARIS_ORIGIN_NAME = "Maris AI"
MARIS_FRAMEWORK_NAME = "maris-ai-core"
TOKENIZER_LOAD_RETRY_EXCEPTIONS = (ImportError, OSError, RuntimeError, TypeError, ValueError)
SANITIZED_ARTIFACT_KEYS = (
"_name_or_path",
"name_or_path",
"base_model_name_or_path",
)
IDENTITY_TEXT_KEYS = frozenset(
{
"base_model_lineage",
"base_model_name",
"chat_template",
"default_system_prompt",
"description",
"model_name",
"system_prompt",
}
)
TEXT_SANITIZED_ARTIFACT_EXTENSIONS = frozenset({".jinja", ".md", ".template", ".txt"})
FOREIGN_MODEL_REFERENCE_RE = re.compile(
r"(?i)\b(?:"
r"allenai|anthropic|cohereforai|deepseek-ai|google|huggingfaceh4|meta-llama|"
r"microsoft|mistralai|nousresearch|openai|qwen|stabilityai|tiiuae|tinyllama"
r")/[A-Za-z0-9][\w.-]*\b"
)
FOREIGN_AI_BRAND_PATTERNS = (
re.compile(r"(?i)\bTinyLlama\b"),
re.compile(r"(?i)\bDeepSeek\b"),
re.compile(r"(?i)\bMistral\b"),
re.compile(r"(?i)\bLlama\b"),
re.compile(r"(?i)\bQwen\b"),
re.compile(r"(?i)\bChatGPT\b"),
re.compile(r"(?i)\bClaude\b"),
re.compile(r"(?i)\bGemini\b"),
re.compile(r"(?i)\bOpenAI\b"),
re.compile(r"(?i)\bAnthropic\b"),
)
MARIS_IDENTITY_VARIANT_RE = re.compile(
rf"{re.escape(MARIS_ORIGIN_NAME)}(?:[-_/][A-Za-z0-9][\w.-]*)+"
)
TRAINING_ARGUMENT_ALIASES = {
"evaluation_strategy": "eval_strategy",
}
MODEL_CARD_TAGS = ("maris-ai", "maris-origin", "conversational-ai")
PEFT_ADAPTER_TYPES = {"lora", "qlora"}
PREFERENCE_OPTIMIZATION_TYPES = {"dpo", "orpo"}
MODEL_SIZE_HINT_RE = re.compile(r"(\d+(?:\.\d+)?)B")
LARGE_MODEL_RESOURCE_THRESHOLD_B = 30.0
GIANT_MODEL_AUTO_QLORA_THRESHOLD_B = 70.0
LONG_CONTEXT_MIN_SEQ_LENGTH = 32_768
GIANT_MODEL_NAME_HINTS = (
"deepseek-v3",
"deepseek-r1",
"qwen3-coder-480b",
"405b",
"480b",
"671b",
)
# Reusable fingerprints for detecting a previously saved model or adapter inside
# persistent storage before starting the next training run.
LOCAL_TRAINING_ARTIFACT_FILES = (
"config.json",
"adapter_config.json",
"model.safetensors",
"pytorch_model.bin",
)
MODEL_SOURCE_FINGERPRINT_KEY = "model_source_fingerprint"
REPO_ROOT = Path(__file__).resolve().parents[3]
def _emit_training_progress_event(event: str, **payload: Any) -> None:
body = {
"maris_training_event": True,
"event": event,
**{key: value for key, value in payload.items() if value is not None},
}
print(json.dumps(body, ensure_ascii=False), flush=True)
def _build_training_progress_label(
*,
stage: str,
epoch: float | None = None,
total_epochs: int | None = None,
step: int | None = None,
total_steps: int | None = None,
) -> str:
if stage == "saving":
return "Saglabā modeļa artefaktus"
if stage == "evaluating":
return "Veic eval metriku aprēķinu"
if stage == "publishing":
return "Publicē modeli origin repozitorijā"
if stage == "benchmarking":
return "Palaiž benchmark un release gate pārbaudes"
if stage == "starting":
return "Inicializē treniņu"
if step is not None and total_steps:
return f"Trenē modeli · solis {step}/{total_steps}"
if epoch is not None and total_epochs:
return f"Trenē modeli · epoha {epoch:g}/{total_epochs}"
if epoch is not None:
return f"Trenē modeli · epoha {epoch:g}"
return "Trenē modeli"
class MarisTrainingProgressCallback:
"""Emitē strukturētus progress notikumus Space UI monitoringam."""
def __init__(self, *, total_epochs: int | None = None) -> None:
self.total_epochs = total_epochs
self.started_monotonic: float | None = None
def on_train_begin(self, args, state, control, **kwargs) -> None: # noqa: ANN001, D401
self.started_monotonic = time.monotonic()
total_steps = int(state.max_steps) if getattr(state, "max_steps", 0) else None
_emit_training_progress_event(
"train_begin",
stage="starting",
label=_build_training_progress_label(stage="starting"),
epoch=_coerce_float(getattr(state, "epoch", None)),
total_epochs=self.total_epochs,
step=_coerce_int(getattr(state, "global_step", None)),
total_steps=total_steps,
learning_rate=getattr(args, "learning_rate", None),
)
def on_log(self, args, state, control, logs=None, **kwargs) -> None: # noqa: ANN001, D401
logs = dict(logs or {})
step = _coerce_int(getattr(state, "global_step", None))
total_steps = int(state.max_steps) if getattr(state, "max_steps", 0) else None
epoch = _coerce_float(logs.get("epoch", getattr(state, "epoch", None)))
learning_rate = logs.get("learning_rate", getattr(args, "learning_rate", None))
_emit_training_progress_event(
"log",
stage="training",
label=_build_training_progress_label(
stage="training",
epoch=epoch,
total_epochs=self.total_epochs,
step=step,
total_steps=total_steps,
),
epoch=epoch,
total_epochs=self.total_epochs,
step=step,
total_steps=total_steps,
loss=_coerce_float(logs.get("loss")),
eval_loss=_coerce_float(logs.get("eval_loss")),
learning_rate=_coerce_float(learning_rate),
eta_seconds=self._estimate_eta_seconds(step=step, total_steps=total_steps),
)
def on_evaluate(self, args, state, control, metrics=None, **kwargs) -> None: # noqa: ANN001, D401
metrics = dict(metrics or {})
step = _coerce_int(getattr(state, "global_step", None))
total_steps = int(state.max_steps) if getattr(state, "max_steps", 0) else None
epoch = _coerce_float(getattr(state, "epoch", None))
_emit_training_progress_event(
"evaluate",
stage="evaluating",
label=_build_training_progress_label(stage="evaluating"),
epoch=epoch,
total_epochs=self.total_epochs,
step=step,
total_steps=total_steps,
eval_loss=_coerce_float(metrics.get("eval_loss")),
eta_seconds=self._estimate_eta_seconds(step=step, total_steps=total_steps),
)
def on_save(self, args, state, control, **kwargs) -> None: # noqa: ANN001, D401
_emit_training_progress_event(
"save",
stage="saving",
label=_build_training_progress_label(stage="saving"),
epoch=_coerce_float(getattr(state, "epoch", None)),
total_epochs=self.total_epochs,
step=_coerce_int(getattr(state, "global_step", None)),
total_steps=int(state.max_steps) if getattr(state, "max_steps", 0) else None,
output_dir=getattr(args, "output_dir", None),
)
def on_train_end(self, args, state, control, **kwargs) -> None: # noqa: ANN001, D401
_emit_training_progress_event(
"train_end",
stage="saving",
label="Treniņš pabeigts, sagatavo artefaktus",
epoch=_coerce_float(getattr(state, "epoch", None)),
total_epochs=self.total_epochs,
step=_coerce_int(getattr(state, "global_step", None)),
total_steps=int(state.max_steps) if getattr(state, "max_steps", 0) else None,
)
def _estimate_eta_seconds(self, *, step: int | None, total_steps: int | None) -> float | None:
if self.started_monotonic is None or not step or not total_steps or step <= 0:
return None
elapsed = time.monotonic() - self.started_monotonic
if elapsed <= 0:
return None
remaining_steps = max(total_steps - step, 0)
if remaining_steps <= 0:
return 0.0
return elapsed / step * remaining_steps
def _coerce_float(value: Any) -> float | None:
if isinstance(value, bool) or value in (None, ""):
return None
try:
return float(value)
except (TypeError, ValueError):
return None
def _coerce_int(value: Any) -> int | None:
if isinstance(value, bool) or value in (None, ""):
return None
try:
return int(float(value))
except (TypeError, ValueError):
return None
@dataclass(slots=True, frozen=True)
class DatasetBranchFilterReport:
split_name: str
branch_name: str
input_records: int
kept_records: int
dropped_records: int
@dataclass(slots=True, frozen=True)
class BranchFilterSignals:
explicit_branches: frozenset[str] = frozenset()
record_types: frozenset[str] = frozenset()
task_types: frozenset[str] = frozenset()
languages: frozenset[str] = frozenset()
repo_context_terms: frozenset[str] = frozenset()
presence_keys: frozenset[str] = frozenset()
def _get_split(dataset: Any, *names: str) -> Any:
for name in names:
if isinstance(dataset, dict) and name in dataset:
return dataset[name]
getter = getattr(dataset, "get", None)
if callable(getter):
candidate = getter(name)
if candidate is not None:
return candidate
return None
def _dataset_split_names(dataset: Any) -> list[str]:
if isinstance(dataset, dict):
return [str(name) for name in dataset]
keys = getattr(dataset, "keys", None)
if callable(keys):
try:
return [str(name) for name in keys()]
except Exception: # noqa: BLE001
pass
return [
name
for name in ("train", "validation", "eval", "test")
if _get_split(dataset, name) is not None
]
def _merge_dataset_splits(splits: list[Any]) -> Any:
merged_splits = [split for split in splits if split is not None]
if not merged_splits:
return None
if len(merged_splits) == 1:
return merged_splits[0]
if all(type(split).__module__.startswith("datasets") for split in merged_splits):
try:
from datasets import concatenate_datasets # type: ignore
return concatenate_datasets(merged_splits)
except Exception as exc: # noqa: BLE001
logger.warning(
"HF split concatenation neizdevās; pārslēdzamies uz ierakstu līmeņa merge: %s",
exc,
)
records: list[dict[str, Any]] = []
for split in merged_splits:
records.extend(_materialize_split_records(split))
return _rebuild_split_like(merged_splits[0], records)
def _load_combined_hf_dataset(repo_ids: list[str]) -> Any:
datasets_by_repo = [(repo_id, load_hf_dataset(repo_id)) for repo_id in repo_ids]
split_names: list[str] = []
for _, dataset in datasets_by_repo:
for split_name in _dataset_split_names(dataset):
if split_name not in split_names:
split_names.append(split_name)
if not split_names:
raise ValueError("Neviens no norādītajiem dataset repo nesatur pieejamus splitus.")
return {
split_name: _merge_dataset_splits(
[_get_split(dataset, split_name) for _, dataset in datasets_by_repo]
)
for split_name in split_names
}
def _resolve_training_dataset_repos(config: TrainingConfig) -> list[str]:
return list(config.dataset_repos or [config.dataset_repo])
def _resolve_eval_dataset_repos(config: TrainingConfig) -> list[str]:
if config.eval_dataset_repos:
return list(config.eval_dataset_repos)
if config.eval_dataset_repo:
return [config.eval_dataset_repo]
return []
def _resolve_primary_eval_dataset_repo(config: TrainingConfig) -> str:
eval_dataset_repos = _resolve_eval_dataset_repos(config)
if eval_dataset_repos:
return eval_dataset_repos[0]
return config.eval_dataset_repo or config.dataset_repo
def _prepare_train_eval_splits(
dataset: Any,
config: TrainingConfig,
) -> tuple[Any, Any | None]:
train_split = _get_split(dataset, "train")
if train_split is None:
raise ValueError("Datasetam nav pieejams 'train' split apmācībai.")
eval_split = _get_split(dataset, "validation", "eval", "test")
if eval_split is not None:
return train_split, eval_split
if (
config.validation_split_ratio <= 0
or not hasattr(train_split, "train_test_split")
or len(train_split) < 2
):
return train_split, None
split_result = train_split.train_test_split(
test_size=config.validation_split_ratio,
seed=config.seed,
)
return split_result["train"], split_result["test"]
def _select_eval_split(dataset: Any, config: TrainingConfig, *, allow_train_fallback: bool) -> Any:
"""Atrod drošāko eval splitu novērtēšanai vai ārējam benchmark datasetam."""
train_split = _get_split(dataset, "train")
eval_split = _get_split(dataset, "validation", "eval", "test")
if eval_split is not None:
return eval_split
if allow_train_fallback and train_split is not None:
return train_split
if train_split is None:
raise ValueError("Datasetam nav pieejams ne 'train', ne eval split novērtēšanai.")
_, derived_eval_split = _prepare_train_eval_splits(dataset, config)
if derived_eval_split is None:
raise ValueError("Datasetam nav pieejams stabils eval split novērtēšanai.")
return derived_eval_split
def _tokenize_dataset(split: Any, tokenizer: Any, max_seq_length: int) -> Any:
if not hasattr(split, "map"):
raise TypeError("HF apmācībai nepieciešams datasets.Dataset objekts ar map() atbalstu.")
remove_columns = list(getattr(split, "column_names", []))
def tokenize_batch(batch: dict[str, list[Any]]) -> dict[str, Any]:
batch_size = len(next(iter(batch.values()))) if batch else 0
texts = []
for index in range(batch_size):
record = {key: values[index] for key, values in batch.items()}
texts.append(record_to_training_text(record, max_chars=max_seq_length * 8))
return tokenizer(
texts,
truncation=True,
max_length=max_seq_length,
padding=False,
)
return split.map(
tokenize_batch,
batched=True,
remove_columns=remove_columns,
desc="Tokenizing training data",
)
def _configure_tokenizer(tokenizer: Any, config: TrainingConfig) -> None:
if getattr(tokenizer, "pad_token", None) is None:
fallback_token = getattr(tokenizer, "eos_token", None) or getattr(
tokenizer, "unk_token", None
)
if fallback_token is not None:
tokenizer.pad_token = fallback_token
if (
getattr(tokenizer, "pad_token_id", None) is None
and getattr(tokenizer, "eos_token_id", None) is not None
):
tokenizer.pad_token_id = tokenizer.eos_token_id
current_model_max_length = getattr(tokenizer, "model_max_length", None)
if (
isinstance(current_model_max_length, int)
and 0 < current_model_max_length < config.max_seq_length
and (
_is_large_training_model(config.model_name)
or config.max_seq_length >= LONG_CONTEXT_MIN_SEQ_LENGTH
)
):
tokenizer.model_max_length = config.max_seq_length
def _uses_peft(config: TrainingConfig) -> bool:
return config.adapter_type in PEFT_ADAPTER_TYPES
def _uses_qlora(config: TrainingConfig) -> bool:
return config.adapter_type == "qlora"
def _uses_preference_optimization(config: TrainingConfig) -> bool:
return (
bool(config.preference_dataset_path)
and config.preference_optimization in PREFERENCE_OPTIMIZATION_TYPES
)
def _resolve_torch_dtype(dtype_name: str) -> Any | None:
normalized = dtype_name.strip().lower()
if not normalized:
return None
try:
import torch # type: ignore
except ImportError:
return None
return getattr(torch, normalized, None)
def _build_model_load_kwargs(config: TrainingConfig) -> dict[str, Any]:
model_load_kwargs: dict[str, Any] = {}
model_load_kwargs["trust_remote_code"] = True
model_load_kwargs["low_cpu_mem_usage"] = True
large_model = _is_large_training_model(config.model_name)
if large_model or _uses_qlora(config):
model_load_kwargs["device_map"] = "auto"
model_load_kwargs["offload_folder"] = str(Path(config.output_dir) / ".offload")
if not _uses_qlora(config):
return model_load_kwargs
try:
from transformers import BitsAndBytesConfig # type: ignore
except ImportError as exc:
raise ImportError(
"QLoRA apmācībai vajag transformers BitsAndBytesConfig un bitsandbytes atkarību."
) from exc
quantization_kwargs: dict[str, Any] = {
"load_in_4bit": True,
"bnb_4bit_quant_type": config.qlora_quant_type,
"bnb_4bit_use_double_quant": config.qlora_use_double_quant,
}
compute_dtype = _resolve_torch_dtype(config.qlora_compute_dtype)
if compute_dtype is not None:
quantization_kwargs["bnb_4bit_compute_dtype"] = compute_dtype
model_load_kwargs["quantization_config"] = BitsAndBytesConfig(**quantization_kwargs)
return model_load_kwargs
def _extract_model_size_billions(model_name: str) -> float | None:
matches = [float(match) for match in MODEL_SIZE_HINT_RE.findall(model_name)]
if matches:
return max(matches)
normalized = model_name.strip().lower()
if any(hint in normalized for hint in GIANT_MODEL_NAME_HINTS):
return GIANT_MODEL_AUTO_QLORA_THRESHOLD_B
return None
def _is_large_training_model(model_name: str) -> bool:
size_hint = _extract_model_size_billions(model_name)
return bool(size_hint is not None and size_hint >= LARGE_MODEL_RESOURCE_THRESHOLD_B)
def _build_hf_auth_kwargs(callable_obj: Any) -> dict[str, str]:
auth_token = get_env_any("MARIS_REPO_TOKEN", "MARIS_TOKEN", "HF_TOKEN")
if not auth_token:
return {}
try:
parameters = inspect.signature(callable_obj).parameters
except (TypeError, ValueError):
return {"token": auth_token}
if "token" in parameters:
return {"token": auth_token}
if "use_auth_token" in parameters:
return {"use_auth_token": auth_token}
if any(parameter.kind is inspect.Parameter.VAR_KEYWORD for parameter in parameters.values()):
return {"token": auth_token}
return {}
def _normalize_training_runtime_config(config: TrainingConfig) -> TrainingConfig:
runtime_updates: dict[str, Any] = {}
if not config.fp16 and not config.bf16:
if _runtime_supports_bf16():
runtime_updates["bf16"] = True
elif _runtime_has_cuda():
runtime_updates["fp16"] = True
size_hint = _extract_model_size_billions(config.model_name)
if size_hint is not None:
if size_hint >= GIANT_MODEL_AUTO_QLORA_THRESHOLD_B and config.adapter_type != "qlora":
runtime_updates["adapter_type"] = "qlora"
if (
size_hint >= GIANT_MODEL_AUTO_QLORA_THRESHOLD_B
and config.max_seq_length >= LONG_CONTEXT_MIN_SEQ_LENGTH
):
if config.distributed_strategy == "none":
runtime_updates["distributed_strategy"] = "deepspeed"
if not config.use_accelerate:
runtime_updates["use_accelerate"] = True
if size_hint >= LARGE_MODEL_RESOURCE_THRESHOLD_B and not config.gradient_checkpointing:
runtime_updates["gradient_checkpointing"] = True
if (
size_hint >= LARGE_MODEL_RESOURCE_THRESHOLD_B
and config.per_device_train_batch_size != 1
):
runtime_updates["per_device_train_batch_size"] = 1
if size_hint >= LARGE_MODEL_RESOURCE_THRESHOLD_B and config.per_device_eval_batch_size != 1:
runtime_updates["per_device_eval_batch_size"] = 1
if (
size_hint >= GIANT_MODEL_AUTO_QLORA_THRESHOLD_B
and config.gradient_accumulation_steps < 16
):
runtime_updates["gradient_accumulation_steps"] = 16
if not runtime_updates:
return config
updated_config = replace(config, **runtime_updates)
logger.info(
"Pielāgoju training runtime %s: %s",
config.model_name,
", ".join(f"{key}={value}" for key, value in sorted(runtime_updates.items())),
)
return updated_config
def _get_torch_module() -> Any | None:
try:
import torch # type: ignore
except ImportError:
return None
return torch
def _runtime_has_cuda() -> bool:
torch = _get_torch_module()
cuda = getattr(torch, "cuda", None)
return bool(
cuda is not None and callable(getattr(cuda, "is_available", None)) and cuda.is_available()
)
def _runtime_supports_bf16() -> bool:
torch = _get_torch_module()
cuda = getattr(torch, "cuda", None)
is_available = getattr(cuda, "is_available", None)
if cuda is None or not callable(is_available) or not is_available():
return False
is_bf16_supported = getattr(cuda, "is_bf16_supported", None)
return bool(callable(is_bf16_supported) and is_bf16_supported())
def _runtime_has_accelerator() -> bool:
if _runtime_has_cuda():
return True
torch = _get_torch_module()
backends = getattr(torch, "backends", None)
mps = getattr(backends, "mps", None)
is_available = getattr(mps, "is_available", None)
return bool(callable(is_available) and is_available())
def _build_runtime_training_argument_overrides() -> dict[str, Any]:
overrides: dict[str, Any] = {}
if not _runtime_has_accelerator():
overrides["dataloader_pin_memory"] = False
stderr = getattr(sys, "stderr", None)
is_interactive = bool(callable(getattr(stderr, "isatty", None)) and stderr.isatty())
if not is_interactive:
overrides["disable_tqdm"] = True
overrides["logging_first_step"] = True
return overrides
def _resolve_repo_relative_path(path_value: str) -> str:
normalized = path_value.strip()
if not normalized:
return ""
candidate = Path(normalized)
if candidate.is_absolute():
return str(candidate)
return str((REPO_ROOT / candidate).resolve())
def _resolve_distributed_config_path(config: TrainingConfig) -> str:
if config.distributed_config_path:
return _resolve_repo_relative_path(config.distributed_config_path)
if config.distributed_strategy == "fsdp":
return str((REPO_ROOT / "huggingface" / "fsdp-config.json").resolve())
if config.distributed_strategy == "deepspeed":
return str((REPO_ROOT / "huggingface" / "deepspeed-zero3.json").resolve())
return ""
def _load_json_object(path_value: str, *, label: str) -> dict[str, Any]:
resolved_path = Path(path_value)
if not resolved_path.is_file():
raise FileNotFoundError(f"{label} nav atrasts: {resolved_path}")
payload = json.loads(resolved_path.read_text(encoding="utf-8"))
if not isinstance(payload, dict):
raise ValueError(f"{label} jābūt JSON objektam.")
return payload
def _require_runtime_package(package_name: str, *, context_label: str) -> None:
try:
get_installed_package_version(package_name)
except (PackageNotFoundError, StopIteration) as exc:
raise ImportError(
f"{context_label} nepieciešams instalēt '{package_name}' Python pakotni."
) from exc
def _build_distributed_training_argument_overrides(config: TrainingConfig) -> dict[str, Any]:
overrides: dict[str, Any] = {}
if config.use_accelerate and config.num_processes > 1:
overrides["ddp_find_unused_parameters"] = False
if config.distributed_strategy == "fsdp":
fsdp_config_path = _resolve_distributed_config_path(config)
fsdp_config = (
_load_json_object(fsdp_config_path, label="FSDP konfigurācija")
if fsdp_config_path
else {}
)
fsdp_config.setdefault("min_num_params", config.fsdp_min_num_params)
if config.fsdp_transformer_layer_cls_to_wrap:
fsdp_config["transformer_layer_cls_to_wrap"] = list(
config.fsdp_transformer_layer_cls_to_wrap
)
overrides.update(
{
"fsdp": "full_shard auto_wrap",
"fsdp_config": fsdp_config,
"ddp_find_unused_parameters": False,
}
)
elif config.distributed_strategy == "deepspeed":
deepspeed_config_path = _resolve_distributed_config_path(config)
if not deepspeed_config_path:
raise ValueError(
"DeepSpeed režīmam vajag distributed_config_path vai repo default konfigurāciju."
)
if not Path(deepspeed_config_path).is_file():
raise FileNotFoundError(f"DeepSpeed konfigurācija nav atrasta: {deepspeed_config_path}")
_require_runtime_package("deepspeed", context_label="DeepSpeed režīms")
overrides.update(
{
"deepspeed": deepspeed_config_path,
"ddp_find_unused_parameters": False,
}
)
return overrides
def _load_tokenizer(model_name: str, config: TrainingConfig) -> Any:
from transformers import AutoTokenizer # type: ignore
load_kwargs = _filter_supported_kwargs(
AutoTokenizer.from_pretrained,
{
"trust_remote_code": True,
**_build_hf_auth_kwargs(AutoTokenizer.from_pretrained),
"use_fast": True,
},
)
with maris_hf_compatible_path(model_name, allow_remote_snapshot=True) as compatible_model_path:
try:
return _load_tokenizer_from_path(
compatible_model_path,
config=config,
auto_tokenizer_class=AutoTokenizer,
load_kwargs=load_kwargs,
)
except TOKENIZER_LOAD_RETRY_EXCEPTIONS as exc:
if not _should_retry_after_installing_tokenizer_backends(exc):
raise
_install_missing_tokenizer_backends()
logger.warning(
"Retrying tokenizer load for %s after installing missing runtime backends.",
config.model_name,
)
return _load_tokenizer_from_path(
compatible_model_path,
config=config,
auto_tokenizer_class=AutoTokenizer,
load_kwargs=load_kwargs,
)
def _load_tokenizer_from_path(
compatible_model_path: str,
*,
config: TrainingConfig,
auto_tokenizer_class: Any,
load_kwargs: dict[str, Any],
) -> Any:
try:
return auto_tokenizer_class.from_pretrained(compatible_model_path, **load_kwargs)
# Tokenizer backends often fail with implementation-specific exceptions
# before a slower Python tokenizer can succeed, so we retry once broadly.
except TOKENIZER_LOAD_RETRY_EXCEPTIONS as exc:
if "use_fast" not in load_kwargs:
raise
retry_kwargs = dict(load_kwargs)
retry_kwargs["use_fast"] = False
logger.warning(
"Fast tokenizer load failed for %s; retrying with use_fast=False: %s",
config.model_name,
exc,
)
try:
return auto_tokenizer_class.from_pretrained(compatible_model_path, **retry_kwargs)
except TOKENIZER_LOAD_RETRY_EXCEPTIONS:
explicit_slow_tokenizer = _load_explicit_slow_tokenizer(
compatible_model_path,
retry_kwargs,
)
if explicit_slow_tokenizer is not None:
return explicit_slow_tokenizer
raise
def _should_retry_after_installing_tokenizer_backends(exc: BaseException) -> bool:
message = str(exc).lower()
missing_backend_markers = (
"you need to have sentencepiece or tiktoken installed",
"requires the sentencepiece library",
"requires the tiktoken library",
"no module named 'sentencepiece'",
'no module named "sentencepiece"',
"no module named 'tiktoken'",
'no module named "tiktoken"',
)
return any(marker in message for marker in missing_backend_markers)
def _install_missing_tokenizer_backends() -> bool:
missing_packages = [
package_name
for package_name in ("sentencepiece", "tiktoken")
if importlib.util.find_spec(package_name) is None
]
if not missing_packages:
return False
logger.warning(
"Installing missing tokenizer runtime backends for training: %s",
", ".join(missing_packages),
)
try:
completed = subprocess.run(
[sys.executable, "-m", "pip", "install", "--no-cache-dir", *missing_packages],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
except subprocess.CalledProcessError as exc:
output = (exc.stdout or "").strip()
if output:
logger.error("Tokenizer backend install failed: %s", output)
raise RuntimeError("Couldn't install missing tokenizer backends automatically.") from exc
output = (completed.stdout or "").strip()
if output:
logger.info("Tokenizer backend install output: %s", output)
importlib.invalidate_caches()
return True
def _load_local_tokenizer_artifact(path: Path) -> dict[str, Any] | None:
"""Nolasa lokālu tokenizer JSON artefaktu un atgriež dict vai None."""
if not path.is_file():
return None
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError) as exc:
logger.warning("Couldn't read tokenizer artifact %s: %s", path, exc)
return None
return payload if isinstance(payload, dict) else None
def _collect_explicit_slow_tokenizer_classes(model_name_or_path: str) -> list[str]:
"""Savāc deduplicētus slow tokenizer klašu nosaukumus no lokāliem HF artefaktiem."""
model_dir = Path(model_name_or_path)
if not model_dir.is_dir():
return []
class_names: list[str] = []
seen: set[str] = set()
def add_class_name(candidate: Any) -> None:
if not isinstance(candidate, str):
return
normalized = candidate.strip()
if not normalized:
return
if normalized.endswith("Fast"):
normalized = normalized[:-4]
if not normalized or normalized in seen:
return
seen.add(normalized)
class_names.append(normalized)
for artifact_name in ("tokenizer_config.json", "config.json"):
payload = _load_local_tokenizer_artifact(model_dir / artifact_name)
if payload is None:
continue
add_class_name(payload.get("tokenizer_class"))
auto_map = payload.get("auto_map")
if not isinstance(auto_map, dict):
continue
auto_tokenizer_entry = auto_map.get("AutoTokenizer")
if isinstance(auto_tokenizer_entry, list):
for candidate in auto_tokenizer_entry:
add_class_name(candidate)
else:
add_class_name(auto_tokenizer_entry)
return class_names
def _load_explicit_slow_tokenizer(
model_name_or_path: str, load_kwargs: dict[str, Any]
) -> Any | None:
"""Mēģina ielādēt explicit slow tokenizer klasi pēc AutoTokenizer fallback neveiksmes."""
class_names = _collect_explicit_slow_tokenizer_classes(model_name_or_path)
if not class_names:
return None
import transformers # type: ignore
retry_kwargs = dict(load_kwargs)
retry_kwargs.pop("use_fast", None)
for class_name in class_names:
tokenizer_class = getattr(transformers, class_name, None)
if tokenizer_class is None:
continue
try:
logger.warning(
"AutoTokenizer retry still failed for %s; loading explicit slow tokenizer %s.",
model_name_or_path,
class_name,
)
return tokenizer_class.from_pretrained(model_name_or_path, **retry_kwargs)
except TOKENIZER_LOAD_RETRY_EXCEPTIONS as exc:
logger.warning(
"Explicit slow tokenizer %s failed for %s: %s",
class_name,
model_name_or_path,
exc,
)
return None
def _load_model(
model_name_or_path: str,
config: TrainingConfig,
*,
trainable_adapter: bool = False,
) -> Any:
from transformers import AutoModelForCausalLM # type: ignore
with maris_hf_compatible_path(
model_name_or_path, allow_remote_snapshot=True
) as compatible_model_path:
local_path = Path(compatible_model_path)
has_local_adapter_config = (
local_path.exists() and local_path.joinpath("adapter_config.json").is_file()
)
should_try_auto_peft = _uses_peft(config) and (
trainable_adapter or has_local_adapter_config
)
if should_try_auto_peft:
try:
import peft # type: ignore
except ImportError:
auto_peft_model_cls = None # type: ignore[assignment]
else:
auto_peft_model_cls = getattr(peft, "AutoPeftModelForCausalLM", None)
if auto_peft_model_cls is not None:
adapter_kwargs = {}
if (
"is_trainable"
in inspect.signature(auto_peft_model_cls.from_pretrained).parameters
):
adapter_kwargs["is_trainable"] = trainable_adapter
adapter_kwargs.update(_build_hf_auth_kwargs(auto_peft_model_cls.from_pretrained))
adapter_kwargs.setdefault("trust_remote_code", True)
return auto_peft_model_cls.from_pretrained(
compatible_model_path,
**_filter_supported_kwargs(
auto_peft_model_cls.from_pretrained,
adapter_kwargs,
),
)
return AutoModelForCausalLM.from_pretrained(
compatible_model_path,
**_filter_supported_kwargs(
AutoModelForCausalLM.from_pretrained,
{
**_build_model_load_kwargs(config),
**_build_hf_auth_kwargs(AutoModelForCausalLM.from_pretrained),
},
),
)
def _apply_peft_adapter(model: Any, config: TrainingConfig) -> Any:
if not _uses_peft(config):
return model
try:
from peft import ( # type: ignore
LoraConfig,
TaskType,
get_peft_model,
prepare_model_for_kbit_training,
)
except ImportError as exc:
raise ImportError("LoRA/QLoRA apmācībai vajag peft atkarību.") from exc
if _uses_qlora(config):
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=config.gradient_checkpointing,
)
target_modules: str | list[str] | None
if config.peft_target_modules:
target_modules = list(config.peft_target_modules)
else:
target_modules = "all-linear"
peft_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
bias=config.lora_bias,
task_type=TaskType.CAUSAL_LM,
target_modules=target_modules,
)
adapted_model = get_peft_model(model, peft_config)
if hasattr(adapted_model, "print_trainable_parameters"):
adapted_model.print_trainable_parameters()
return adapted_model
def _supports_gradient_checkpointing_kwargs(method: Any) -> bool:
with suppress(TypeError, ValueError):
signature = inspect.signature(method)
return "gradient_checkpointing_kwargs" in signature.parameters or any(
parameter.kind is inspect.Parameter.VAR_KEYWORD
for parameter in signature.parameters.values()
)
return False
def _prepare_training_model(model_name: str, tokenizer: Any, config: TrainingConfig) -> Any:
model = _load_model(model_name, config)
if (
getattr(model, "config", None) is not None
and getattr(model.config, "pad_token_id", None) is None
):
model.config.pad_token_id = getattr(tokenizer, "pad_token_id", None)
if getattr(model, "config", None) is not None and hasattr(model.config, "use_cache"):
model.config.use_cache = not config.gradient_checkpointing
if config.gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
gradient_checkpointing_enable = model.gradient_checkpointing_enable
if config.gradient_checkpointing_use_reentrant is None:
gradient_checkpointing_enable()
else:
if _supports_gradient_checkpointing_kwargs(gradient_checkpointing_enable):
gradient_checkpointing_enable(
gradient_checkpointing_kwargs={
"use_reentrant": config.gradient_checkpointing_use_reentrant
}
)
else:
logger.warning(
"Ignoring explicit gradient_checkpointing_use_reentrant=%s because "
"gradient_checkpointing_enable() does not accept keyword overrides in "
"the current runtime.",
config.gradient_checkpointing_use_reentrant,
)
gradient_checkpointing_enable()
return _apply_peft_adapter(model, config)
def _load_preference_training_dataset(config: TrainingConfig) -> Any:
examples = _filter_preference_examples_for_branch(
load_preference_dataset(config.preference_dataset_path),
branch_name=config.branch_name,
branch_filter_rules=config.branch_dataset_filter_rules,
)
rows = [
{
"prompt": example.prompt,
"chosen": example.chosen,
"rejected": example.rejected,
"source": example.source,
"annotator": example.annotator or "",
"edit_target": example.edit_target or "",
"context": example.context or "",
"branch": example.branch or "",
"task_type": example.task_type or "",
"language": example.language or "",
"repo_context": list(example.repo_context),
"execution_required": example.execution_required,
"tags": list(example.tags),
}
for example in examples
]
try:
from datasets import Dataset # type: ignore
except ImportError:
return rows
try:
return Dataset.from_list(rows)
except AttributeError as exc:
if "PreTrainedTokenizerBase" in str(exc):
logger.warning(
"Falling back to plain preference rows because datasets fingerprinting needs a fuller transformers install: %s",
exc,
)
return rows
raise
def _filter_preference_examples_for_branch(
examples: list[PreferenceExample],
*,
branch_name: str,
branch_filter_rules: dict[str, dict[str, Any]] | None = None,
) -> list[PreferenceExample]:
if branch_name not in TEXT_TRAINABLE_BRANCHES:
return examples
filtered = [
example
for example in examples
if _matches_branch_filter_rule(
_preference_branch_filter_signals(example),
branch_name=branch_name,
branch_filter_rules=branch_filter_rules,
)
]
if filtered:
return filtered
raise ValueError(
f"Preference dataset filter neatgrieza nevienu piemēru atzaram '{branch_name}'."
)
def _is_local_training_artifact_dir(path: Path) -> bool:
"""Return whether a directory contains a reusable saved model or adapter artifact."""
return path.is_dir() and any(
path.joinpath(name).is_file() for name in LOCAL_TRAINING_ARTIFACT_FILES
)
def _build_model_source_fingerprint(model_name: str) -> str:
normalized = model_name.strip().casefold().encode("utf-8")
return hashlib.sha256(normalized).hexdigest()
def _load_local_model_source_fingerprint(path: Path) -> str | None:
training_config = _load_json_if_exists(path / "training-config.json")
if not isinstance(training_config, dict):
return None
fingerprint = training_config.get(MODEL_SOURCE_FINGERPRINT_KEY)
if not isinstance(fingerprint, str):
return None
normalized = fingerprint.strip().casefold()
return normalized or None
def _resolve_training_model_source(config: TrainingConfig) -> str:
"""Resolve the effective source model, preferring local persistent artifacts when enabled."""
if not config.continue_from_latest_artifact:
return config.model_name
candidates: list[tuple[Path, bool]] = []
if config.continue_model_path:
candidates.append((Path(config.continue_model_path), True))
output_dir = Path(config.output_dir)
candidates.append((output_dir, False))
expected_fingerprint = _build_model_source_fingerprint(config.model_name)
seen: set[Path] = set()
for candidate, explicit_candidate in candidates:
try:
resolved = candidate.expanduser().resolve()
except OSError:
continue
if resolved in seen:
continue
seen.add(resolved)
if _is_local_training_artifact_dir(resolved):
if not explicit_candidate:
saved_fingerprint = _load_local_model_source_fingerprint(resolved)
if saved_fingerprint != expected_fingerprint:
if saved_fingerprint:
logger.info(
"Izlaižu auto-continue no %s, jo saglabātā modeļa bāze neatbilst izvēlētajam modelim.",
resolved,
)
else:
logger.info(
"Izlaižu auto-continue no %s, jo trūkst modeļa saderības metadata.",
resolved,
)
continue
logger.info("Turpinu treniņu no lokālā artefakta: %s", resolved)
return str(resolved)
return config.model_name
def _resolve_runtime_uid_suffix() -> str:
"""Best-effort UID suffix for runtime fallback paths and user labels."""
uid = "unknown"
getuid = getattr(os, "getuid", None)
if callable(getuid):
with suppress(OSError):
uid = str(getuid())
return uid
def _ensure_runtime_home_dir() -> str:
"""Ensure runtime HOME and user identity vars exist in containers without passwd entries."""
configured_home = os.environ.get("HOME", "").strip()
uid = _resolve_runtime_uid_suffix()
if not configured_home:
fallback_home = (Path(tempfile.gettempdir()) / f"maris-home-{uid}").resolve()
fallback_home.mkdir(parents=True, exist_ok=True)
configured_home = str(fallback_home)
os.environ["HOME"] = configured_home
logger.warning(
"HOME nav iestatīts; izmantojam fallback runtime home %s, lai treniņš strādātu konteineros bez passwd ieraksta.",
fallback_home,
)
runtime_user = ""
for name in ("USER", "LOGNAME", "USERNAME"):
candidate = os.environ.get(name, "").strip()
if candidate:
runtime_user = candidate
break
if not runtime_user:
runtime_user = f"maris-{uid}"
logger.warning(
"Lietotāja vides mainīgie nav iestatīti; izmantojam fallback runtime lietotāju %s, lai izvairītos no getpwuid kļūdām.",
runtime_user,
)
for name in ("USER", "LOGNAME", "USERNAME"):
if not os.environ.get(name, "").strip():
os.environ[name] = runtime_user
return configured_home
def _filter_supported_kwargs(callable_obj: Any, kwargs: dict[str, Any]) -> dict[str, Any]:
signature = inspect.signature(callable_obj)
parameters = signature.parameters
if any(parameter.kind is inspect.Parameter.VAR_KEYWORD for parameter in parameters.values()):
return kwargs
supported_names = {
name
for name, parameter in parameters.items()
if name != "self"
and parameter.kind
in {inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY}
}
return {name: value for name, value in kwargs.items() if name in supported_names}
def _build_preference_training_arguments(preference_config_cls: Any, config: TrainingConfig) -> Any:
return _build_training_arguments(
preference_config_cls,
output_dir=config.output_dir,
overwrite_output_dir=True,
num_train_epochs=config.num_epochs,
learning_rate=config.learning_rate,
per_device_train_batch_size=config.per_device_train_batch_size,
per_device_eval_batch_size=config.per_device_eval_batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
warmup_ratio=config.warmup_ratio,
weight_decay=config.weight_decay,
logging_steps=config.logging_steps,
save_steps=config.save_steps,
eval_steps=config.eval_steps,
save_total_limit=config.save_total_limit,
lr_scheduler_type=config.lr_scheduler_type,
seed=config.seed,
fp16=config.fp16,
bf16=config.bf16,
report_to=config.report_to,
save_safetensors=config.save_safetensors,
remove_unused_columns=False,
beta=config.preference_beta,
max_length=config.preference_max_length,
max_prompt_length=config.preference_max_prompt_length,
evaluation_strategy="no",
)
def _load_preference_reference_model(output_dir: Path, config: TrainingConfig) -> Any:
reference_path = config.preference_reference_model or str(output_dir)
return _load_model(reference_path, config, trainable_adapter=False)
def _run_preference_optimization(
model: Any,
tokenizer: Any,
config: TrainingConfig,
*,
output_dir: Path,
) -> tuple[Any, Any, dict[str, float]]:
if not _uses_preference_optimization(config):
return model, None, {}
try:
import trl # type: ignore
except ImportError as exc:
raise ImportError("Preference optimization (DPO/ORPO) vajag trl atkarību.") from exc
preference_dataset = _load_preference_training_dataset(config)
strategy = config.preference_optimization
trainer_name = f"{strategy.upper()}Trainer"
trainer_cls = getattr(trl, trainer_name, None)
if trainer_cls is None:
raise RuntimeError(f"TRL modulī nav pieejams {trainer_name}.")
trainer_config_cls = getattr(trl, f"{strategy.upper()}Config", None)
if trainer_config_cls is None:
from transformers import TrainingArguments # type: ignore
trainer_config_cls = TrainingArguments
preference_args = _build_preference_training_arguments(trainer_config_cls, config)
trainer_kwargs: dict[str, Any] = {
"model": model,
"args": preference_args,
"train_dataset": preference_dataset,
"tokenizer": tokenizer,
"processing_class": tokenizer,
}
if strategy == "dpo":
trainer_kwargs["ref_model"] = _load_preference_reference_model(output_dir, config)
preference_trainer = trainer_cls(
**_filter_supported_kwargs(trainer_cls.__init__, trainer_kwargs)
)
preference_result = preference_trainer.train()
preference_metrics = {
f"preference_{key}": float(value)
for key, value in dict(getattr(preference_result, "metrics", {})).items()
if isinstance(value, int | float)
}
preference_metrics["preference_examples"] = float(len(preference_dataset))
preference_metrics["preference_stage"] = 1.0
preference_metrics["preference_strategy"] = 1.0 if strategy == "dpo" else 2.0
return model, preference_trainer, preference_metrics
def _save_json(path: Path, payload: dict[str, Any]) -> None:
_save_json_payload(path, payload)
def _save_json_payload(path: Path, payload: Any) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8")
def _load_json_if_exists(path: Path) -> dict[str, Any] | None:
if not path.is_file():
return None
payload = json.loads(path.read_text(encoding="utf-8"))
return payload if isinstance(payload, dict) else None
def _write_text(path: Path, content: str) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(content, encoding="utf-8")
def _build_training_artifact_config(config: TrainingConfig) -> dict[str, Any]:
payload = config.to_dict().copy()
payload.pop("model_name", None)
payload["model_preset"] = config.model_preset or "custom"
payload[MODEL_SOURCE_FINGERPRINT_KEY] = _build_model_source_fingerprint(config.model_name)
payload["maris_origin"] = MARIS_ORIGIN_NAME
payload["maris_framework"] = MARIS_FRAMEWORK_NAME
payload["maris_model_id"] = config.hub_model_id
payload["artifact_format"] = "maris-training-config-v1"
return payload
def _build_artifact_identity(config: TrainingConfig) -> dict[str, Any]:
return {
"maris_origin": MARIS_ORIGIN_NAME,
"maris_framework": MARIS_FRAMEWORK_NAME,
"dataset_repo": config.dataset_repo,
"dataset_repos": _resolve_training_dataset_repos(config),
"eval_dataset_repo": _resolve_primary_eval_dataset_repo(config),
"eval_dataset_repos": _resolve_eval_dataset_repos(config),
"branch_name": config.branch_name,
"branch_focus": config.branch_focus,
}
def _build_model_artifact_identity(config: TrainingConfig) -> dict[str, Any]:
return {
**_build_artifact_identity(config),
"maris_model_id": config.hub_model_id,
}
def _build_metrics_artifact(
config: TrainingConfig,
metrics: dict[str, float],
*,
artifact_type: str,
extra_payload: dict[str, Any] | None = None,
) -> dict[str, Any]:
payload = dict(metrics)
payload.update(_build_model_artifact_identity(config))
payload["artifact_type"] = artifact_type
if extra_payload:
payload.update(extra_payload)
return payload
def _build_branch_suite_artifact(
config: TrainingConfig, suite_results: dict[str, dict[str, Any]]
) -> dict[str, Any]:
return {
"branches": suite_results,
**_build_artifact_identity(config),
"artifact_type": "branch-suite",
}
def _write_origin_metadata(output_dir: Path, config: TrainingConfig) -> None:
_save_json(
output_dir / "maris-metadata.json",
{
"origin": MARIS_ORIGIN_NAME,
"framework": MARIS_FRAMEWORK_NAME,
**_build_model_artifact_identity(config),
},
)
def _build_model_card(
config: TrainingConfig,
*,
trained_at: str,
train_examples: int,
eval_examples: int,
metrics: dict[str, float],
) -> str:
metrics_lines = []
for key in ("train_loss", "eval_loss", "perplexity"):
if key in metrics:
metrics_lines.append(f"- `{key}`: {metrics[key]}")
if not metrics_lines:
metrics_lines.append("- Nav pieejamu metriku šim skrējienam.")
return "\n".join(
[
"---",
"language:",
" - lv",
" - en",
"license: mit",
"pipeline_tag: text-generation",
"tags:",
*(f" - {tag}" for tag in MODEL_CARD_TAGS),
"---",
"",
f"# {MARIS_ORIGIN_NAME} Model",
"",
"## Model Summary",
"",
f"Šis artefakts ir {MARIS_ORIGIN_NAME} treniņa izvads atzaram `{config.branch_name}`.",
f"Modelis tiek publicēts kā `{config.hub_model_id}` un saglabā Maris AI identitāti visos galvenajos artefaktos.",
"",
"## Training Data",
"",
f"- Dataset repo: `{config.dataset_repo}`",
f"- Dataset repos: `{', '.join(_resolve_training_dataset_repos(config))}`",
f"- Eval dataset repo: `{_resolve_primary_eval_dataset_repo(config)}`",
f"- Eval dataset repos: `{', '.join(_resolve_eval_dataset_repos(config) or [_resolve_primary_eval_dataset_repo(config)])}`",
f"- Train examples: `{train_examples}`",
f"- Eval examples: `{eval_examples}`",
f"- Branch focus: `{config.branch_focus}`",
"",
"## Training Configuration",
"",
"- Base lineage: `Maris AI`",
f"- Model preset: `{config.model_preset or 'custom'}`",
f"- Adapter strategy: `{config.adapter_type}`",
f"- Preference optimization: `{config.preference_optimization}`",
f"- Output model id: `{config.hub_model_id}`",
f"- Trained at: `{trained_at}`",
"",
"## Metrics",
"",
*metrics_lines,
"",
"## Maris Identity Guarantees",
"",
f"- `config.json`, `generation_config.json` un `tokenizer_config.json` tiek sanitizēti uz `{config.hub_model_id}`",
f"- `training-config.json`, `training-provenance.json` un `maris-metadata.json` satur tikai `{MARIS_ORIGIN_NAME}` izcelsmes metadatus",
"- Šis artefakts ir paredzēts publicēšanai un lietošanai kā Maris AI modelis",
"",
]
)
def _write_training_provenance(
output_dir: Path,
config: TrainingConfig,
*,
trained_at: str,
train_examples: int,
eval_examples: int,
metrics: dict[str, float],
quality_report: DatasetQualityReport | None = None,
scoring_report: DatasetScoringReport | None = None,
benchmark_feedback: DatasetBenchmarkFeedback | None = None,
) -> None:
artifact_files = [
"README.md",
"training-config.json",
"training-metrics.json",
"maris-metadata.json",
"training-provenance.json",
]
if quality_report is not None:
artifact_files.append("dataset-quality-report.json")
if scoring_report is not None:
artifact_files.append("dataset-scoring-report.json")
_save_json(
output_dir / "training-provenance.json",
{
"trained_at": trained_at,
"dataset_repo": config.dataset_repo,
"dataset_repos": _resolve_training_dataset_repos(config),
"eval_dataset_repo": _resolve_primary_eval_dataset_repo(config),
"eval_dataset_repos": _resolve_eval_dataset_repos(config),
"branch_name": config.branch_name,
"branch_focus": config.branch_focus,
"adapter_type": config.adapter_type,
"preference_optimization": config.preference_optimization,
"base_model_name": MARIS_ORIGIN_NAME,
"base_model_lineage": MARIS_ORIGIN_NAME,
"model_preset": config.model_preset or "custom",
"hub_model_id": config.hub_model_id,
"train_examples": train_examples,
"eval_examples": eval_examples,
"artifact_files": artifact_files,
"metrics": metrics,
"quality_report": quality_report.to_dict() if quality_report is not None else None,
"scoring_report": scoring_report.to_dict() if scoring_report is not None else None,
"benchmark_feedback": (
build_benchmark_feedback_artifact(benchmark_feedback)
if benchmark_feedback is not None
else None
),
"maris_origin": MARIS_ORIGIN_NAME,
"maris_framework": MARIS_FRAMEWORK_NAME,
},
)
def _sanitize_saved_artifact_json(path: Path, *, maris_model_id: str) -> None:
if not path.is_file():
return
payload = json.loads(path.read_text(encoding="utf-8"))
sanitized_payload = _sanitize_artifact_payload(
payload,
maris_model_id=maris_model_id,
)
if isinstance(sanitized_payload, dict):
sanitized_payload["maris_origin"] = MARIS_ORIGIN_NAME
sanitized_payload["maris_framework"] = MARIS_FRAMEWORK_NAME
sanitized_payload["maris_model_id"] = maris_model_id
_save_json_payload(path, sanitized_payload)
def _sanitize_identity_text(value: str, *, maris_model_id: str) -> str:
sanitized = FOREIGN_MODEL_REFERENCE_RE.sub(maris_model_id, value)
for pattern in FOREIGN_AI_BRAND_PATTERNS:
sanitized = pattern.sub(MARIS_ORIGIN_NAME, sanitized)
sanitized = MARIS_IDENTITY_VARIANT_RE.sub(MARIS_ORIGIN_NAME, sanitized)
return sanitized
def _is_sanitized_reference_key(field_name: str | None) -> bool:
return bool(field_name) and (
field_name in SANITIZED_ARTIFACT_KEYS or field_name.endswith("_name_or_path")
)
def _sanitize_artifact_payload(
payload: Any,
*,
maris_model_id: str,
field_name: str | None = None,
path: tuple[str, ...] = (),
) -> Any:
if isinstance(payload, dict):
return {
key: _sanitize_artifact_payload(
value,
maris_model_id=maris_model_id,
field_name=str(key),
path=(*path, str(key)),
)
for key, value in payload.items()
}
if isinstance(payload, list):
return [
_sanitize_artifact_payload(
item,
maris_model_id=maris_model_id,
field_name=field_name,
path=(*path, str(index)),
)
for index, item in enumerate(payload)
]
if not isinstance(payload, str):
return payload
if _is_sanitized_reference_key(field_name):
return maris_model_id
if field_name in {"base_model_name", "base_model_lineage"}:
return MARIS_ORIGIN_NAME
if field_name in IDENTITY_TEXT_KEYS or "chat_template" in path:
return _sanitize_identity_text(payload, maris_model_id=maris_model_id)
if FOREIGN_MODEL_REFERENCE_RE.search(payload) or any(
pattern.search(payload) for pattern in FOREIGN_AI_BRAND_PATTERNS
):
return _sanitize_identity_text(payload, maris_model_id=maris_model_id)
return payload
def _is_text_identity_artifact(path: Path) -> bool:
lowered = path.name.lower()
return (
path.suffix.lower() in TEXT_SANITIZED_ARTIFACT_EXTENSIONS
or "chat_template" in lowered
or "chat-template" in lowered
)
def _sanitize_text_identity_artifact(path: Path, *, maris_model_id: str) -> None:
if not path.is_file():
return
path.write_text(
_sanitize_identity_text(path.read_text(encoding="utf-8"), maris_model_id=maris_model_id),
encoding="utf-8",
)
def _sanitize_saved_artifacts(output_dir: Path, *, maris_model_id: str) -> None:
for path in sorted(output_dir.rglob("*")):
if not path.is_file():
continue
if path.suffix.lower() == ".json":
_sanitize_saved_artifact_json(path, maris_model_id=maris_model_id)
continue
if _is_text_identity_artifact(path):
_sanitize_text_identity_artifact(path, maris_model_id=maris_model_id)
write_maris_compatibility_artifact(output_dir, maris_model_id=maris_model_id)
apply_maris_compatibility_identity(output_dir)
def _collect_unsanitized_identity_strings(
payload: Any,
*,
path: tuple[str, ...] = (),
) -> list[str]:
if isinstance(payload, dict):
matches: list[str] = []
for key, value in payload.items():
matches.extend(
_collect_unsanitized_identity_strings(
value,
path=(*path, str(key)),
)
)
return matches
if isinstance(payload, list):
list_matches: list[str] = []
for index, item in enumerate(payload):
list_matches.extend(
_collect_unsanitized_identity_strings(item, path=(*path, str(index)))
)
return list_matches
if not isinstance(payload, str):
return []
if FOREIGN_MODEL_REFERENCE_RE.search(payload):
return ["/".join(path) or "<root>"]
for pattern in FOREIGN_AI_BRAND_PATTERNS:
if pattern.search(payload):
return ["/".join(path) or "<root>"]
return []
def _verify_saved_artifacts(output_dir: Path) -> None:
issues: list[str] = []
for path in sorted(output_dir.rglob("*")):
if not path.is_file():
continue
if path.name == MARIS_COMPATIBILITY_ARTIFACT_NAME:
continue
if path.suffix.lower() == ".json":
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except json.JSONDecodeError:
continue
for issue in _collect_unsanitized_identity_strings(payload):
issues.append(f"{path.relative_to(output_dir)}:{issue}")
continue
if _is_text_identity_artifact(path):
content = path.read_text(encoding="utf-8")
if FOREIGN_MODEL_REFERENCE_RE.search(content) or any(
pattern.search(content) for pattern in FOREIGN_AI_BRAND_PATTERNS
):
issues.append(str(path.relative_to(output_dir)))
if issues:
unique_issues = sorted(set(issues))
visible_issues = unique_issues[:20]
overflow_suffix = ""
if len(unique_issues) > len(visible_issues):
overflow_suffix = f" ... un vēl {len(unique_issues) - len(visible_issues)}"
raise ValueError(
"Atrasti nesanitizēti Hugging Face artefakti: "
+ ", ".join(visible_issues)
+ overflow_suffix
)
def _write_training_artifacts(
output_dir: Path,
config: TrainingConfig,
*,
trained_at: str,
train_examples: int,
eval_examples: int,
metrics: dict[str, float],
quality_report: DatasetQualityReport | None = None,
scoring_report: DatasetScoringReport | None = None,
benchmark_feedback: DatasetBenchmarkFeedback | None = None,
) -> None:
_sanitize_saved_artifacts(output_dir, maris_model_id=config.hub_model_id)
_save_json(output_dir / "training-config.json", _build_training_artifact_config(config))
_save_json(
output_dir / "training-metrics.json",
_build_metrics_artifact(
config,
metrics,
artifact_type="training-metrics",
extra_payload=_build_scoring_dashboard_payload(
scoring_report=scoring_report,
benchmark_feedback=benchmark_feedback,
),
),
)
if quality_report is not None:
_save_json(output_dir / "dataset-quality-report.json", quality_report.to_dict())
if scoring_report is not None:
_save_json(output_dir / "dataset-scoring-report.json", scoring_report.to_dict())
_write_origin_metadata(output_dir, config)
_write_training_provenance(
output_dir,
config,
trained_at=trained_at,
train_examples=train_examples,
eval_examples=eval_examples,
metrics=metrics,
quality_report=quality_report,
scoring_report=scoring_report,
benchmark_feedback=benchmark_feedback,
)
_write_text(
output_dir / "README.md",
_build_model_card(
config,
trained_at=trained_at,
train_examples=train_examples,
eval_examples=eval_examples,
metrics=metrics,
),
)
def _write_preference_artifact(output_dir: Path, config: TrainingConfig) -> dict[str, Any] | None:
if not config.preference_dataset_path:
return None
examples = _load_filtered_preference_examples(config)
payload = summarize_preference_dataset(examples)
payload.update(_build_model_artifact_identity(config))
human_eval_summary = build_human_eval_summary(examples)
human_eval_summary.update(_build_model_artifact_identity(config))
_save_json(output_dir / "preference-summary.json", payload)
_save_json(output_dir / "human-eval-summary.json", human_eval_summary)
_save_json(
output_dir / "blind-side-by-side-eval.json",
build_blind_side_by_side_artifact(examples),
)
return payload
def _load_filtered_preference_examples(config: TrainingConfig) -> list[PreferenceExample]:
return _filter_preference_examples_for_branch(
load_preference_dataset(config.preference_dataset_path),
branch_name=config.branch_name,
branch_filter_rules=config.branch_dataset_filter_rules,
)
def _sync_output_dir_to_hub(
output_dir: Path, config: TrainingConfig, *, commit_message: str
) -> None:
try:
from huggingface_hub import HfApi # type: ignore
except ImportError:
logger.warning("huggingface_hub nav pieejams — izlaižam Maris artefaktu sinhronizāciju.")
return
try:
token = get_env_any("MARIS_REPO_TOKEN", "MARIS_TOKEN", "HF_TOKEN")
api = HfApi(token=token) if token else HfApi()
api.create_repo(repo_id=config.hub_model_id, repo_type="model", exist_ok=True)
api.upload_folder(
folder_path=str(output_dir),
repo_id=config.hub_model_id,
repo_type="model",
commit_message=commit_message,
)
except Exception as exc: # noqa: BLE001
logger.warning("Neizdevās pārpublicēt Maris artefaktus uz Hub: %s", exc)
def _build_training_arguments(training_arguments_cls: Any, **kwargs: Any) -> Any:
signature = inspect.signature(training_arguments_cls.__init__)
parameters = signature.parameters
if any(parameter.kind is inspect.Parameter.VAR_KEYWORD for parameter in parameters.values()):
return training_arguments_cls(**kwargs)
supported_names = {
name
for name, parameter in parameters.items()
if name != "self"
and parameter.kind
in {inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY}
}
resolved_kwargs = dict(kwargs)
for source_name, target_name in TRAINING_ARGUMENT_ALIASES.items():
if (
source_name in resolved_kwargs
and source_name not in supported_names
and target_name in supported_names
):
resolved_kwargs[target_name] = resolved_kwargs.pop(source_name)
ignored_arguments = sorted(set(resolved_kwargs) - supported_names)
if ignored_arguments:
logger.info(
"Izlaiž neatbalstītos TrainingArguments parametrus: %s",
", ".join(ignored_arguments),
)
return training_arguments_cls(
**{name: value for name, value in resolved_kwargs.items() if name in supported_names}
)
def _materialize_split_records(split: Any) -> list[dict[str, Any]]:
raw_records = getattr(split, "items", split)
return [
dict(record) if isinstance(record, dict) else {"text": str(record)}
for record in raw_records
]
def _rebuild_split_like(split: Any, records: list[dict[str, Any]]) -> Any:
if type(split).__module__.startswith("datasets"):
from datasets import Dataset # type: ignore
return Dataset.from_list(records)
split_type = type(split)
try:
return split_type(records)
except Exception: # noqa: BLE001
return records
def _resolve_branch_benchmark_dataset_path(config: TrainingConfig, branch_name: str) -> str:
explicit = str(config.branch_benchmark_dataset_paths.get(branch_name, "") or "").strip()
if explicit:
return explicit
return config.benchmark_dataset_path
def _resolve_branch_benchmark_name(config: TrainingConfig, branch_name: str) -> str:
explicit = str(config.branch_benchmark_names.get(branch_name, "") or "").strip()
if explicit:
return explicit
return config.benchmark_name
def _resolve_branch_preference_dataset_path(config: TrainingConfig, branch_name: str) -> str:
explicit = str(config.branch_preference_dataset_paths.get(branch_name, "") or "").strip()
if explicit:
return explicit
return config.preference_dataset_path
def _apply_branch_runtime_defaults(config: TrainingConfig) -> TrainingConfig:
benchmark_dataset_path = config.benchmark_dataset_path
benchmark_name = config.benchmark_name
preference_dataset_path = config.preference_dataset_path
if benchmark_dataset_path and preference_dataset_path:
return config
if not benchmark_dataset_path and config.benchmark_gate_enabled:
benchmark_dataset_path = _resolve_branch_benchmark_dataset_path(config, config.branch_name)
if config.benchmark_gate_enabled and (
not benchmark_name or benchmark_name == DEFAULT_BENCHMARK_NAME
):
benchmark_name = _resolve_branch_benchmark_name(config, config.branch_name)
if not preference_dataset_path:
preference_dataset_path = _resolve_branch_preference_dataset_path(
config, config.branch_name
)
if (
benchmark_dataset_path == config.benchmark_dataset_path
and benchmark_name == config.benchmark_name
and preference_dataset_path == config.preference_dataset_path
):
return config
return replace(
config,
benchmark_dataset_path=benchmark_dataset_path,
benchmark_name=benchmark_name,
preference_dataset_path=preference_dataset_path,
)
def _normalize_branch_filter_token(value: Any) -> str:
return str(value or "").strip().lower().replace("-", "_").replace(" ", "_")
def _normalize_branch_filter_values(value: Any) -> set[str]:
if isinstance(value, list | tuple | set | frozenset):
return {
normalized for item in value if (normalized := _normalize_branch_filter_token(item))
}
normalized = _normalize_branch_filter_token(value)
return {normalized} if normalized else set()
def _record_branch_labels(record: dict[str, Any]) -> set[str]:
labels: set[str] = set()
def add_label(value: Any) -> None:
if isinstance(value, str):
normalized = _normalize_branch_filter_token(value)
if normalized:
labels.add(normalized)
elif isinstance(value, list):
for item in value:
add_label(item)
for key in ("branch", "profile", "type", "task_type", "category"):
add_label(record.get(key))
metadata = record.get("metadata")
if isinstance(metadata, dict):
for key in ("branch", "profile", "type", "task", "workflow", "project_area"):
add_label(metadata.get(key))
add_label(metadata.get("tags"))
add_label(record.get("tags"))
return labels
def _record_repo_context_terms(record: dict[str, Any]) -> set[str]:
values: list[str] = []
repo_context = record.get("repo_context")
if isinstance(repo_context, list):
values.extend(str(item).strip().lower() for item in repo_context if str(item).strip())
elif isinstance(repo_context, str) and repo_context.strip():
values.append(repo_context.strip().lower())
metadata = record.get("metadata")
if isinstance(metadata, dict):
for key in ("project_area", "audience"):
value = metadata.get(key)
if isinstance(value, str) and value.strip():
values.append(value.strip().lower())
return set(values)
def _record_branch_filter_signals(record: dict[str, Any]) -> BranchFilterSignals:
labels = _record_branch_labels(record)
metadata = record.get("metadata")
task_types = _normalize_branch_filter_values(record.get("task_type"))
if isinstance(metadata, dict):
task_types.update(_normalize_branch_filter_values(metadata.get("task")))
task_types.update(_normalize_branch_filter_values(metadata.get("workflow")))
languages = _normalize_branch_filter_values(record.get("language"))
if isinstance(metadata, dict):
languages.update(_normalize_branch_filter_values(metadata.get("language")))
record_types = _normalize_branch_filter_values(record.get("type"))
if isinstance(metadata, dict):
record_types.update(_normalize_branch_filter_values(metadata.get("type")))
presence_keys = {
_normalize_branch_filter_token(key)
for key, value in record.items()
if key != "metadata" and value not in (None, "", [], {}, ())
}
return BranchFilterSignals(
explicit_branches=frozenset(label for label in labels if label in TEXT_TRAINABLE_BRANCHES),
record_types=frozenset(record_types),
task_types=frozenset(task_types),
languages=frozenset(languages),
repo_context_terms=frozenset(_record_repo_context_terms(record)),
presence_keys=frozenset(presence_keys),
)
def _preference_branch_filter_signals(example: PreferenceExample) -> BranchFilterSignals:
explicit_branches = _normalize_branch_filter_values(example.branch)
return BranchFilterSignals(
explicit_branches=frozenset(
label for label in explicit_branches if label in TEXT_TRAINABLE_BRANCHES
),
task_types=frozenset(_normalize_branch_filter_values(example.task_type)),
languages=frozenset(_normalize_branch_filter_values(example.language)),
repo_context_terms=frozenset(
_normalize_branch_filter_token(item)
for item in example.repo_context
if str(item).strip()
),
presence_keys=frozenset(
key
for key, value in {
"edit_target": example.edit_target,
"context": example.context,
"execution_required": example.execution_required,
}.items()
if value not in (None, "", False)
),
)
def _matches_branch_filter_rule(
signals: BranchFilterSignals,
*,
branch_name: str,
branch_filter_rules: dict[str, dict[str, Any]] | None,
) -> bool:
normalized_branch = _normalize_branch_filter_token(branch_name)
if normalized_branch not in TEXT_TRAINABLE_BRANCHES:
return True
rules = (branch_filter_rules or DEFAULT_BRANCH_DATASET_FILTER_RULES).get(normalized_branch, {})
if not rules:
return True
include_explicit_branches = _normalize_branch_filter_values(
rules.get("include_explicit_branches")
)
exclude_explicit_branches = _normalize_branch_filter_values(
rules.get("exclude_explicit_branches")
)
if signals.explicit_branches.intersection(exclude_explicit_branches):
return False
if signals.explicit_branches.intersection(include_explicit_branches):
return True
positive_rule_matches = []
for rule_key, signal_values in (
("include_record_types", signals.record_types),
("include_task_types", signals.task_types),
("include_languages", signals.languages),
("include_repo_context_terms", signals.repo_context_terms),
("include_presence_keys", signals.presence_keys),
):
required_values = _normalize_branch_filter_values(rules.get(rule_key))
if required_values:
positive_rule_matches.append(bool(signal_values.intersection(required_values)))
allow_unlabeled = bool(rules.get("allow_unlabeled", False))
if any(positive_rule_matches):
return True
if positive_rule_matches or include_explicit_branches:
return allow_unlabeled and not signals.explicit_branches
if allow_unlabeled and not signals.explicit_branches:
return True
return not signals.explicit_branches
def _filter_records_for_branch(
records: list[dict[str, Any]],
*,
branch_name: str,
split_name: str,
branch_filter_rules: dict[str, dict[str, Any]] | None = None,
) -> tuple[list[dict[str, Any]], DatasetBranchFilterReport]:
if branch_name not in TEXT_TRAINABLE_BRANCHES:
report = DatasetBranchFilterReport(
split_name=split_name,
branch_name=branch_name,
input_records=len(records),
kept_records=len(records),
dropped_records=0,
)
return records, report
filtered_records = [
record
for record in records
if _matches_branch_filter_rule(
_record_branch_filter_signals(record),
branch_name=branch_name,
branch_filter_rules=branch_filter_rules,
)
]
report = DatasetBranchFilterReport(
split_name=split_name,
branch_name=branch_name,
input_records=len(records),
kept_records=len(filtered_records),
dropped_records=max(len(records) - len(filtered_records), 0),
)
return filtered_records, report
def _apply_branch_dataset_filter_to_split(
split: Any,
*,
split_name: str,
config: TrainingConfig,
allow_empty: bool = False,
) -> tuple[Any | None, DatasetBranchFilterReport]:
filtered_records, report = _filter_records_for_branch(
_materialize_split_records(split),
branch_name=config.branch_name,
split_name=split_name,
branch_filter_rules=config.branch_dataset_filter_rules,
)
if not filtered_records:
message = (
f"Branch dataset filter atmeta visus ierakstus splitam '{split_name}' "
f"atzaram '{config.branch_name}'. Ienākošie={report.input_records}."
)
if allow_empty:
logger.warning("%s Šo splitu izlaidīšu šajā training skrējienā.", message)
return None, report
raise ValueError(message)
return _rebuild_split_like(split, filtered_records), report
def _dataset_quality_config(config: TrainingConfig) -> DatasetQualityGateConfig:
return DatasetQualityGateConfig(
enabled=config.quality_gate_enabled,
dedupe_enabled=config.dedupe_enabled,
min_text_chars=config.quality_min_text_chars,
max_text_chars=max(config.max_seq_length * 8, config.quality_min_text_chars),
)
def _dataset_scoring_config(config: TrainingConfig) -> DatasetScoringConfig:
return DatasetScoringConfig(
enabled=config.scoring_enabled,
weighted_repetition_enabled=config.weighted_repetition_enabled,
max_text_chars=max(config.max_seq_length * 8, config.quality_min_text_chars),
medium_score_repeat_count=config.medium_score_repeat_count,
high_score_repeat_count=config.high_score_repeat_count,
source_weighting_enabled=config.source_weighting_enabled,
source_weight_map=config.source_weight_map.copy(),
category_weight_map=config.category_weight_map.copy(),
max_effective_repeat_count=config.max_effective_repeat_count,
benchmark_feedback_enabled=config.benchmark_feedback_enabled,
benchmark_feedback_path=config.benchmark_feedback_path,
benchmark_feedback_boost_scale=config.benchmark_feedback_boost_scale,
benchmark_feedback_max_multiplier=config.benchmark_feedback_max_multiplier,
)
def _benchmark_feedback_discovery_candidates(config: TrainingConfig) -> list[Path]:
output_dir = Path(config.output_dir)
candidates: list[Path] = []
direct_candidate = output_dir / "benchmark-feedback.json"
if direct_candidate.is_file():
candidates.append(direct_candidate)
search_roots = [output_dir.parent]
if output_dir.parent != output_dir.parent.parent:
search_roots.append(output_dir.parent.parent)
seen: set[Path] = {direct_candidate.resolve()} if direct_candidate.exists() else set()
for root in search_roots:
if not root.exists():
continue
for candidate in root.rglob("benchmark-feedback.json"):
try:
resolved = candidate.resolve()
except OSError:
continue
if resolved in seen or not candidate.is_file():
continue
seen.add(resolved)
candidates.append(candidate)
return candidates
def _discover_benchmark_feedback_path(config: TrainingConfig) -> str:
candidates = _benchmark_feedback_discovery_candidates(config)
if not candidates:
return ""
newest = max(candidates, key=lambda path: path.stat().st_mtime)
return str(newest)
def _load_dataset_benchmark_feedback(config: TrainingConfig) -> DatasetBenchmarkFeedback | None:
if not config.benchmark_feedback_enabled:
return None
feedback_path = config.benchmark_feedback_path
if not feedback_path and config.benchmark_feedback_auto_discover:
feedback_path = _discover_benchmark_feedback_path(config)
if not feedback_path:
return None
discovery_mode = "explicit" if config.benchmark_feedback_path else "auto_discovered"
feedback = load_benchmark_feedback(
feedback_path,
targets=_default_branch_benchmark_targets(config),
boost_scale=config.benchmark_feedback_boost_scale,
max_multiplier=config.benchmark_feedback_max_multiplier,
)
return DatasetBenchmarkFeedback(
artifact_path=feedback.artifact_path,
deficient_metrics=feedback.deficient_metrics,
overall_multiplier=feedback.overall_multiplier,
discovery_mode=discovery_mode,
)
def _build_scoring_dashboard_payload(
*,
scoring_report: DatasetScoringReport | None,
benchmark_feedback: DatasetBenchmarkFeedback | None,
) -> dict[str, Any]:
if scoring_report is None and benchmark_feedback is None:
return {}
payload: dict[str, Any] = {}
if scoring_report is not None:
payload["scoring_dashboard"] = {
split_name: {
"sources": split_report.source_dashboard,
"categories": split_report.category_dashboard,
}
for split_name, split_report in scoring_report.splits.items()
}
payload.update(_build_flattened_scoring_dashboard_metrics(scoring_report))
if benchmark_feedback is not None:
payload["benchmark_feedback"] = build_benchmark_feedback_artifact(benchmark_feedback)
return payload
def _build_flattened_scoring_dashboard_metrics(
scoring_report: DatasetScoringReport | None,
) -> dict[str, float]:
if scoring_report is None:
return {}
flattened: dict[str, float] = {}
for split_name, split_report in scoring_report.splits.items():
for dimension_name, dimension_payload in (
("sources", split_report.source_dashboard),
("categories", split_report.category_dashboard),
):
for label, metrics in dimension_payload.items():
for metric_name, metric_value in metrics.items():
flattened[
f"scoring_dashboard_{split_name}_{dimension_name}_{label}_{metric_name}"
] = float(metric_value)
return flattened
def _apply_quality_gate_to_split(
split: Any,
*,
split_name: str,
config: TrainingConfig,
allow_empty: bool = False,
) -> tuple[Any | None, DatasetQualitySplitReport]:
filtered_records, report = apply_quality_gate_to_records(
_materialize_split_records(split),
split_name=split_name,
config=_dataset_quality_config(config),
)
if not filtered_records:
message = (
f"Dataset quality gate atmeta visus ierakstus splitam '{split_name}'. "
f"Ienākošie={report.input_records}, paturētie={report.kept_records}, "
f"atmestie={report.dropped_records}, dublikāti={report.duplicates_removed}. "
"Pārbaudi īsos, zemas kvalitātes vai dublētos ierakstus."
)
if allow_empty:
logger.warning(
"%s Iemesli=%s; piemēri=%s. Šo splitu izlaidīšu šajā training skrējienā.",
message,
report.reasons,
report.sample_rejections,
)
return None, report
raise ValueError(message)
return _rebuild_split_like(split, filtered_records), report
def _apply_scoring_to_split(
split: Any,
*,
split_name: str,
config: TrainingConfig,
expand_weights: bool,
benchmark_feedback: DatasetBenchmarkFeedback | None,
) -> tuple[Any, DatasetScoringSplitReport]:
scored_records, report = apply_scoring_to_records(
_materialize_split_records(split),
split_name=split_name,
config=_dataset_scoring_config(config),
expand_weights=expand_weights,
benchmark_feedback=benchmark_feedback,
)
if not scored_records:
raise ValueError(
f"Dataset scoring pipeline neatgrieza nevienu ierakstu splitam '{split_name}'. "
f"Ienākošie={report.input_records}, expanded={report.expanded_records}, "
f"source_weighting={config.source_weighting_enabled}, "
f"benchmark_feedback={bool(benchmark_feedback)}."
)
return _rebuild_split_like(split, scored_records), report
def build_branch_training_configs(config: TrainingConfig) -> list[TrainingConfig]:
"""Izveido konfigurāciju katram arhitektūras atzaram."""
base_output_dir = Path(config.output_dir)
branch_specs = (
{
"branch_name": "master",
"branch_focus": "Maris AI realtime_text_chat",
"adapter_type": config.adapter_type,
"model_name": config.model_name,
"output_dir": base_output_dir / "master",
"hub_model_id": config.text_model_id,
"benchmark_name": _resolve_branch_benchmark_name(config, "master"),
"benchmark_dataset_path": _resolve_branch_benchmark_dataset_path(config, "master"),
"benchmark_gate_enabled": True,
"benchmark_min_overall": max(config.benchmark_min_overall, 0.76),
"quality_min_text_chars": max(config.quality_min_text_chars, 12),
"category_weight_map": {
**config.category_weight_map,
"reasoning": max(float(config.category_weight_map.get("reasoning", 1.0)), 1.2),
"helpfulness": max(float(config.category_weight_map.get("helpfulness", 1.0)), 1.15),
"grounding": max(float(config.category_weight_map.get("grounding", 1.0)), 1.15),
"latvian_quality": max(
float(config.category_weight_map.get("latvian_quality", 1.0)),
1.1,
),
},
},
{
"branch_name": "coder",
"branch_focus": "Maris AI code_generation_repo_grounded",
"adapter_type": config.adapter_type,
"model_name": config.model_name,
"output_dir": base_output_dir / "coder",
"hub_model_id": config.hub_model_id,
"benchmark_name": _resolve_branch_benchmark_name(config, "coder"),
"benchmark_dataset_path": _resolve_branch_benchmark_dataset_path(config, "coder"),
"preference_dataset_path": _resolve_branch_preference_dataset_path(config, "coder"),
"benchmark_gate_enabled": True,
"benchmark_min_overall": max(config.benchmark_min_overall, 0.76),
"quality_min_text_chars": max(config.quality_min_text_chars, 18),
"category_weight_map": {
**config.category_weight_map,
"coding": max(float(config.category_weight_map.get("coding", 1.0)), 1.35),
"debugging": max(float(config.category_weight_map.get("debugging", 1.0)), 1.35),
"grounding": max(float(config.category_weight_map.get("grounding", 1.0)), 1.25),
"refactor": max(float(config.category_weight_map.get("refactor", 1.0)), 1.2),
"tests": max(float(config.category_weight_map.get("tests", 1.0)), 1.15),
"unsafe": max(float(config.category_weight_map.get("unsafe", 1.0)), 1.15),
},
},
{
"branch_name": "planner",
"branch_focus": "Maris AI autonomous_tasks",
"adapter_type": config.adapter_type,
"model_name": config.model_name,
"output_dir": base_output_dir / "planner",
"hub_model_id": config.hub_model_id,
"benchmark_name": _resolve_branch_benchmark_name(config, "planner"),
"benchmark_dataset_path": _resolve_branch_benchmark_dataset_path(config, "planner"),
"benchmark_gate_enabled": True,
"benchmark_min_overall": max(config.benchmark_min_overall, 0.76),
"quality_min_text_chars": max(config.quality_min_text_chars, 18),
"category_weight_map": {
**config.category_weight_map,
"reasoning": max(float(config.category_weight_map.get("reasoning", 1.0)), 1.3),
"planning": max(float(config.category_weight_map.get("planning", 1.0)), 1.3),
"grounding": max(float(config.category_weight_map.get("grounding", 1.0)), 1.2),
"safety": max(float(config.category_weight_map.get("safety", 1.0)), 1.1),
},
},
{
"branch_name": "image",
"branch_focus": "Maris AI image_generation",
"adapter_type": "specialist_model",
"model_name": config.image_model_id,
"output_dir": base_output_dir / "image",
"hub_model_id": config.image_model_id,
},
{
"branch_name": "music",
"branch_focus": "Maris AI music_generation",
"adapter_type": "specialist_model",
"model_name": config.music_model_id,
"output_dir": base_output_dir / "music",
"hub_model_id": config.music_model_id,
},
{
"branch_name": "tts",
"branch_focus": "Maris AI realtime_tts",
"adapter_type": "specialist_model",
"model_name": config.tts_model_id,
"output_dir": base_output_dir / "tts",
"hub_model_id": config.tts_model_id,
},
{
"branch_name": "stt",
"branch_focus": "Maris AI realtime_stt",
"adapter_type": "specialist_model",
"model_name": config.stt_model_id,
"output_dir": base_output_dir / "stt",
"hub_model_id": config.stt_model_id,
},
{
"branch_name": "video",
"branch_focus": "Maris AI video_generation",
"adapter_type": "specialist_model",
"model_name": config.video_model_id,
"output_dir": base_output_dir / "video",
"hub_model_id": config.video_model_id,
},
)
return [
TrainingConfig(
model_name=str(branch["model_name"]),
branch_name=str(branch["branch_name"]),
branch_focus=str(branch["branch_focus"]),
adapter_type=str(branch["adapter_type"]),
lora_r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
lora_bias=config.lora_bias,
peft_target_modules=list(config.peft_target_modules),
qlora_quant_type=config.qlora_quant_type,
qlora_use_double_quant=config.qlora_use_double_quant,
qlora_compute_dtype=config.qlora_compute_dtype,
dataset_repo=config.dataset_repo,
eval_dataset_repo=config.eval_dataset_repo,
output_dir=str(branch["output_dir"]),
hub_model_id=str(branch["hub_model_id"]),
text_model_id=config.text_model_id,
image_model_id=config.image_model_id,
music_model_id=config.music_model_id,
tts_model_id=config.tts_model_id,
stt_model_id=config.stt_model_id,
video_model_id=config.video_model_id,
num_epochs=config.num_epochs,
learning_rate=config.learning_rate,
per_device_train_batch_size=config.per_device_train_batch_size,
per_device_eval_batch_size=config.per_device_eval_batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
warmup_ratio=config.warmup_ratio,
weight_decay=config.weight_decay,
logging_steps=config.logging_steps,
save_steps=config.save_steps,
eval_steps=config.eval_steps,
save_total_limit=config.save_total_limit,
max_seq_length=config.max_seq_length,
validation_split_ratio=config.validation_split_ratio,
seed=config.seed,
fp16=config.fp16,
bf16=config.bf16,
gradient_checkpointing=config.gradient_checkpointing,
gradient_checkpointing_use_reentrant=config.gradient_checkpointing_use_reentrant,
report_to=list(config.report_to),
push_to_hub=config.push_to_hub,
save_safetensors=config.save_safetensors,
lr_scheduler_type=config.lr_scheduler_type,
benchmark_dataset_path=str(
branch.get("benchmark_dataset_path", config.benchmark_dataset_path) or ""
),
benchmark_name=str(branch.get("benchmark_name", config.benchmark_name)),
benchmark_levels=list(config.benchmark_levels),
benchmark_min_overall=float(
branch.get("benchmark_min_overall", config.benchmark_min_overall)
),
benchmark_gate_enabled=bool(
branch.get("benchmark_gate_enabled", config.benchmark_gate_enabled)
),
branch_benchmark_targets={
key: value.copy() for key, value in config.branch_benchmark_targets.items()
},
branch_benchmark_names=config.branch_benchmark_names.copy(),
branch_benchmark_dataset_paths=config.branch_benchmark_dataset_paths.copy(),
branch_preference_dataset_paths=config.branch_preference_dataset_paths.copy(),
branch_dataset_filter_rules={
key: value.copy() for key, value in config.branch_dataset_filter_rules.items()
},
preference_dataset_path=str(
branch.get("preference_dataset_path", config.preference_dataset_path) or ""
),
preference_optimization=config.preference_optimization,
preference_beta=config.preference_beta,
preference_max_prompt_length=config.preference_max_prompt_length,
preference_max_length=config.preference_max_length,
preference_reference_model=config.preference_reference_model,
quality_gate_enabled=config.quality_gate_enabled,
dedupe_enabled=config.dedupe_enabled,
quality_min_text_chars=int(
branch.get("quality_min_text_chars", config.quality_min_text_chars)
),
scoring_enabled=config.scoring_enabled,
weighted_repetition_enabled=config.weighted_repetition_enabled,
medium_score_repeat_count=config.medium_score_repeat_count,
high_score_repeat_count=config.high_score_repeat_count,
source_weighting_enabled=config.source_weighting_enabled,
source_weight_map=config.source_weight_map.copy(),
category_weight_map=dict(branch.get("category_weight_map", config.category_weight_map)),
max_effective_repeat_count=config.max_effective_repeat_count,
benchmark_feedback_enabled=config.benchmark_feedback_enabled,
benchmark_feedback_auto_discover=config.benchmark_feedback_auto_discover,
benchmark_feedback_path=config.benchmark_feedback_path,
benchmark_feedback_boost_scale=config.benchmark_feedback_boost_scale,
benchmark_feedback_max_multiplier=config.benchmark_feedback_max_multiplier,
continue_from_latest_artifact=config.continue_from_latest_artifact,
continue_model_path=config.continue_model_path,
)
for branch in branch_specs
]
def _augment_metrics(metrics: dict[str, float], *, eval_dataset: Any | None) -> dict[str, float]:
enriched = dict(metrics)
if eval_dataset is not None:
enriched["eval_examples"] = float(len(eval_dataset))
if "eval_loss" in enriched:
# Cap extremely large loss values before exponentiation to avoid overflow.
# math.exp(20) is already ~4.85e8, which is a generous upper bound for a
# diagnostic perplexity metric stored in training artifacts.
enriched["perplexity"] = float(
math.exp(min(enriched["eval_loss"], MAX_LOSS_FOR_PERPLEXITY))
)
return enriched
def _default_branch_benchmark_targets(config: TrainingConfig) -> dict[str, float]:
return config.branch_benchmark_targets.get(config.branch_name) or {
"overall": config.benchmark_min_overall
}
def _build_benchmark_gate_artifact(
config: TrainingConfig,
benchmark_manifest: dict[str, Any],
*,
regression_report: dict[str, Any] | None = None,
) -> dict[str, Any]:
targets = _augment_release_gate_targets(
_default_branch_benchmark_targets(config),
benchmark_manifest,
)
failed_metrics = {
metric: {
"required": target,
"actual": _benchmark_metric_value(benchmark_manifest, metric),
}
for metric, target in targets.items()
if metric == "overall" or _benchmark_metric_present(benchmark_manifest, metric)
if _benchmark_metric_value(benchmark_manifest, metric) < float(target)
}
regression_count = (
int(regression_report.get("regression_count", 0) or 0) if regression_report else 0
)
if regression_count > 0:
failed_metrics["regression_count"] = {"required": 0.0, "actual": float(regression_count)}
return {
"artifact_type": "benchmark-release-gate",
**_build_model_artifact_identity(config),
"benchmark_name": benchmark_manifest.get("benchmark_name", config.benchmark_name),
"passed": not failed_metrics,
"targets": targets,
"failed_metrics": failed_metrics,
"regression_policy": {
"allow_regressions": False,
"regression_count": regression_count,
},
"levels_checked": config.benchmark_levels,
}
def _augment_release_gate_targets(
targets: dict[str, float],
benchmark_manifest: dict[str, Any],
) -> dict[str, float]:
augmented = dict(targets)
score_manifest = benchmark_manifest.get("score_manifest", {})
critical_defaults = {
"success_rate": 0.85,
"grounding": 0.75,
"safety": 0.9,
"judge_overall": 0.72,
"judge_task_completion": 0.72,
"judge_instruction_following": 0.74,
"judge_safety": 0.9,
"judge_regression_risk": 0.72,
"pairwise_win_rate": 0.55,
}
if "long_context" in score_manifest:
critical_defaults["long_context"] = max(
float(augmented.get("long_context", 0.0) or 0.0), 0.72
)
if int(benchmark_manifest.get("execution_cases", 0) or 0) > 0:
critical_defaults["execution"] = 0.7
if int(benchmark_manifest.get("grounding_cases", 0) or 0) > 0:
critical_defaults["grounding"] = 0.75
if int(benchmark_manifest.get("production_like_cases", 0) or 0) > 0:
critical_defaults["production_like_pass_rate"] = 0.75
for metric, threshold in critical_defaults.items():
if _benchmark_metric_present(benchmark_manifest, metric):
augmented.setdefault(metric, threshold)
return augmented
def _benchmark_metric_present(benchmark_manifest: dict[str, Any], metric: str) -> bool:
score_manifest = benchmark_manifest.get("score_manifest", {})
return metric in score_manifest or metric in benchmark_manifest
def _benchmark_metric_value(benchmark_manifest: dict[str, Any], metric: str) -> float:
score_manifest = benchmark_manifest.get("score_manifest", {})
raw_value = (
score_manifest.get(metric) if metric in score_manifest else benchmark_manifest.get(metric)
)
return float(raw_value or 0.0)
def _same_benchmark_reference(left: dict[str, Any] | None, right: dict[str, Any]) -> bool:
if not isinstance(left, dict):
return False
return (
str(left.get("benchmark_name", "")).strip() == str(right.get("benchmark_name", "")).strip()
and str(left.get("branch", "")).strip() == str(right.get("branch", "")).strip()
and str(left.get("model", "")).strip() == str(right.get("model", "")).strip()
and str(left.get("generated_at", "")).strip() == str(right.get("generated_at", "")).strip()
)
def _select_previous_benchmark_baseline(
*,
current_manifest: dict[str, Any],
previous_history: dict[str, Any] | None,
previous_manifest: dict[str, Any] | None,
) -> dict[str, Any] | None:
runs = previous_history.get("runs") if isinstance(previous_history, dict) else None
if isinstance(runs, list):
run_objects = [dict(item) for item in runs if isinstance(item, dict)]
if run_objects:
last_run = run_objects[-1]
if _same_benchmark_reference(last_run, current_manifest):
return run_objects[-2] if len(run_objects) > 1 else None
return last_run
if _same_benchmark_reference(previous_manifest, current_manifest):
return None
return previous_manifest
def _write_benchmark_artifacts(
output_dir: Path,
config: TrainingConfig,
benchmark_manifest: dict[str, Any],
) -> tuple[dict[str, Any], dict[str, Any]]:
benchmark_manifest = _attach_human_eval_summary_to_manifest(config, benchmark_manifest)
previous_history = _load_json_if_exists(output_dir / "benchmark-history.json")
previous_manifest = _load_json_if_exists(output_dir / "benchmark-manifest.json")
regression_report = build_chat_benchmark_regression_report(
benchmark_manifest,
previous_run=_select_previous_benchmark_baseline(
current_manifest=benchmark_manifest,
previous_history=previous_history,
previous_manifest=previous_manifest,
),
)
benchmark_gate = _build_benchmark_gate_artifact(
config,
benchmark_manifest,
regression_report=regression_report,
)
benchmark_history = build_chat_benchmark_history_artifact(
benchmark_manifest,
previous_history=previous_history,
)
_save_json(output_dir / "benchmark-manifest.json", benchmark_manifest)
_save_json(output_dir / "benchmark-release-gate.json", benchmark_gate)
_save_json(output_dir / "benchmark-history.json", benchmark_history)
_save_json(output_dir / "benchmark-regression-report.json", regression_report)
_save_json(
output_dir / "benchmark-feedback.json",
build_benchmark_feedback_artifact(
load_benchmark_feedback(
output_dir / "benchmark-manifest.json",
targets=_default_branch_benchmark_targets(config),
boost_scale=config.benchmark_feedback_boost_scale,
max_multiplier=config.benchmark_feedback_max_multiplier,
)
),
)
return benchmark_gate, regression_report
def _attach_human_eval_summary_to_manifest(
config: TrainingConfig,
benchmark_manifest: dict[str, Any],
) -> dict[str, Any]:
if not config.preference_dataset_path:
return benchmark_manifest
examples = _load_filtered_preference_examples(config)
human_eval_summary = build_human_eval_summary(examples)
manifest = dict(benchmark_manifest)
manifest["human_eval_summary"] = human_eval_summary
score_manifest = dict(manifest.get("score_manifest", {}))
score_manifest["pairwise_win_rate"] = round(
float(human_eval_summary.get("pairwise_win_rate", 0.0) or 0.0),
3,
)
score_manifest["human_eval_confidence"] = round(
float(human_eval_summary.get("average_confidence", 0.0) or 0.0),
3,
)
manifest["score_manifest"] = score_manifest
return manifest
async def _run_post_training_benchmark(
config: TrainingConfig,
*,
model_path: str,
) -> dict[str, Any] | None:
if not config.benchmark_dataset_path:
return None
from transformers import pipeline # type: ignore
cases = load_chat_benchmark_dataset(config.benchmark_dataset_path)
selected_cases = select_chat_benchmark_cases(
cases,
levels=config.benchmark_levels,
branch=config.branch_name,
)
if not selected_cases:
raise ValueError("Benchmark datasetā nav neviena case izvēlētajiem benchmark leveliem.")
pipe = pipeline(
"text-generation",
model=model_path,
tokenizer=model_path,
trust_remote_code=True,
)
async def responder(case: Any) -> dict[str, Any]:
effective_profile = case.profile or {
"coder": "coder",
"planner": "planner",
}.get(config.branch_name, "general")
messages = [
{
"role": "system",
"content": build_system_prompt(
effective_profile,
persona_id=case.persona_id,
),
},
*list(case.history),
{"role": "user", "content": case.message},
]
raw_output = await asyncio.to_thread(
call_generation_pipeline,
pipe,
messages,
max_new_tokens=512,
temperature=0.0,
)
response_text = _extract_response_text(raw_output)
return {
"response": response_text,
"model": config.hub_model_id,
"tokens_used": _extract_usage_tokens(raw_output) or 0,
"persona_title": "Core Assistant",
}
results = await run_chat_benchmark_with_responder(
selected_cases,
responder=responder,
concurrency=1,
)
return build_chat_benchmark_manifest(
results,
benchmark_name=config.benchmark_name,
branch=config.branch_name,
model=config.hub_model_id,
)
def train_with_config(config: TrainingConfig) -> dict[str, float]:
"""Apmāca modeli pēc pilnas konfigurācijas."""
config = _apply_branch_runtime_defaults(_normalize_training_runtime_config(config))
_ensure_runtime_home_dir()
training_dataset_repos = _resolve_training_dataset_repos(config)
eval_dataset_repos = _resolve_eval_dataset_repos(config)
try:
_emit_training_progress_event(
"prepare_dataset",
stage="preparing",
label="Ielādē training datasetu",
total_epochs=config.num_epochs,
dataset_repo=config.dataset_repo,
dataset_repos=training_dataset_repos,
)
logger.info("Ielādē training datasetus: %s", ", ".join(training_dataset_repos))
dataset = _load_combined_hf_dataset(training_dataset_repos)
except HFDatasetError as exc:
logger.error("Apmācība apturēta: %s", exc)
raise SystemExit(str(exc)) from None
except Exception as exc: # noqa: BLE001
logger.error("Apmācības kļūda: %s", exc)
raise
try:
from transformers import ( # type: ignore
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
try:
from transformers import TrainerCallback # type: ignore
except ImportError: # pragma: no cover - fallback for lightweight test doubles
class TrainerCallback: # type: ignore[no-redef]
pass
training_model_source = _resolve_training_model_source(config)
_emit_training_progress_event(
"prepare_model",
stage="preparing",
label="Ielādē tokenizeri un modeli",
total_epochs=config.num_epochs,
model_source=training_model_source,
)
logger.info("Ielādē modeli: %s", training_model_source)
tokenizer = _load_tokenizer(training_model_source, config)
_configure_tokenizer(tokenizer, config)
model = _prepare_training_model(training_model_source, tokenizer, config)
train_split, eval_split = _prepare_train_eval_splits(dataset, config)
if eval_dataset_repos:
_emit_training_progress_event(
"prepare_eval_dataset",
stage="preparing",
label="Ielādē eval datasetu",
total_epochs=config.num_epochs,
dataset_repo=config.eval_dataset_repo or eval_dataset_repos[0],
dataset_repos=eval_dataset_repos,
)
logger.info("Ielādē atsevišķus eval datasetus: %s", ", ".join(eval_dataset_repos))
eval_source_dataset = _load_combined_hf_dataset(eval_dataset_repos)
eval_split = _select_eval_split(
eval_source_dataset,
config,
allow_train_fallback=True,
)
train_split, train_branch_filter_report = _apply_branch_dataset_filter_to_split(
train_split,
split_name="train",
config=config,
)
eval_branch_filter_report: DatasetBranchFilterReport | None = None
if eval_split is not None:
eval_split, eval_branch_filter_report = _apply_branch_dataset_filter_to_split(
eval_split,
split_name="eval",
config=config,
allow_empty=True,
)
train_split, train_quality_report = _apply_quality_gate_to_split(
train_split,
split_name="train",
config=config,
)
benchmark_feedback = _load_dataset_benchmark_feedback(config)
train_split, train_scoring_report = _apply_scoring_to_split(
train_split,
split_name="train",
config=config,
expand_weights=True,
benchmark_feedback=benchmark_feedback,
)
eval_quality_report: DatasetQualitySplitReport | None = None
eval_scoring_report: DatasetScoringSplitReport | None = None
if eval_split is not None:
eval_split, eval_quality_report = _apply_quality_gate_to_split(
eval_split,
split_name="eval",
config=config,
allow_empty=True,
)
if eval_split is not None:
eval_split, eval_scoring_report = _apply_scoring_to_split(
eval_split,
split_name="eval",
config=config,
expand_weights=False,
benchmark_feedback=benchmark_feedback,
)
_emit_training_progress_event(
"prepare_tokenization",
stage="preparing",
label="Tokenizē treniņa un eval datus",
total_epochs=config.num_epochs,
)
train_dataset = _tokenize_dataset(train_split, tokenizer, config.max_seq_length)
eval_dataset = (
_tokenize_dataset(eval_split, tokenizer, config.max_seq_length)
if eval_split is not None
else None
)
_emit_training_progress_event(
"prepare_runtime",
stage="preparing",
label="Sagatavo treneri un runtime",
total_epochs=config.num_epochs,
)
training_args = _build_training_arguments(
TrainingArguments,
output_dir=config.output_dir,
overwrite_output_dir=True,
num_train_epochs=config.num_epochs,
learning_rate=config.learning_rate,
per_device_train_batch_size=config.per_device_train_batch_size,
per_device_eval_batch_size=config.per_device_eval_batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
warmup_ratio=config.warmup_ratio,
weight_decay=config.weight_decay,
logging_steps=config.logging_steps,
save_steps=config.save_steps,
eval_steps=config.eval_steps,
save_total_limit=config.save_total_limit,
lr_scheduler_type=config.lr_scheduler_type,
seed=config.seed,
fp16=config.fp16,
bf16=config.bf16,
report_to=config.report_to,
save_safetensors=config.save_safetensors,
remove_unused_columns=False,
evaluation_strategy="steps" if eval_dataset is not None else "no",
load_best_model_at_end=eval_dataset is not None,
metric_for_best_model="eval_loss" if eval_dataset is not None else None,
greater_is_better=False if eval_dataset is not None else None,
**_build_runtime_training_argument_overrides(),
**_build_distributed_training_argument_overrides(config),
)
class _SpaceTrainingProgressCallback(MarisTrainingProgressCallback, TrainerCallback):
pass
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
)
if hasattr(trainer, "add_callback"):
trainer.add_callback(_SpaceTrainingProgressCallback(total_epochs=config.num_epochs))
logger.info("Sāk apmācību...")
train_result = trainer.train()
metrics = dict(getattr(train_result, "metrics", {}))
if eval_dataset is not None:
metrics.update(trainer.evaluate())
metrics = _augment_metrics(metrics, eval_dataset=eval_dataset)
train_examples = len(train_dataset)
eval_examples = len(eval_dataset) if eval_dataset is not None else 0
trained_at = datetime.now(UTC).isoformat()
output_dir = Path(config.output_dir)
if _uses_preference_optimization(config):
trainer.save_model(str(output_dir))
tokenizer.save_pretrained(str(output_dir))
model, preference_trainer, preference_metrics = _run_preference_optimization(
model,
tokenizer,
config,
output_dir=output_dir,
)
metrics.update(preference_metrics)
if preference_trainer is not None:
trainer = preference_trainer
quality_report = build_dataset_quality_report(
config=_dataset_quality_config(config),
train_report=train_quality_report,
eval_report=eval_quality_report,
)
scoring_report = build_dataset_scoring_report(
config=_dataset_scoring_config(config),
train_report=train_scoring_report,
eval_report=eval_scoring_report,
)
metrics["branch_filter_train_kept"] = float(train_branch_filter_report.kept_records)
metrics["branch_filter_train_dropped"] = float(train_branch_filter_report.dropped_records)
metrics["quality_train_kept"] = float(train_quality_report.kept_records)
metrics["quality_train_dropped"] = float(train_quality_report.dropped_records)
metrics["quality_train_duplicates_removed"] = float(train_quality_report.duplicates_removed)
metrics["scoring_train_average_score"] = float(train_scoring_report.average_score)
metrics["scoring_train_expanded_records"] = float(train_scoring_report.expanded_records)
metrics["scoring_train_repeated_records"] = float(train_scoring_report.repeated_records)
metrics["scoring_train_average_repeat_multiplier"] = float(
train_scoring_report.average_repeat_multiplier
)
metrics["scoring_train_feedback_boosted_records"] = float(
train_scoring_report.feedback_boosted_records
)
if eval_quality_report is not None:
if eval_branch_filter_report is not None:
metrics["branch_filter_eval_kept"] = float(eval_branch_filter_report.kept_records)
metrics["branch_filter_eval_dropped"] = float(
eval_branch_filter_report.dropped_records
)
metrics["quality_eval_kept"] = float(eval_quality_report.kept_records)
metrics["quality_eval_dropped"] = float(eval_quality_report.dropped_records)
metrics["quality_eval_duplicates_removed"] = float(
eval_quality_report.duplicates_removed
)
metrics["quality_eval_skipped"] = float(eval_split is None)
elif eval_branch_filter_report is not None:
metrics["branch_filter_eval_kept"] = float(eval_branch_filter_report.kept_records)
metrics["branch_filter_eval_dropped"] = float(eval_branch_filter_report.dropped_records)
if eval_scoring_report is not None:
metrics["scoring_eval_average_score"] = float(eval_scoring_report.average_score)
metrics["scoring_eval_expanded_records"] = float(eval_scoring_report.expanded_records)
metrics["scoring_eval_feedback_boosted_records"] = float(
eval_scoring_report.feedback_boosted_records
)
trainer.save_model(str(output_dir))
tokenizer.save_pretrained(str(output_dir))
_write_training_artifacts(
output_dir,
config,
trained_at=trained_at,
train_examples=train_examples,
eval_examples=eval_examples,
metrics=metrics,
quality_report=quality_report,
scoring_report=scoring_report,
benchmark_feedback=benchmark_feedback,
)
_write_preference_artifact(output_dir, config)
benchmark_manifest = (
asyncio.run(_run_post_training_benchmark(config, model_path=str(output_dir)))
if config.benchmark_dataset_path
else None
)
if benchmark_manifest is not None:
_emit_training_progress_event(
"benchmark",
stage="benchmarking",
label=_build_training_progress_label(stage="benchmarking"),
output_dir=str(output_dir),
)
benchmark_gate, regression_report = _write_benchmark_artifacts(
output_dir,
config,
benchmark_manifest,
)
metrics["benchmark_overall"] = float(
benchmark_manifest.get("score_manifest", {}).get("overall", 0.0)
)
metrics["benchmark_gate_passed"] = 1.0 if benchmark_gate["passed"] else 0.0
metrics["benchmark_regressions"] = float(regression_report.get("regression_count", 0))
_save_json(
output_dir / "training-metrics.json",
_build_metrics_artifact(
config,
metrics,
artifact_type="training-metrics",
extra_payload=_build_scoring_dashboard_payload(
scoring_report=scoring_report,
benchmark_feedback=benchmark_feedback,
),
),
)
if config.benchmark_gate_enabled and not benchmark_gate["passed"]:
failure_details = ", ".join(
f"{metric}: {details['actual']:.3f} < {details['required']:.3f}"
for metric, details in sorted(benchmark_gate["failed_metrics"].items())
)
raise ValueError("Benchmark release gate neizgāja: " + failure_details)
_sanitize_saved_artifacts(output_dir, maris_model_id=config.hub_model_id)
_verify_saved_artifacts(output_dir)
if config.push_to_hub and hasattr(trainer, "push_to_hub"):
_emit_training_progress_event(
"publish",
stage="publishing",
label=_build_training_progress_label(stage="publishing"),
output_dir=str(output_dir),
)
trainer.push_to_hub(commit_message=f"Maris AI training sync ({config.branch_name})")
_write_training_artifacts(
output_dir,
config,
trained_at=trained_at,
train_examples=train_examples,
eval_examples=eval_examples,
metrics=metrics,
quality_report=quality_report,
scoring_report=scoring_report,
benchmark_feedback=benchmark_feedback,
)
if benchmark_manifest is not None:
_write_benchmark_artifacts(output_dir, config, benchmark_manifest)
_sanitize_saved_artifacts(output_dir, maris_model_id=config.hub_model_id)
_verify_saved_artifacts(output_dir)
_sync_output_dir_to_hub(
output_dir,
config,
commit_message=f"Maris AI artifact sync ({config.branch_name})",
)
logger.info("Apmācība pabeigta. Modelis saglabāts: %s", output_dir)
return metrics
except Exception as exc: # noqa: BLE001
logger.error("Apmācības kļūda: %s", exc)
raise
def train_branch_suite(config: TrainingConfig) -> dict[str, dict[str, Any]]:
"""Palaiž branch-oriented training pipeline visiem Maris atzariem."""
suite_results: dict[str, dict[str, Any]] = {}
base_output_dir = Path(config.output_dir)
for branch_config in build_branch_training_configs(config):
if branch_config.branch_name in TEXT_TRAINABLE_BRANCHES:
metrics = train_with_config(branch_config)
suite_results[branch_config.branch_name] = {
"status": "trained",
"maris_model_id": branch_config.hub_model_id,
"branch_focus": branch_config.branch_focus,
"output_dir": branch_config.output_dir,
"metrics": metrics,
"dataset_repo": branch_config.dataset_repo,
"maris_origin": MARIS_ORIGIN_NAME,
"maris_framework": MARIS_FRAMEWORK_NAME,
}
continue
branch_output_dir = Path(branch_config.output_dir)
branch_output_dir.mkdir(parents=True, exist_ok=True)
branch_manifest = {
"status": "external_specialist",
"maris_model_id": branch_config.hub_model_id,
"branch_focus": branch_config.branch_focus,
"output_dir": branch_config.output_dir,
"reason": "Šim atzaram nepieciešams specializēts multimodāls treniņa pipeline ārpus CausalLM trainer.",
"maris_origin": MARIS_ORIGIN_NAME,
"maris_framework": MARIS_FRAMEWORK_NAME,
}
_save_json(branch_output_dir / "branch-config.json", branch_manifest)
suite_results[branch_config.branch_name] = branch_manifest
_save_json(
base_output_dir / "branch-suite.json",
_build_branch_suite_artifact(config, suite_results),
)
return suite_results
def evaluate_with_config(
config: TrainingConfig,
*,
model_path: str | None = None,
) -> dict[str, float]:
"""Novērtē modeli ar to pašu datu pipeline, ko izmanto apmācībai."""
config = _apply_branch_runtime_defaults(config)
_ensure_runtime_home_dir()
eval_dataset_repos = _resolve_eval_dataset_repos(config)
dataset_repos = eval_dataset_repos or _resolve_training_dataset_repos(config)
try:
dataset = _load_combined_hf_dataset(dataset_repos)
except HFDatasetError as exc:
logger.error("Novērtēšana apturēta: %s", exc)
raise SystemExit(str(exc)) from None
from transformers import ( # type: ignore
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
resolved_model = model_path or config.output_dir
tokenizer = _load_tokenizer(resolved_model, config)
_configure_tokenizer(tokenizer, config)
model = _load_model(resolved_model, config)
if eval_dataset_repos:
reference_split = _select_eval_split(dataset, config, allow_train_fallback=True)
else:
train_split, eval_split = _prepare_train_eval_splits(dataset, config)
reference_split = eval_split or train_split
reference_split, _ = _apply_branch_dataset_filter_to_split(
reference_split,
split_name="evaluation",
config=config,
)
reference_split, _ = _apply_quality_gate_to_split(
reference_split,
split_name="evaluation",
config=config,
)
reference_split, _ = _apply_scoring_to_split(
reference_split,
split_name="evaluation",
config=config,
expand_weights=False,
benchmark_feedback=_load_dataset_benchmark_feedback(config),
)
eval_dataset = _tokenize_dataset(reference_split, tokenizer, config.max_seq_length)
trainer = Trainer(
model=model,
args=_build_training_arguments(
TrainingArguments,
output_dir=config.output_dir,
per_device_eval_batch_size=config.per_device_eval_batch_size,
report_to=[],
remove_unused_columns=False,
**_build_runtime_training_argument_overrides(),
**_build_distributed_training_argument_overrides(config),
),
eval_dataset=eval_dataset,
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
)
metrics = _augment_metrics(dict(trainer.evaluate()), eval_dataset=eval_dataset)
output_dir = Path(config.output_dir)
_save_json(
output_dir / "evaluation-metrics.json",
_build_metrics_artifact(config, metrics, artifact_type="evaluation-metrics"),
)
if config.benchmark_dataset_path:
benchmark_manifest = asyncio.run(
_run_post_training_benchmark(config, model_path=str(resolved_model))
)
if benchmark_manifest is not None:
benchmark_gate, regression_report = _write_benchmark_artifacts(
output_dir,
config,
benchmark_manifest,
)
metrics["benchmark_overall"] = float(
benchmark_manifest.get("score_manifest", {}).get("overall", 0.0)
)
metrics["benchmark_gate_passed"] = 1.0 if benchmark_gate["passed"] else 0.0
metrics["benchmark_regressions"] = float(regression_report.get("regression_count", 0))
_save_json(
output_dir / "evaluation-metrics.json",
_build_metrics_artifact(config, metrics, artifact_type="evaluation-metrics"),
)
_sanitize_saved_artifacts(output_dir, maris_model_id=config.hub_model_id)
_verify_saved_artifacts(output_dir)
return metrics
def train(
config_path: str | None = None,
model_name: str | None = None,
dataset_repo: str | None = None,
dataset_repos: str | list[str] | None = None,
eval_dataset_repo: str | None = None,
eval_dataset_repos: str | list[str] | None = None,
benchmark_dataset_path: str | None = None,
benchmark_feedback_path: str | None = None,
preference_dataset_path: str | None = None,
preference_optimization: str | None = None,
preference_beta: float | None = None,
preference_max_prompt_length: int | None = None,
preference_max_length: int | None = None,
preference_reference_model: str | None = None,
output_dir: str | None = None,
num_epochs: int | None = None,
learning_rate: float | None = None,
max_seq_length: int | None = None,
validation_split_ratio: float | None = None,
adapter_type: str | None = None,
lora_r: int | None = None,
lora_alpha: int | None = None,
lora_dropout: float | None = None,
lora_bias: str | None = None,
peft_target_modules: str | list[str] | None = None,
qlora_quant_type: str | None = None,
qlora_use_double_quant: bool | None = None,
qlora_compute_dtype: str | None = None,
distributed_strategy: str | None = None,
distributed_config_path: str | None = None,
use_accelerate: bool | None = None,
accelerate_config_path: str | None = None,
num_processes: int | None = None,
num_machines: int | None = None,
machine_rank: int | None = None,
main_process_ip: str | None = None,
main_process_port: int | None = None,
fsdp_transformer_layer_cls_to_wrap: str | list[str] | None = None,
fsdp_min_num_params: int | None = None,
continue_from_latest_artifact: bool | None = None,
continue_model_path: str | None = None,
) -> dict[str, float]:
"""Apmāca modeli ar HuggingFace datasets datiem."""
config = load_training_config(
config_path,
overrides={
"model_name": model_name,
"dataset_repo": dataset_repo,
"dataset_repos": dataset_repos,
"eval_dataset_repo": eval_dataset_repo,
"eval_dataset_repos": eval_dataset_repos,
"benchmark_dataset_path": benchmark_dataset_path,
"benchmark_feedback_path": benchmark_feedback_path,
"preference_dataset_path": preference_dataset_path,
"preference_optimization": preference_optimization,
"preference_beta": preference_beta,
"preference_max_prompt_length": preference_max_prompt_length,
"preference_max_length": preference_max_length,
"preference_reference_model": preference_reference_model,
"output_dir": output_dir,
"num_epochs": num_epochs,
"learning_rate": learning_rate,
"max_seq_length": max_seq_length,
"validation_split_ratio": validation_split_ratio,
"adapter_type": adapter_type,
"lora_r": lora_r,
"lora_alpha": lora_alpha,
"lora_dropout": lora_dropout,
"lora_bias": lora_bias,
"peft_target_modules": peft_target_modules,
"qlora_quant_type": qlora_quant_type,
"qlora_use_double_quant": qlora_use_double_quant,
"qlora_compute_dtype": qlora_compute_dtype,
"distributed_strategy": distributed_strategy,
"distributed_config_path": distributed_config_path,
"use_accelerate": use_accelerate,
"accelerate_config_path": accelerate_config_path,
"num_processes": num_processes,
"num_machines": num_machines,
"machine_rank": machine_rank,
"main_process_ip": main_process_ip,
"main_process_port": main_process_port,
"fsdp_transformer_layer_cls_to_wrap": fsdp_transformer_layer_cls_to_wrap,
"fsdp_min_num_params": fsdp_min_num_params,
"continue_from_latest_artifact": continue_from_latest_artifact,
"continue_model_path": continue_model_path,
},
)
return train_with_config(config)