"""Konfigurācija Maris apmācības pipeline'iem.""" from __future__ import annotations import json import logging import re from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Any from maris_core.utils.env import get_env_any, validate_maris_model, validate_maris_repo logger = logging.getLogger(__name__) EXTRA_MODEL_SPLIT_RE = re.compile(r"[\n,;]+") EXTRA_MODEL_KEY_SANITIZE_RE = re.compile(r"[^a-z0-9]+") DEFAULT_TRAINING_BASE_MODEL = "MarisUK/maris-ai-master" DEFAULT_MASTER_MODEL_REPO = "MarisUK/maris-ai-master" DEFAULT_TEXT_MODEL_REPO = "MarisUK/maris-ai-text" DEFAULT_IMAGE_MODEL_REPO = "MarisUK/maris-ai-image" DEFAULT_MUSIC_MODEL_REPO = "MarisUK/maris-ai-music" DEFAULT_TTS_MODEL_REPO = "MarisUK/maris-tts-runtime" DEFAULT_STT_MODEL_REPO = "MarisUK/maris-stt-runtime" DEFAULT_VIDEO_MODEL_REPO = "MarisUK/maris-ai-video" DEFAULT_PRIMARY_TRAINING_DATASET_REPO = "MarisUK/maris-ai-memory" DEFAULT_TRAINING_DATASET_REPOS: list[str] = [ "MarisUK/maris-ai-memory", DEFAULT_PRIMARY_TRAINING_DATASET_REPO, "MarisUK/maris-ai-evals", "MarisUK/maris-ai-benchmark", ] DEFAULT_EVAL_DATASET_REPOS: list[str] = [ "MarisUK/maris-ai-evals", "MarisUK/maris-ai-benchmark", ] AVAILABLE_TRAINING_BASE_MODELS: dict[str, dict[str, str]] = { "balanced": { "model_name": DEFAULT_TRAINING_BASE_MODEL, "label": "Balanced default", "description": "Galvenais Maris master modelis pilnam instruction fine-tuning ciklam.", }, "reasoning": { "model_name": DEFAULT_TRAINING_BASE_MODEL, "label": "Master reasoning", "description": "Maris master modelis uzdevumiem, kur svarīga vispārējā reasoning kvalitāte.", }, "coding": { "model_name": "MarisUK/maris-ai-text", "label": "Text specialist", "description": "Maris text modelis čata, instrukciju un tehniska teksta fine-tuning skrējieniem.", }, "lightweight": { "model_name": "MarisUK/maris-ai-text", "label": "Lean text runtime", "description": "Teksta runtime modelis lētākiem vai ātrākiem eksperimentiem Maris ekosistēmā.", }, } DEFAULT_BENCHMARK_NAME = "chat-quality" DEFAULT_BRANCH_BENCHMARK_NAMES: dict[str, str] = { "master": "memory-quality", "coder": "coder-release-quality", "planner": "planner-release-quality", } DEFAULT_BRANCH_BENCHMARK_TARGETS: dict[str, dict[str, float]] = { "master": { "overall": 0.78, "reasoning": 0.74, "factuality": 0.72, "helpfulness": 0.76, "latvian_quality": 0.74, "memory_retrieval_pass_rate": 0.8, "memory_multi_turn_continuity": 0.74, "memory_cross_session_recall": 0.72, "memory_user_preferences_recall": 0.76, "memory_cross_lingual_retrieval": 0.72, "memory_stale_memory_rejection": 0.8, "production_like_pass_rate": 0.75, }, "coder": { "overall": 0.74, "coding": 0.78, "reasoning": 0.72, "execution": 0.7, "grounding": 0.74, "safety": 0.9, "production_like_pass_rate": 0.75, }, "planner": { "overall": 0.76, "reasoning": 0.77, "helpfulness": 0.74, "long_context": 0.72, "grounding": 0.72, "safety": 0.9, "production_like_pass_rate": 0.75, }, } DEFAULT_SOURCE_WEIGHT_MAP: dict[str, float] = { "production": 1.3, "synthetic": 1.0, "noisy": 0.65, "unknown": 1.0, } DEFAULT_CATEGORY_WEIGHT_MAP: dict[str, float] = {} EVALS_DIR = Path(__file__).resolve().parents[2] / "evals" DEFAULT_BRANCH_DATASET_FILTER_RULES_PATH = EVALS_DIR / "branch_dataset_filter_rules.json" def _load_default_branch_config() -> dict[str, Any]: raw = json.loads(DEFAULT_BRANCH_DATASET_FILTER_RULES_PATH.read_text(encoding="utf-8")) if not isinstance(raw, dict): raise ValueError("Branch dataset defaults failam jābūt JSON objektam.") return raw def _resolve_default_branch_config_path(path: Any) -> str: candidate = str(path or "").strip() if not candidate: return "" resolved = Path(candidate) if not resolved.is_absolute(): resolved = DEFAULT_BRANCH_DATASET_FILTER_RULES_PATH.parent / resolved return str(resolved.resolve()) def _normalize_default_branch_path_map(value: Any) -> dict[str, str]: if not isinstance(value, dict): raise ValueError("Branch path defaults jābūt objektam ar branch -> dataset path.") normalized: dict[str, str] = {} for branch_name, path in value.items(): resolved = _resolve_default_branch_config_path(path) if resolved: normalized[str(branch_name)] = resolved return normalized def _normalize_default_branch_rule_map(value: Any) -> dict[str, dict[str, Any]]: if not isinstance(value, dict): raise ValueError("branch_dataset_filter_rules noklusējumam jābūt objektam.") return { str(branch): dict(payload) for branch, payload in value.items() if isinstance(payload, dict) } DEFAULT_BRANCH_CONFIG = _load_default_branch_config() DEFAULT_BRANCH_BENCHMARK_DATASET_PATHS: dict[str, str] = _normalize_default_branch_path_map( DEFAULT_BRANCH_CONFIG.get("branch_benchmark_dataset_paths", {}) ) DEFAULT_BRANCH_BENCHMARK_NAMES: dict[str, str] = { **DEFAULT_BRANCH_BENCHMARK_NAMES, **{ str(branch): str(name).strip() for branch, name in DEFAULT_BRANCH_CONFIG.get("branch_benchmark_names", {}).items() if str(name).strip() }, } DEFAULT_BRANCH_PREFERENCE_DATASET_PATHS: dict[str, str] = _normalize_default_branch_path_map( DEFAULT_BRANCH_CONFIG.get("branch_preference_dataset_paths", {}) ) DEFAULT_BRANCH_DATASET_FILTER_RULES = _normalize_default_branch_rule_map( DEFAULT_BRANCH_CONFIG.get("branch_dataset_filter_rules", DEFAULT_BRANCH_CONFIG) ) def _parse_bool(value: Any, *, default: bool) -> bool: if value is None: return default if isinstance(value, bool): return value return str(value).strip().lower() in {"1", "true", "yes", "on"} def _parse_optional_bool(value: Any) -> bool | None: if value is None: return None if isinstance(value, str) and not value.strip(): return None return _parse_bool(value, default=False) def _parse_list(value: Any) -> list[str]: if value is None: return [] if isinstance(value, list): return [str(item) for item in value if str(item).strip()] return [item.strip() for item in str(value).split(",") if item.strip()] def _parse_repo_list(value: Any, *, default: list[str] | None = None) -> list[str]: if value in (None, ""): return list(default or []) parsed = ( json.loads(value) if isinstance(value, str) and value.lstrip().startswith("[") else value ) raw_items = parsed if isinstance(parsed, list) else EXTRA_MODEL_SPLIT_RE.split(str(parsed)) normalized: list[str] = [] for item in raw_items: candidate = str(item or "").strip() if candidate and candidate not in normalized: normalized.append(candidate) return normalized @dataclass(slots=True) class TrainingConfig: """Pilna apmācības konfigurācija vienam Maris treniņa skrējienam.""" model_name: str = DEFAULT_TRAINING_BASE_MODEL model_preset: str = "" branch_name: str = "master" branch_focus: str = "general_reasoning" adapter_type: str = "full" lora_r: int = 16 lora_alpha: int = 32 lora_dropout: float = 0.05 lora_bias: str = "none" peft_target_modules: list[str] = field(default_factory=list) qlora_quant_type: str = "nf4" qlora_use_double_quant: bool = True qlora_compute_dtype: str = "float16" dataset_repo: str = DEFAULT_PRIMARY_TRAINING_DATASET_REPO dataset_repos: list[str] = field(default_factory=list) eval_dataset_repo: str = "" eval_dataset_repos: list[str] = field(default_factory=list) output_dir: str = "./output/model" hub_model_id: str = DEFAULT_MASTER_MODEL_REPO text_model_id: str = DEFAULT_TEXT_MODEL_REPO image_model_id: str = DEFAULT_IMAGE_MODEL_REPO music_model_id: str = DEFAULT_MUSIC_MODEL_REPO tts_model_id: str = DEFAULT_TTS_MODEL_REPO stt_model_id: str = DEFAULT_STT_MODEL_REPO video_model_id: str = DEFAULT_VIDEO_MODEL_REPO num_epochs: int = 3 learning_rate: float = 2e-5 per_device_train_batch_size: int = 1 per_device_eval_batch_size: int = 1 gradient_accumulation_steps: int = 8 warmup_ratio: float = 0.1 weight_decay: float = 0.01 logging_steps: int = 10 save_steps: int = 100 eval_steps: int = 100 save_total_limit: int = 2 max_seq_length: int = 1024 validation_split_ratio: float = 0.1 seed: int = 42 fp16: bool = False bf16: bool = False gradient_checkpointing: bool = False gradient_checkpointing_use_reentrant: bool | None = None distributed_strategy: str = "none" distributed_config_path: str = "" use_accelerate: bool = False accelerate_config_path: str = "" num_processes: int = 1 num_machines: int = 1 machine_rank: int = 0 main_process_ip: str = "" main_process_port: int = 29500 fsdp_transformer_layer_cls_to_wrap: list[str] = field(default_factory=list) fsdp_min_num_params: int = 100_000_000 report_to: list[str] = field(default_factory=list) push_to_hub: bool = False save_safetensors: bool = True lr_scheduler_type: str = "cosine" benchmark_dataset_path: str = "" benchmark_name: str = DEFAULT_BENCHMARK_NAME benchmark_levels: list[str] = field(default_factory=lambda: ["local", "ci", "release"]) benchmark_min_overall: float = 0.7 benchmark_gate_enabled: bool = False branch_benchmark_targets: dict[str, dict[str, float]] = field( default_factory=lambda: { key: value.copy() for key, value in DEFAULT_BRANCH_BENCHMARK_TARGETS.items() } ) branch_benchmark_names: dict[str, str] = field( default_factory=lambda: DEFAULT_BRANCH_BENCHMARK_NAMES.copy() ) branch_benchmark_dataset_paths: dict[str, str] = field( default_factory=lambda: DEFAULT_BRANCH_BENCHMARK_DATASET_PATHS.copy() ) branch_preference_dataset_paths: dict[str, str] = field( default_factory=lambda: DEFAULT_BRANCH_PREFERENCE_DATASET_PATHS.copy() ) branch_dataset_filter_rules: dict[str, dict[str, Any]] = field( default_factory=lambda: { key: value.copy() for key, value in DEFAULT_BRANCH_DATASET_FILTER_RULES.items() } ) preference_dataset_path: str = "" preference_optimization: str = "none" preference_beta: float = 0.1 preference_max_prompt_length: int = 512 preference_max_length: int = 1024 preference_reference_model: str = "" quality_gate_enabled: bool = True dedupe_enabled: bool = True quality_min_text_chars: int = 4 scoring_enabled: bool = True weighted_repetition_enabled: bool = True medium_score_repeat_count: int = 2 high_score_repeat_count: int = 3 source_weighting_enabled: bool = True source_weight_map: dict[str, float] = field( default_factory=lambda: DEFAULT_SOURCE_WEIGHT_MAP.copy() ) category_weight_map: dict[str, float] = field( default_factory=lambda: DEFAULT_CATEGORY_WEIGHT_MAP.copy() ) max_effective_repeat_count: int = 6 benchmark_feedback_enabled: bool = True benchmark_feedback_auto_discover: bool = True benchmark_feedback_path: str = "" benchmark_feedback_boost_scale: float = 2.0 benchmark_feedback_max_multiplier: float = 1.75 continue_from_latest_artifact: bool = False continue_model_path: str = "" def to_dict(self) -> dict[str, Any]: """Serializē konfigurāciju uz dict.""" return asdict(self) def list_training_base_models() -> dict[str, dict[str, str]]: """Atgriež iepriekš definētos bāzes modeļu presetus.""" models = {key: value.copy() for key, value in AVAILABLE_TRAINING_BASE_MODELS.items()} models.update(_load_extra_training_base_models()) return models def _normalize_extra_training_base_model_payload( preset_key: str, payload: Any, ) -> dict[str, str] | None: if isinstance(payload, str): model_name = payload.strip() label = "" description = "" elif isinstance(payload, dict): model_name = str(payload.get("model_name", "") or "").strip() label = str(payload.get("label", "") or "").strip() description = str(payload.get("description", "") or "").strip() else: logger.warning( "Ignoring extra training preset %r because payload must be an object or model string.", preset_key, ) return None model_parts = model_name.split("/", 1) if len(model_parts) != 2 or not all(part.strip() for part in model_parts): logger.warning( "Ignoring extra training preset %r because model_name must use owner/name format.", preset_key, ) return None if not label: label = preset_key.replace("-", " ").replace("_", " ").title() if not description: description = f"External base model preset {model_name}." return { "model_name": model_name, "label": label, "description": description, } def _load_extra_training_base_models() -> dict[str, dict[str, str]]: raw_value = get_env_any("MARIS_TRAIN_EXTRA_MODELS", "HF_TRAIN_EXTRA_MODELS", default="") or "" normalized = raw_value.strip() if not normalized: return {} try: parsed = json.loads(normalized) except json.JSONDecodeError as exc: fallback_models = _parse_extra_training_base_models_fallback(normalized) if fallback_models: logger.info( "Parsed MARIS_TRAIN_EXTRA_MODELS/HF_TRAIN_EXTRA_MODELS using owner/name fallback syntax." ) return fallback_models logger.warning( "Ignoring MARIS_TRAIN_EXTRA_MODELS/HF_TRAIN_EXTRA_MODELS because value is not valid JSON or supported fallback syntax.", exc_info=exc, ) return {} normalized_payloads = _coerce_extra_training_base_models_payload(parsed) if normalized_payloads is None: logger.warning( "Ignoring MARIS_TRAIN_EXTRA_MODELS/HF_TRAIN_EXTRA_MODELS because top-level value must be a JSON object, JSON array, or owner/name fallback list." ) return {} result: dict[str, dict[str, str]] = {} for preset_name, payload in normalized_payloads.items(): preset_key = str(preset_name).strip() if not preset_key: logger.warning("Ignoring extra training preset with empty name.") continue normalized_payload = _normalize_extra_training_base_model_payload(preset_key, payload) if normalized_payload is None: continue result[preset_key] = normalized_payload return result def _coerce_extra_training_base_models_payload(parsed: Any) -> dict[str, Any] | None: if isinstance(parsed, dict): return parsed if not isinstance(parsed, list): return None coerced: dict[str, Any] = {} for payload in parsed: if isinstance(payload, str): preset_key = _build_extra_training_preset_key(payload, existing=coerced) coerced[preset_key] = payload continue if isinstance(payload, dict): preset_key = str(payload.get("preset") or payload.get("key") or "").strip() model_name = str(payload.get("model_name") or payload.get("model") or "").strip() if not model_name: logger.warning( "Ignoring extra training preset list item because model_name is missing." ) continue if not preset_key: preset_key = _build_extra_training_preset_key(model_name, existing=coerced) normalized_payload = { "model_name": model_name, "label": str(payload.get("label", "") or "").strip(), "description": str(payload.get("description", "") or "").strip(), } coerced[preset_key] = normalized_payload continue logger.warning( "Ignoring extra training preset list item %r because it must be a string or object.", payload, ) return coerced def _parse_extra_training_base_models_fallback(raw_value: str) -> dict[str, dict[str, str]]: result: dict[str, dict[str, str]] = {} candidates = [item.strip() for item in EXTRA_MODEL_SPLIT_RE.split(raw_value) if item.strip()] for candidate in candidates: preset_key = "" model_name = candidate if "=" in candidate: preset_key, model_name = candidate.split("=", 1) preset_key = preset_key.strip() model_name = model_name.strip() if not model_name: continue if not preset_key: preset_key = _build_extra_training_preset_key(model_name, existing=result) normalized_payload = _normalize_extra_training_base_model_payload(preset_key, model_name) if normalized_payload is None: return {} result[preset_key] = normalized_payload return result def _build_extra_training_preset_key( model_name: str, *, existing: dict[str, Any], ) -> str: owner_name = model_name.strip().lower().replace("/", "-") base_key = EXTRA_MODEL_KEY_SANITIZE_RE.sub("-", owner_name).strip("-") or "extra-model" candidate = base_key suffix = 2 while candidate in existing: candidate = f"{base_key}-{suffix}" suffix += 1 return candidate def resolve_training_model( model_name: str, model_preset: str | None, *, available_models: dict[str, dict[str, str]] | None = None, ) -> str: """Atrisina modeļa preset uz konkrētu bāzes modeli.""" normalized_preset = (model_preset or "").strip() if not normalized_preset: return model_name resolved_models = available_models or list_training_base_models() preset = resolved_models.get(normalized_preset) if preset is None: available = ", ".join(sorted(resolved_models)) raise ValueError( f"Nezināms MARIS_TRAIN_MODEL_PRESET/model_preset. Izmanto vienu no: {available}." ) return preset["model_name"] def resolve_model_selection( default_model_name: str, *sources: dict[str, Any], available_models: dict[str, dict[str, str]] | None = None, ) -> tuple[str, str]: """Atrod efektīvo modeli no augstākās prioritātes avota. Katrā avotā tiešs `model_name` ir prioritārāks par `model_preset`, jo tas ir precīzāks override nekā presets. """ resolved_model_name = default_model_name resolved_model_preset = "" for source in sources: source_model_name = source.get("model_name") source_model_preset = source.get("model_preset") if source_model_name not in (None, ""): return str(source_model_name), "" if source_model_preset not in (None, ""): resolved_model_preset = str(source_model_preset) resolved_model_name = resolve_training_model( default_model_name, resolved_model_preset, available_models=available_models, ) return resolved_model_name, resolved_model_preset return resolved_model_name, resolved_model_preset def load_training_config( config_path: str | None = None, overrides: dict[str, Any] | None = None, ) -> TrainingConfig: """Ielādē konfigurāciju no JSON, vides mainīgajiem un CLI override'iem.""" data: dict[str, Any] = {} defaults = TrainingConfig().to_dict() if config_path: data = json.loads(Path(config_path).read_text(encoding="utf-8")) env_data: dict[str, Any] = { "model_name": get_env_any("MARIS_TRAIN_BASE_MODEL", "HF_TRAIN_BASE_MODEL", "TEXT_MODEL"), "model_preset": get_env_any("MARIS_TRAIN_MODEL_PRESET", "HF_TRAIN_MODEL_PRESET"), "branch_name": get_env_any("MARIS_TRAIN_BRANCH_NAME", "HF_TRAIN_BRANCH_NAME"), "branch_focus": get_env_any("MARIS_TRAIN_BRANCH_FOCUS", "HF_TRAIN_BRANCH_FOCUS"), "adapter_type": get_env_any("MARIS_TRAIN_ADAPTER_TYPE", "HF_TRAIN_ADAPTER_TYPE"), "lora_r": get_env_any("MARIS_TRAIN_LORA_R", "HF_TRAIN_LORA_R"), "lora_alpha": get_env_any("MARIS_TRAIN_LORA_ALPHA", "HF_TRAIN_LORA_ALPHA"), "lora_dropout": get_env_any("MARIS_TRAIN_LORA_DROPOUT", "HF_TRAIN_LORA_DROPOUT"), "lora_bias": get_env_any("MARIS_TRAIN_LORA_BIAS", "HF_TRAIN_LORA_BIAS"), "peft_target_modules": get_env_any( "MARIS_TRAIN_PEFT_TARGET_MODULES", "HF_TRAIN_PEFT_TARGET_MODULES", ), "qlora_quant_type": get_env_any( "MARIS_TRAIN_QLORA_QUANT_TYPE", "HF_TRAIN_QLORA_QUANT_TYPE", ), "qlora_use_double_quant": get_env_any( "MARIS_TRAIN_QLORA_USE_DOUBLE_QUANT", "HF_TRAIN_QLORA_USE_DOUBLE_QUANT", ), "qlora_compute_dtype": get_env_any( "MARIS_TRAIN_QLORA_COMPUTE_DTYPE", "HF_TRAIN_QLORA_COMPUTE_DTYPE", ), "dataset_repo": get_env_any("MARIS_MEMORY_REPO", "MARIS_DATASET_REPO", "HF_DATASET_REPO"), "dataset_repos": get_env_any("MARIS_DATASET_REPOS", "HF_DATASET_REPOS"), "eval_dataset_repo": get_env_any("MARIS_EVAL_DATASET_REPO", "HF_EVAL_DATASET_REPO"), "eval_dataset_repos": get_env_any("MARIS_EVAL_DATASET_REPOS", "HF_EVAL_DATASET_REPOS"), "output_dir": get_env_any("MARIS_TRAIN_OUTPUT_DIR", "HF_TRAIN_OUTPUT_DIR"), "hub_model_id": get_env_any("MARIS_MODEL_REPO", "HF_MODEL_REPO"), "text_model_id": get_env_any("TEXT_MODEL", default=DEFAULT_TEXT_MODEL_REPO), "image_model_id": get_env_any("IMAGE_MODEL", default=DEFAULT_IMAGE_MODEL_REPO), "music_model_id": get_env_any("MUSIC_MODEL", default=DEFAULT_MUSIC_MODEL_REPO), "tts_model_id": get_env_any("TTS_MODEL", default=DEFAULT_TTS_MODEL_REPO), "stt_model_id": get_env_any("STT_MODEL", default=DEFAULT_STT_MODEL_REPO), "video_model_id": get_env_any("VIDEO_MODEL", default=DEFAULT_VIDEO_MODEL_REPO), "num_epochs": get_env_any("MARIS_TRAIN_NUM_EPOCHS", "HF_TRAIN_NUM_EPOCHS"), "learning_rate": get_env_any("MARIS_TRAIN_LEARNING_RATE", "HF_TRAIN_LEARNING_RATE"), "per_device_train_batch_size": get_env_any("MARIS_TRAIN_BATCH_SIZE", "HF_TRAIN_BATCH_SIZE"), "per_device_eval_batch_size": get_env_any( "MARIS_TRAIN_EVAL_BATCH_SIZE", "HF_TRAIN_EVAL_BATCH_SIZE" ), "gradient_accumulation_steps": get_env_any( "MARIS_TRAIN_GRADIENT_ACCUMULATION_STEPS", "HF_TRAIN_GRADIENT_ACCUMULATION_STEPS", ), "warmup_ratio": get_env_any("MARIS_TRAIN_WARMUP_RATIO", "HF_TRAIN_WARMUP_RATIO"), "weight_decay": get_env_any("MARIS_TRAIN_WEIGHT_DECAY", "HF_TRAIN_WEIGHT_DECAY"), "logging_steps": get_env_any("MARIS_TRAIN_LOGGING_STEPS", "HF_TRAIN_LOGGING_STEPS"), "save_steps": get_env_any("MARIS_TRAIN_SAVE_STEPS", "HF_TRAIN_SAVE_STEPS"), "eval_steps": get_env_any("MARIS_TRAIN_EVAL_STEPS", "HF_TRAIN_EVAL_STEPS"), "save_total_limit": get_env_any( "MARIS_TRAIN_SAVE_TOTAL_LIMIT", "HF_TRAIN_SAVE_TOTAL_LIMIT" ), "max_seq_length": get_env_any("MARIS_TRAIN_MAX_SEQ_LENGTH", "HF_TRAIN_MAX_SEQ_LENGTH"), "validation_split_ratio": get_env_any( "MARIS_TRAIN_VALIDATION_SPLIT", "HF_TRAIN_VALIDATION_SPLIT" ), "seed": get_env_any("MARIS_TRAIN_SEED", "HF_TRAIN_SEED"), "fp16": get_env_any("MARIS_TRAIN_FP16", "HF_TRAIN_FP16"), "bf16": get_env_any("MARIS_TRAIN_BF16", "HF_TRAIN_BF16"), "gradient_checkpointing": get_env_any( "MARIS_TRAIN_GRADIENT_CHECKPOINTING", "HF_TRAIN_GRADIENT_CHECKPOINTING", ), "gradient_checkpointing_use_reentrant": get_env_any( "MARIS_TRAIN_GRADIENT_CHECKPOINTING_USE_REENTRANT", "HF_TRAIN_GRADIENT_CHECKPOINTING_USE_REENTRANT", ), "distributed_strategy": get_env_any( "MARIS_TRAIN_DISTRIBUTED_STRATEGY", "HF_TRAIN_DISTRIBUTED_STRATEGY", ), "distributed_config_path": get_env_any( "MARIS_TRAIN_DISTRIBUTED_CONFIG_PATH", "HF_TRAIN_DISTRIBUTED_CONFIG_PATH", ), "use_accelerate": get_env_any( "MARIS_TRAIN_USE_ACCELERATE", "HF_TRAIN_USE_ACCELERATE", ), "accelerate_config_path": get_env_any( "MARIS_TRAIN_ACCELERATE_CONFIG_PATH", "HF_TRAIN_ACCELERATE_CONFIG_PATH", ), "num_processes": get_env_any("MARIS_TRAIN_NUM_PROCESSES", "HF_TRAIN_NUM_PROCESSES"), "num_machines": get_env_any("MARIS_TRAIN_NUM_MACHINES", "HF_TRAIN_NUM_MACHINES"), "machine_rank": get_env_any("MARIS_TRAIN_MACHINE_RANK", "HF_TRAIN_MACHINE_RANK"), "main_process_ip": get_env_any( "MARIS_TRAIN_MAIN_PROCESS_IP", "HF_TRAIN_MAIN_PROCESS_IP", ), "main_process_port": get_env_any( "MARIS_TRAIN_MAIN_PROCESS_PORT", "HF_TRAIN_MAIN_PROCESS_PORT", ), "fsdp_transformer_layer_cls_to_wrap": get_env_any( "MARIS_TRAIN_FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP", "HF_TRAIN_FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP", ), "fsdp_min_num_params": get_env_any( "MARIS_TRAIN_FSDP_MIN_NUM_PARAMS", "HF_TRAIN_FSDP_MIN_NUM_PARAMS", ), "report_to": get_env_any("MARIS_TRAIN_REPORT_TO", "HF_TRAIN_REPORT_TO"), "push_to_hub": get_env_any("MARIS_TRAIN_PUBLISH", "HF_TRAIN_PUSH_TO_HUB"), "save_safetensors": get_env_any( "MARIS_TRAIN_SAVE_SAFETENSORS", "HF_TRAIN_SAVE_SAFETENSORS" ), "lr_scheduler_type": get_env_any( "MARIS_TRAIN_LR_SCHEDULER_TYPE", "HF_TRAIN_LR_SCHEDULER_TYPE" ), "benchmark_dataset_path": get_env_any( "MARIS_BENCHMARK_DATASET_PATH", "HF_BENCHMARK_DATASET_PATH" ), "benchmark_name": get_env_any("MARIS_BENCHMARK_NAME", "HF_BENCHMARK_NAME"), "benchmark_levels": get_env_any("MARIS_BENCHMARK_LEVELS", "HF_BENCHMARK_LEVELS"), "benchmark_min_overall": get_env_any( "MARIS_BENCHMARK_MIN_OVERALL", "HF_BENCHMARK_MIN_OVERALL" ), "benchmark_gate_enabled": get_env_any( "MARIS_BENCHMARK_GATE_ENABLED", "HF_BENCHMARK_GATE_ENABLED" ), "branch_benchmark_names": get_env_any( "MARIS_BRANCH_BENCHMARK_NAMES", "HF_BRANCH_BENCHMARK_NAMES", ), "branch_benchmark_dataset_paths": get_env_any( "MARIS_BRANCH_BENCHMARK_DATASET_PATHS", "HF_BRANCH_BENCHMARK_DATASET_PATHS", ), "branch_preference_dataset_paths": get_env_any( "MARIS_BRANCH_PREFERENCE_DATASET_PATHS", "HF_BRANCH_PREFERENCE_DATASET_PATHS", ), "branch_dataset_filter_rules": get_env_any( "MARIS_BRANCH_DATASET_FILTER_RULES", "HF_BRANCH_DATASET_FILTER_RULES", ), "preference_dataset_path": get_env_any( "MARIS_PREFERENCE_DATASET_PATH", "HF_PREFERENCE_DATASET_PATH" ), "preference_optimization": get_env_any( "MARIS_PREFERENCE_OPTIMIZATION", "HF_PREFERENCE_OPTIMIZATION", ), "preference_beta": get_env_any("MARIS_PREFERENCE_BETA", "HF_PREFERENCE_BETA"), "preference_max_prompt_length": get_env_any( "MARIS_PREFERENCE_MAX_PROMPT_LENGTH", "HF_PREFERENCE_MAX_PROMPT_LENGTH", ), "preference_max_length": get_env_any( "MARIS_PREFERENCE_MAX_LENGTH", "HF_PREFERENCE_MAX_LENGTH", ), "preference_reference_model": get_env_any( "MARIS_PREFERENCE_REFERENCE_MODEL", "HF_PREFERENCE_REFERENCE_MODEL", ), "quality_gate_enabled": get_env_any( "MARIS_TRAIN_QUALITY_GATE_ENABLED", "HF_TRAIN_QUALITY_GATE_ENABLED", ), "dedupe_enabled": get_env_any( "MARIS_TRAIN_DEDUPE_ENABLED", "HF_TRAIN_DEDUPE_ENABLED", ), "quality_min_text_chars": get_env_any( "MARIS_TRAIN_QUALITY_MIN_CHARS", "HF_TRAIN_QUALITY_MIN_CHARS", ), "scoring_enabled": get_env_any( "MARIS_TRAIN_SCORING_ENABLED", "HF_TRAIN_SCORING_ENABLED", ), "weighted_repetition_enabled": get_env_any( "MARIS_TRAIN_WEIGHTED_REPETITION_ENABLED", "HF_TRAIN_WEIGHTED_REPETITION_ENABLED", ), "medium_score_repeat_count": get_env_any( "MARIS_TRAIN_MEDIUM_SCORE_REPEAT_COUNT", "HF_TRAIN_MEDIUM_SCORE_REPEAT_COUNT", ), "high_score_repeat_count": get_env_any( "MARIS_TRAIN_HIGH_SCORE_REPEAT_COUNT", "HF_TRAIN_HIGH_SCORE_REPEAT_COUNT", ), "source_weighting_enabled": get_env_any( "MARIS_TRAIN_SOURCE_WEIGHTING_ENABLED", "HF_TRAIN_SOURCE_WEIGHTING_ENABLED", ), "source_weight_map": get_env_any( "MARIS_TRAIN_SOURCE_WEIGHT_MAP", "HF_TRAIN_SOURCE_WEIGHT_MAP", ), "category_weight_map": get_env_any( "MARIS_TRAIN_CATEGORY_WEIGHT_MAP", "HF_TRAIN_CATEGORY_WEIGHT_MAP", ), "max_effective_repeat_count": get_env_any( "MARIS_TRAIN_MAX_EFFECTIVE_REPEAT_COUNT", "HF_TRAIN_MAX_EFFECTIVE_REPEAT_COUNT", ), "benchmark_feedback_enabled": get_env_any( "MARIS_TRAIN_BENCHMARK_FEEDBACK_ENABLED", "HF_TRAIN_BENCHMARK_FEEDBACK_ENABLED", ), "benchmark_feedback_auto_discover": get_env_any( "MARIS_TRAIN_BENCHMARK_FEEDBACK_AUTO_DISCOVER", "HF_TRAIN_BENCHMARK_FEEDBACK_AUTO_DISCOVER", ), "benchmark_feedback_path": get_env_any( "MARIS_TRAIN_BENCHMARK_FEEDBACK_PATH", "HF_TRAIN_BENCHMARK_FEEDBACK_PATH", ), "benchmark_feedback_boost_scale": get_env_any( "MARIS_TRAIN_BENCHMARK_FEEDBACK_BOOST_SCALE", "HF_TRAIN_BENCHMARK_FEEDBACK_BOOST_SCALE", ), "benchmark_feedback_max_multiplier": get_env_any( "MARIS_TRAIN_BENCHMARK_FEEDBACK_MAX_MULTIPLIER", "HF_TRAIN_BENCHMARK_FEEDBACK_MAX_MULTIPLIER", ), "continue_from_latest_artifact": get_env_any( "MARIS_TRAIN_CONTINUE_FROM_LATEST", "HF_TRAIN_CONTINUE_FROM_LATEST", ), "continue_model_path": get_env_any( "MARIS_TRAIN_CONTINUE_MODEL_PATH", "HF_TRAIN_CONTINUE_MODEL_PATH", "MARIS_LOCAL_MODEL_DIR", "HF_LOCAL_MODEL_DIR", ), } env_overrides = {key: value for key, value in env_data.items() if value not in (None, "")} cli_overrides = {key: value for key, value in (overrides or {}).items() if value is not None} merged: dict[str, Any] = {} merged.update(defaults) merged.update(data) merged.update(env_overrides) merged.update(cli_overrides) resolved_distributed_strategy = str( merged.get("distributed_strategy", "none") or "none" ).lower() explicit_use_accelerate = next( ( source["use_accelerate"] for source in (cli_overrides, env_overrides, data) if source.get("use_accelerate") not in (None, "") ), None, ) available_base_models = list_training_base_models() resolved_model_name, resolved_model_preset = resolve_model_selection( str(defaults["model_name"]), cli_overrides, env_overrides, data, available_models=available_base_models, ) config = TrainingConfig( model_name=resolved_model_name, model_preset=resolved_model_preset, branch_name=str(merged["branch_name"]), branch_focus=str(merged["branch_focus"]), adapter_type=str(merged["adapter_type"]), lora_r=int(merged.get("lora_r", 16)), lora_alpha=int(merged.get("lora_alpha", 32)), lora_dropout=float(merged.get("lora_dropout", 0.05)), lora_bias=str(merged.get("lora_bias", "none") or "none"), peft_target_modules=_parse_list(merged.get("peft_target_modules")), qlora_quant_type=str(merged.get("qlora_quant_type", "nf4") or "nf4"), qlora_use_double_quant=_parse_bool( merged.get("qlora_use_double_quant"), default=True, ), qlora_compute_dtype=str(merged.get("qlora_compute_dtype") or "float16").lower(), dataset_repo=validate_maris_repo( str(merged["dataset_repo"]), "MARIS_MEMORY_REPO/MARIS_DATASET_REPO/HF_DATASET_REPO/dataset_repo", label="dataset repozitorijs", ), dataset_repos=[], eval_dataset_repo=( validate_maris_repo( str(merged["eval_dataset_repo"]), "MARIS_EVAL_DATASET_REPO/HF_EVAL_DATASET_REPO/eval_dataset_repo", label="eval dataset repozitorijs", ) if merged.get("eval_dataset_repo") not in (None, "") else "" ), eval_dataset_repos=[], output_dir=str(merged["output_dir"]), hub_model_id=validate_maris_model( str(merged["hub_model_id"]), "MARIS_MODEL_REPO/HF_MODEL_REPO/hub_model_id", ), text_model_id=validate_maris_model( str(merged["text_model_id"]), "TEXT_MODEL/text_model_id", ), image_model_id=validate_maris_model( str(merged["image_model_id"]), "IMAGE_MODEL/image_model_id", ), music_model_id=validate_maris_model( str(merged["music_model_id"]), "MUSIC_MODEL/music_model_id", ), tts_model_id=validate_maris_model( str(merged["tts_model_id"]), "TTS_MODEL/tts_model_id", ), stt_model_id=validate_maris_model( str(merged["stt_model_id"]), "STT_MODEL/stt_model_id", ), video_model_id=validate_maris_model( str(merged["video_model_id"]), "VIDEO_MODEL/video_model_id", ), num_epochs=int(merged["num_epochs"]), learning_rate=float(merged["learning_rate"]), per_device_train_batch_size=int(merged["per_device_train_batch_size"]), per_device_eval_batch_size=int(merged["per_device_eval_batch_size"]), gradient_accumulation_steps=int(merged["gradient_accumulation_steps"]), warmup_ratio=float(merged["warmup_ratio"]), weight_decay=float(merged["weight_decay"]), logging_steps=int(merged["logging_steps"]), save_steps=int(merged["save_steps"]), eval_steps=int(merged["eval_steps"]), save_total_limit=int(merged["save_total_limit"]), max_seq_length=int(merged["max_seq_length"]), validation_split_ratio=float(merged["validation_split_ratio"]), seed=int(merged["seed"]), fp16=_parse_bool(merged.get("fp16"), default=False), bf16=_parse_bool(merged.get("bf16"), default=False), gradient_checkpointing=_parse_bool(merged.get("gradient_checkpointing"), default=False), gradient_checkpointing_use_reentrant=_parse_optional_bool( merged.get("gradient_checkpointing_use_reentrant") ), distributed_strategy=resolved_distributed_strategy, distributed_config_path=str(merged.get("distributed_config_path", "") or ""), use_accelerate=_parse_bool( explicit_use_accelerate, default=resolved_distributed_strategy != "none", ), accelerate_config_path=str(merged.get("accelerate_config_path", "") or ""), num_processes=int(merged.get("num_processes", 1)), num_machines=int(merged.get("num_machines", 1)), machine_rank=int(merged.get("machine_rank", 0)), main_process_ip=str(merged.get("main_process_ip", "") or ""), main_process_port=int(merged.get("main_process_port", 29500)), fsdp_transformer_layer_cls_to_wrap=_parse_list( merged.get("fsdp_transformer_layer_cls_to_wrap") ), fsdp_min_num_params=int(merged.get("fsdp_min_num_params", 100_000_000)), report_to=_parse_list(merged.get("report_to")), push_to_hub=_parse_bool(merged.get("push_to_hub"), default=False), save_safetensors=_parse_bool(merged.get("save_safetensors"), default=True), lr_scheduler_type=str(merged["lr_scheduler_type"]), benchmark_dataset_path=str(merged.get("benchmark_dataset_path", "") or ""), benchmark_name=str( merged.get("benchmark_name", DEFAULT_BENCHMARK_NAME) or DEFAULT_BENCHMARK_NAME ), benchmark_levels=_parse_list(merged.get("benchmark_levels")) or ["local", "ci", "release"], benchmark_min_overall=float(merged.get("benchmark_min_overall", 0.7)), benchmark_gate_enabled=_parse_bool(merged.get("benchmark_gate_enabled"), default=False), branch_benchmark_targets=_parse_branch_targets(merged.get("branch_benchmark_targets")), branch_benchmark_names=_parse_branch_benchmark_names(merged.get("branch_benchmark_names")), branch_benchmark_dataset_paths=_parse_branch_benchmark_dataset_paths( merged.get("branch_benchmark_dataset_paths") ), branch_preference_dataset_paths=_parse_branch_preference_dataset_paths( merged.get("branch_preference_dataset_paths") ), branch_dataset_filter_rules=_parse_branch_dataset_filter_rules( merged.get("branch_dataset_filter_rules") ), preference_dataset_path=str(merged.get("preference_dataset_path", "") or ""), preference_optimization=str( merged.get("preference_optimization", "none") or "none" ).lower(), preference_beta=float(merged.get("preference_beta", 0.1)), preference_max_prompt_length=int(merged.get("preference_max_prompt_length", 512)), preference_max_length=int(merged.get("preference_max_length", 1024)), preference_reference_model=( validate_maris_model( str(merged["preference_reference_model"]), "MARIS_PREFERENCE_REFERENCE_MODEL/HF_PREFERENCE_REFERENCE_MODEL/preference_reference_model", ) if merged.get("preference_reference_model") not in (None, "") else "" ), quality_gate_enabled=_parse_bool(merged.get("quality_gate_enabled"), default=True), dedupe_enabled=_parse_bool(merged.get("dedupe_enabled"), default=True), quality_min_text_chars=int(merged.get("quality_min_text_chars", 4)), scoring_enabled=_parse_bool(merged.get("scoring_enabled"), default=True), weighted_repetition_enabled=_parse_bool( merged.get("weighted_repetition_enabled"), default=True, ), medium_score_repeat_count=int(merged.get("medium_score_repeat_count", 2)), high_score_repeat_count=int(merged.get("high_score_repeat_count", 3)), source_weighting_enabled=_parse_bool( merged.get("source_weighting_enabled"), default=True, ), source_weight_map=_parse_source_weight_map(merged.get("source_weight_map")), category_weight_map=_parse_category_weight_map(merged.get("category_weight_map")), max_effective_repeat_count=int(merged.get("max_effective_repeat_count", 6)), benchmark_feedback_enabled=_parse_bool( merged.get("benchmark_feedback_enabled"), default=True, ), benchmark_feedback_auto_discover=_parse_bool( merged.get("benchmark_feedback_auto_discover"), default=True, ), benchmark_feedback_path=str(merged.get("benchmark_feedback_path", "") or ""), benchmark_feedback_boost_scale=float(merged.get("benchmark_feedback_boost_scale", 2.0)), benchmark_feedback_max_multiplier=float( merged.get("benchmark_feedback_max_multiplier", 1.75) ), continue_from_latest_artifact=_parse_bool( merged.get("continue_from_latest_artifact"), default=False, ), continue_model_path=str(merged.get("continue_model_path", "") or ""), ) config.dataset_repos = [ validate_maris_repo( repo_id, "MARIS_DATASET_REPOS/HF_DATASET_REPOS/dataset_repos", label="dataset repozitorijs", ) for repo_id in _parse_repo_list( merged.get("dataset_repos"), default=[config.dataset_repo], ) ] if config.dataset_repo not in config.dataset_repos: config.dataset_repos.insert(0, config.dataset_repo) config.eval_dataset_repos = [ validate_maris_repo( repo_id, "MARIS_EVAL_DATASET_REPOS/HF_EVAL_DATASET_REPOS/eval_dataset_repos", label="eval dataset repozitorijs", ) for repo_id in _parse_repo_list( merged.get("eval_dataset_repos"), default=[config.eval_dataset_repo] if config.eval_dataset_repo else [], ) ] if config.eval_dataset_repo and config.eval_dataset_repo not in config.eval_dataset_repos: config.eval_dataset_repos.insert(0, config.eval_dataset_repo) if config.fp16 and config.bf16: raise ValueError("Maris training konfigurācijā nevar vienlaikus ieslēgt fp16 un bf16.") if config.adapter_type not in {"full", "lora", "qlora", "specialist_model"}: raise ValueError("adapter_type jābūt vienam no: full, lora, qlora, specialist_model.") if config.distributed_strategy not in {"none", "fsdp", "deepspeed"}: raise ValueError("distributed_strategy jābūt vienam no: none, fsdp, deepspeed.") if config.preference_optimization not in {"none", "dpo", "orpo"}: raise ValueError("preference_optimization jābūt vienam no: none, dpo, orpo.") if config.preference_optimization != "none" and not config.preference_dataset_path: raise ValueError( "Preference optimization vajag preference_dataset_path ar prompt/chosen/rejected datiem." ) if config.num_processes < 1: raise ValueError("num_processes jābūt vismaz 1.") if config.num_machines < 1: raise ValueError("num_machines jābūt vismaz 1.") if config.machine_rank < 0: raise ValueError("machine_rank nedrīkst būt negatīvs.") if config.main_process_port < 1: raise ValueError("main_process_port jābūt pozitīvam portam.") if config.fsdp_min_num_params < 0: raise ValueError("fsdp_min_num_params nedrīkst būt negatīvs.") return config def _parse_branch_targets(value: Any) -> dict[str, dict[str, float]]: if value in (None, ""): return {key: item.copy() for key, item in DEFAULT_BRANCH_BENCHMARK_TARGETS.items()} parsed = json.loads(value) if isinstance(value, str) else value if not isinstance(parsed, dict): raise ValueError("branch_benchmark_targets jābūt objektam ar branch -> metric -> score.") normalized: dict[str, dict[str, float]] = {} for branch, metrics in parsed.items(): if not isinstance(metrics, dict): raise ValueError("Katram branch_benchmark_targets ierakstam jābūt objektam.") normalized[str(branch)] = {str(name): float(score) for name, score in metrics.items()} return normalized def _parse_source_weight_map(value: Any) -> dict[str, float]: if value in (None, ""): return DEFAULT_SOURCE_WEIGHT_MAP.copy() parsed = json.loads(value) if isinstance(value, str) else value if not isinstance(parsed, dict): raise ValueError("source_weight_map jābūt objektam ar source_tier -> weight.") normalized = DEFAULT_SOURCE_WEIGHT_MAP.copy() for tier, weight in parsed.items(): normalized[str(tier)] = float(weight) return normalized def _parse_branch_benchmark_names(value: Any) -> dict[str, str]: if value in (None, ""): return DEFAULT_BRANCH_BENCHMARK_NAMES.copy() if isinstance(value, str): value = json.loads(value) if not isinstance(value, dict): raise ValueError("branch_benchmark_names jābūt objektam ar branch -> benchmark name.") normalized: dict[str, str] = {} for branch_name, benchmark_name in value.items(): name = str(benchmark_name or "").strip() if name: normalized[str(branch_name)] = name return normalized def _parse_branch_benchmark_dataset_paths(value: Any) -> dict[str, str]: if value in (None, ""): return DEFAULT_BRANCH_BENCHMARK_DATASET_PATHS.copy() parsed = json.loads(value) if isinstance(value, str) else value if not isinstance(parsed, dict): raise ValueError( "branch_benchmark_dataset_paths jābūt objektam ar branch -> benchmark dataset path." ) normalized = DEFAULT_BRANCH_BENCHMARK_DATASET_PATHS.copy() for branch_name, path in parsed.items(): candidate = str(path or "").strip() if not candidate: continue normalized[str(branch_name)] = candidate return normalized def _parse_branch_preference_dataset_paths(value: Any) -> dict[str, str]: if value in (None, ""): return DEFAULT_BRANCH_PREFERENCE_DATASET_PATHS.copy() parsed = json.loads(value) if isinstance(value, str) else value if not isinstance(parsed, dict): raise ValueError( "branch_preference_dataset_paths jābūt objektam ar branch -> preference dataset path." ) normalized = DEFAULT_BRANCH_PREFERENCE_DATASET_PATHS.copy() for branch_name, path in parsed.items(): candidate = str(path or "").strip() if not candidate: continue normalized[str(branch_name)] = candidate return normalized def _parse_branch_dataset_filter_rules(value: Any) -> dict[str, dict[str, Any]]: if value in (None, ""): return {key: item.copy() for key, item in DEFAULT_BRANCH_DATASET_FILTER_RULES.items()} parsed = json.loads(value) if isinstance(value, str) else value if not isinstance(parsed, dict): raise ValueError("branch_dataset_filter_rules jābūt objektam ar branch -> filter rule map.") normalized = {key: item.copy() for key, item in DEFAULT_BRANCH_DATASET_FILTER_RULES.items()} for branch_name, rules in parsed.items(): if not isinstance(rules, dict): raise ValueError("Katram branch_dataset_filter_rules ierakstam jābūt objektam.") normalized[str(branch_name)] = dict(rules) return normalized def _parse_category_weight_map(value: Any) -> dict[str, float]: if value in (None, ""): return DEFAULT_CATEGORY_WEIGHT_MAP.copy() parsed = json.loads(value) if isinstance(value, str) else value if not isinstance(parsed, dict): raise ValueError("category_weight_map jābūt objektam ar category -> weight.") return {str(label): float(weight) for label, weight in parsed.items()}