| """Konfigurācija Maris apmācības pipeline'iem.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import logging |
| import re |
| from dataclasses import asdict, dataclass, field |
| from pathlib import Path |
| from typing import Any |
|
|
| from maris_core.utils.env import get_env_any, validate_maris_model, validate_maris_repo |
|
|
| logger = logging.getLogger(__name__) |
| EXTRA_MODEL_SPLIT_RE = re.compile(r"[\n,;]+") |
| EXTRA_MODEL_KEY_SANITIZE_RE = re.compile(r"[^a-z0-9]+") |
|
|
| DEFAULT_TRAINING_BASE_MODEL = "MarisUK/maris-ai-master" |
| DEFAULT_MASTER_MODEL_REPO = "MarisUK/maris-ai-master" |
| DEFAULT_TEXT_MODEL_REPO = "MarisUK/maris-ai-text" |
| DEFAULT_IMAGE_MODEL_REPO = "MarisUK/maris-ai-image" |
| DEFAULT_MUSIC_MODEL_REPO = "MarisUK/maris-ai-music" |
| DEFAULT_TTS_MODEL_REPO = "MarisUK/maris-tts-runtime" |
| DEFAULT_STT_MODEL_REPO = "MarisUK/maris-stt-runtime" |
| DEFAULT_VIDEO_MODEL_REPO = "MarisUK/maris-ai-video" |
| DEFAULT_PRIMARY_TRAINING_DATASET_REPO = "MarisUK/maris-ai-memory" |
| DEFAULT_TRAINING_DATASET_REPOS: list[str] = [ |
| "MarisUK/maris-ai-memory", |
| DEFAULT_PRIMARY_TRAINING_DATASET_REPO, |
| "MarisUK/maris-ai-evals", |
| "MarisUK/maris-ai-benchmark", |
| ] |
| DEFAULT_EVAL_DATASET_REPOS: list[str] = [ |
| "MarisUK/maris-ai-evals", |
| "MarisUK/maris-ai-benchmark", |
| ] |
| AVAILABLE_TRAINING_BASE_MODELS: dict[str, dict[str, str]] = { |
| "balanced": { |
| "model_name": DEFAULT_TRAINING_BASE_MODEL, |
| "label": "Balanced default", |
| "description": "Galvenais Maris master modelis pilnam instruction fine-tuning ciklam.", |
| }, |
| "reasoning": { |
| "model_name": DEFAULT_TRAINING_BASE_MODEL, |
| "label": "Master reasoning", |
| "description": "Maris master modelis uzdevumiem, kur svarīga vispārējā reasoning kvalitāte.", |
| }, |
| "coding": { |
| "model_name": "MarisUK/maris-ai-text", |
| "label": "Text specialist", |
| "description": "Maris text modelis čata, instrukciju un tehniska teksta fine-tuning skrējieniem.", |
| }, |
| "lightweight": { |
| "model_name": "MarisUK/maris-ai-text", |
| "label": "Lean text runtime", |
| "description": "Teksta runtime modelis lētākiem vai ātrākiem eksperimentiem Maris ekosistēmā.", |
| }, |
| } |
| DEFAULT_BENCHMARK_NAME = "chat-quality" |
| DEFAULT_BRANCH_BENCHMARK_NAMES: dict[str, str] = { |
| "master": "memory-quality", |
| "coder": "coder-release-quality", |
| "planner": "planner-release-quality", |
| } |
|
|
| DEFAULT_BRANCH_BENCHMARK_TARGETS: dict[str, dict[str, float]] = { |
| "master": { |
| "overall": 0.78, |
| "reasoning": 0.74, |
| "factuality": 0.72, |
| "helpfulness": 0.76, |
| "latvian_quality": 0.74, |
| "memory_retrieval_pass_rate": 0.8, |
| "memory_multi_turn_continuity": 0.74, |
| "memory_cross_session_recall": 0.72, |
| "memory_user_preferences_recall": 0.76, |
| "memory_cross_lingual_retrieval": 0.72, |
| "memory_stale_memory_rejection": 0.8, |
| "production_like_pass_rate": 0.75, |
| }, |
| "coder": { |
| "overall": 0.74, |
| "coding": 0.78, |
| "reasoning": 0.72, |
| "execution": 0.7, |
| "grounding": 0.74, |
| "safety": 0.9, |
| "production_like_pass_rate": 0.75, |
| }, |
| "planner": { |
| "overall": 0.76, |
| "reasoning": 0.77, |
| "helpfulness": 0.74, |
| "long_context": 0.72, |
| "grounding": 0.72, |
| "safety": 0.9, |
| "production_like_pass_rate": 0.75, |
| }, |
| } |
| DEFAULT_SOURCE_WEIGHT_MAP: dict[str, float] = { |
| "production": 1.3, |
| "synthetic": 1.0, |
| "noisy": 0.65, |
| "unknown": 1.0, |
| } |
| DEFAULT_CATEGORY_WEIGHT_MAP: dict[str, float] = {} |
| EVALS_DIR = Path(__file__).resolve().parents[2] / "evals" |
| DEFAULT_BRANCH_DATASET_FILTER_RULES_PATH = EVALS_DIR / "branch_dataset_filter_rules.json" |
|
|
|
|
| def _load_default_branch_config() -> dict[str, Any]: |
| raw = json.loads(DEFAULT_BRANCH_DATASET_FILTER_RULES_PATH.read_text(encoding="utf-8")) |
| if not isinstance(raw, dict): |
| raise ValueError("Branch dataset defaults failam jābūt JSON objektam.") |
| return raw |
|
|
|
|
| def _resolve_default_branch_config_path(path: Any) -> str: |
| candidate = str(path or "").strip() |
| if not candidate: |
| return "" |
| resolved = Path(candidate) |
| if not resolved.is_absolute(): |
| resolved = DEFAULT_BRANCH_DATASET_FILTER_RULES_PATH.parent / resolved |
| return str(resolved.resolve()) |
|
|
|
|
| def _normalize_default_branch_path_map(value: Any) -> dict[str, str]: |
| if not isinstance(value, dict): |
| raise ValueError("Branch path defaults jābūt objektam ar branch -> dataset path.") |
| normalized: dict[str, str] = {} |
| for branch_name, path in value.items(): |
| resolved = _resolve_default_branch_config_path(path) |
| if resolved: |
| normalized[str(branch_name)] = resolved |
| return normalized |
|
|
|
|
| def _normalize_default_branch_rule_map(value: Any) -> dict[str, dict[str, Any]]: |
| if not isinstance(value, dict): |
| raise ValueError("branch_dataset_filter_rules noklusējumam jābūt objektam.") |
| return { |
| str(branch): dict(payload) for branch, payload in value.items() if isinstance(payload, dict) |
| } |
|
|
|
|
| DEFAULT_BRANCH_CONFIG = _load_default_branch_config() |
| DEFAULT_BRANCH_BENCHMARK_DATASET_PATHS: dict[str, str] = _normalize_default_branch_path_map( |
| DEFAULT_BRANCH_CONFIG.get("branch_benchmark_dataset_paths", {}) |
| ) |
| DEFAULT_BRANCH_BENCHMARK_NAMES: dict[str, str] = { |
| **DEFAULT_BRANCH_BENCHMARK_NAMES, |
| **{ |
| str(branch): str(name).strip() |
| for branch, name in DEFAULT_BRANCH_CONFIG.get("branch_benchmark_names", {}).items() |
| if str(name).strip() |
| }, |
| } |
| DEFAULT_BRANCH_PREFERENCE_DATASET_PATHS: dict[str, str] = _normalize_default_branch_path_map( |
| DEFAULT_BRANCH_CONFIG.get("branch_preference_dataset_paths", {}) |
| ) |
| DEFAULT_BRANCH_DATASET_FILTER_RULES = _normalize_default_branch_rule_map( |
| DEFAULT_BRANCH_CONFIG.get("branch_dataset_filter_rules", DEFAULT_BRANCH_CONFIG) |
| ) |
|
|
|
|
| def _parse_bool(value: Any, *, default: bool) -> bool: |
| if value is None: |
| return default |
| if isinstance(value, bool): |
| return value |
| return str(value).strip().lower() in {"1", "true", "yes", "on"} |
|
|
|
|
| def _parse_optional_bool(value: Any) -> bool | None: |
| if value is None: |
| return None |
| if isinstance(value, str) and not value.strip(): |
| return None |
| return _parse_bool(value, default=False) |
|
|
|
|
| def _parse_list(value: Any) -> list[str]: |
| if value is None: |
| return [] |
| if isinstance(value, list): |
| return [str(item) for item in value if str(item).strip()] |
| return [item.strip() for item in str(value).split(",") if item.strip()] |
|
|
|
|
| def _parse_repo_list(value: Any, *, default: list[str] | None = None) -> list[str]: |
| if value in (None, ""): |
| return list(default or []) |
| parsed = ( |
| json.loads(value) if isinstance(value, str) and value.lstrip().startswith("[") else value |
| ) |
| raw_items = parsed if isinstance(parsed, list) else EXTRA_MODEL_SPLIT_RE.split(str(parsed)) |
| normalized: list[str] = [] |
| for item in raw_items: |
| candidate = str(item or "").strip() |
| if candidate and candidate not in normalized: |
| normalized.append(candidate) |
| return normalized |
|
|
|
|
| @dataclass(slots=True) |
| class TrainingConfig: |
| """Pilna apmācības konfigurācija vienam Maris treniņa skrējienam.""" |
|
|
| model_name: str = DEFAULT_TRAINING_BASE_MODEL |
| model_preset: str = "" |
| branch_name: str = "master" |
| branch_focus: str = "general_reasoning" |
| adapter_type: str = "full" |
| lora_r: int = 16 |
| lora_alpha: int = 32 |
| lora_dropout: float = 0.05 |
| lora_bias: str = "none" |
| peft_target_modules: list[str] = field(default_factory=list) |
| qlora_quant_type: str = "nf4" |
| qlora_use_double_quant: bool = True |
| qlora_compute_dtype: str = "float16" |
| dataset_repo: str = DEFAULT_PRIMARY_TRAINING_DATASET_REPO |
| dataset_repos: list[str] = field(default_factory=list) |
| eval_dataset_repo: str = "" |
| eval_dataset_repos: list[str] = field(default_factory=list) |
| output_dir: str = "./output/model" |
| hub_model_id: str = DEFAULT_MASTER_MODEL_REPO |
| text_model_id: str = DEFAULT_TEXT_MODEL_REPO |
| image_model_id: str = DEFAULT_IMAGE_MODEL_REPO |
| music_model_id: str = DEFAULT_MUSIC_MODEL_REPO |
| tts_model_id: str = DEFAULT_TTS_MODEL_REPO |
| stt_model_id: str = DEFAULT_STT_MODEL_REPO |
| video_model_id: str = DEFAULT_VIDEO_MODEL_REPO |
| num_epochs: int = 3 |
| learning_rate: float = 2e-5 |
| per_device_train_batch_size: int = 1 |
| per_device_eval_batch_size: int = 1 |
| gradient_accumulation_steps: int = 8 |
| warmup_ratio: float = 0.1 |
| weight_decay: float = 0.01 |
| logging_steps: int = 10 |
| save_steps: int = 100 |
| eval_steps: int = 100 |
| save_total_limit: int = 2 |
| max_seq_length: int = 1024 |
| validation_split_ratio: float = 0.1 |
| seed: int = 42 |
| fp16: bool = False |
| bf16: bool = False |
| gradient_checkpointing: bool = False |
| gradient_checkpointing_use_reentrant: bool | None = None |
| distributed_strategy: str = "none" |
| distributed_config_path: str = "" |
| use_accelerate: bool = False |
| accelerate_config_path: str = "" |
| num_processes: int = 1 |
| num_machines: int = 1 |
| machine_rank: int = 0 |
| main_process_ip: str = "" |
| main_process_port: int = 29500 |
| fsdp_transformer_layer_cls_to_wrap: list[str] = field(default_factory=list) |
| fsdp_min_num_params: int = 100_000_000 |
| report_to: list[str] = field(default_factory=list) |
| push_to_hub: bool = False |
| save_safetensors: bool = True |
| lr_scheduler_type: str = "cosine" |
| benchmark_dataset_path: str = "" |
| benchmark_name: str = DEFAULT_BENCHMARK_NAME |
| benchmark_levels: list[str] = field(default_factory=lambda: ["local", "ci", "release"]) |
| benchmark_min_overall: float = 0.7 |
| benchmark_gate_enabled: bool = False |
| branch_benchmark_targets: dict[str, dict[str, float]] = field( |
| default_factory=lambda: { |
| key: value.copy() for key, value in DEFAULT_BRANCH_BENCHMARK_TARGETS.items() |
| } |
| ) |
| branch_benchmark_names: dict[str, str] = field( |
| default_factory=lambda: DEFAULT_BRANCH_BENCHMARK_NAMES.copy() |
| ) |
| branch_benchmark_dataset_paths: dict[str, str] = field( |
| default_factory=lambda: DEFAULT_BRANCH_BENCHMARK_DATASET_PATHS.copy() |
| ) |
| branch_preference_dataset_paths: dict[str, str] = field( |
| default_factory=lambda: DEFAULT_BRANCH_PREFERENCE_DATASET_PATHS.copy() |
| ) |
| branch_dataset_filter_rules: dict[str, dict[str, Any]] = field( |
| default_factory=lambda: { |
| key: value.copy() for key, value in DEFAULT_BRANCH_DATASET_FILTER_RULES.items() |
| } |
| ) |
| preference_dataset_path: str = "" |
| preference_optimization: str = "none" |
| preference_beta: float = 0.1 |
| preference_max_prompt_length: int = 512 |
| preference_max_length: int = 1024 |
| preference_reference_model: str = "" |
| quality_gate_enabled: bool = True |
| dedupe_enabled: bool = True |
| quality_min_text_chars: int = 4 |
| scoring_enabled: bool = True |
| weighted_repetition_enabled: bool = True |
| medium_score_repeat_count: int = 2 |
| high_score_repeat_count: int = 3 |
| source_weighting_enabled: bool = True |
| source_weight_map: dict[str, float] = field( |
| default_factory=lambda: DEFAULT_SOURCE_WEIGHT_MAP.copy() |
| ) |
| category_weight_map: dict[str, float] = field( |
| default_factory=lambda: DEFAULT_CATEGORY_WEIGHT_MAP.copy() |
| ) |
| max_effective_repeat_count: int = 6 |
| benchmark_feedback_enabled: bool = True |
| benchmark_feedback_auto_discover: bool = True |
| benchmark_feedback_path: str = "" |
| benchmark_feedback_boost_scale: float = 2.0 |
| benchmark_feedback_max_multiplier: float = 1.75 |
| continue_from_latest_artifact: bool = False |
| continue_model_path: str = "" |
|
|
| def to_dict(self) -> dict[str, Any]: |
| """Serializē konfigurāciju uz dict.""" |
| return asdict(self) |
|
|
|
|
| def list_training_base_models() -> dict[str, dict[str, str]]: |
| """Atgriež iepriekš definētos bāzes modeļu presetus.""" |
| models = {key: value.copy() for key, value in AVAILABLE_TRAINING_BASE_MODELS.items()} |
| models.update(_load_extra_training_base_models()) |
| return models |
|
|
|
|
| def _normalize_extra_training_base_model_payload( |
| preset_key: str, |
| payload: Any, |
| ) -> dict[str, str] | None: |
| if isinstance(payload, str): |
| model_name = payload.strip() |
| label = "" |
| description = "" |
| elif isinstance(payload, dict): |
| model_name = str(payload.get("model_name", "") or "").strip() |
| label = str(payload.get("label", "") or "").strip() |
| description = str(payload.get("description", "") or "").strip() |
| else: |
| logger.warning( |
| "Ignoring extra training preset %r because payload must be an object or model string.", |
| preset_key, |
| ) |
| return None |
|
|
| model_parts = model_name.split("/", 1) |
| if len(model_parts) != 2 or not all(part.strip() for part in model_parts): |
| logger.warning( |
| "Ignoring extra training preset %r because model_name must use owner/name format.", |
| preset_key, |
| ) |
| return None |
| if not label: |
| label = preset_key.replace("-", " ").replace("_", " ").title() |
| if not description: |
| description = f"External base model preset {model_name}." |
|
|
| return { |
| "model_name": model_name, |
| "label": label, |
| "description": description, |
| } |
|
|
|
|
| def _load_extra_training_base_models() -> dict[str, dict[str, str]]: |
| raw_value = get_env_any("MARIS_TRAIN_EXTRA_MODELS", "HF_TRAIN_EXTRA_MODELS", default="") or "" |
| normalized = raw_value.strip() |
| if not normalized: |
| return {} |
|
|
| try: |
| parsed = json.loads(normalized) |
| except json.JSONDecodeError as exc: |
| fallback_models = _parse_extra_training_base_models_fallback(normalized) |
| if fallback_models: |
| logger.info( |
| "Parsed MARIS_TRAIN_EXTRA_MODELS/HF_TRAIN_EXTRA_MODELS using owner/name fallback syntax." |
| ) |
| return fallback_models |
| logger.warning( |
| "Ignoring MARIS_TRAIN_EXTRA_MODELS/HF_TRAIN_EXTRA_MODELS because value is not valid JSON or supported fallback syntax.", |
| exc_info=exc, |
| ) |
| return {} |
|
|
| normalized_payloads = _coerce_extra_training_base_models_payload(parsed) |
| if normalized_payloads is None: |
| logger.warning( |
| "Ignoring MARIS_TRAIN_EXTRA_MODELS/HF_TRAIN_EXTRA_MODELS because top-level value must be a JSON object, JSON array, or owner/name fallback list." |
| ) |
| return {} |
|
|
| result: dict[str, dict[str, str]] = {} |
| for preset_name, payload in normalized_payloads.items(): |
| preset_key = str(preset_name).strip() |
| if not preset_key: |
| logger.warning("Ignoring extra training preset with empty name.") |
| continue |
|
|
| normalized_payload = _normalize_extra_training_base_model_payload(preset_key, payload) |
| if normalized_payload is None: |
| continue |
|
|
| result[preset_key] = normalized_payload |
|
|
| return result |
|
|
|
|
| def _coerce_extra_training_base_models_payload(parsed: Any) -> dict[str, Any] | None: |
| if isinstance(parsed, dict): |
| return parsed |
| if not isinstance(parsed, list): |
| return None |
|
|
| coerced: dict[str, Any] = {} |
| for payload in parsed: |
| if isinstance(payload, str): |
| preset_key = _build_extra_training_preset_key(payload, existing=coerced) |
| coerced[preset_key] = payload |
| continue |
| if isinstance(payload, dict): |
| preset_key = str(payload.get("preset") or payload.get("key") or "").strip() |
| model_name = str(payload.get("model_name") or payload.get("model") or "").strip() |
| if not model_name: |
| logger.warning( |
| "Ignoring extra training preset list item because model_name is missing." |
| ) |
| continue |
| if not preset_key: |
| preset_key = _build_extra_training_preset_key(model_name, existing=coerced) |
| normalized_payload = { |
| "model_name": model_name, |
| "label": str(payload.get("label", "") or "").strip(), |
| "description": str(payload.get("description", "") or "").strip(), |
| } |
| coerced[preset_key] = normalized_payload |
| continue |
| logger.warning( |
| "Ignoring extra training preset list item %r because it must be a string or object.", |
| payload, |
| ) |
| return coerced |
|
|
|
|
| def _parse_extra_training_base_models_fallback(raw_value: str) -> dict[str, dict[str, str]]: |
| result: dict[str, dict[str, str]] = {} |
| candidates = [item.strip() for item in EXTRA_MODEL_SPLIT_RE.split(raw_value) if item.strip()] |
| for candidate in candidates: |
| preset_key = "" |
| model_name = candidate |
| if "=" in candidate: |
| preset_key, model_name = candidate.split("=", 1) |
| preset_key = preset_key.strip() |
| model_name = model_name.strip() |
| if not model_name: |
| continue |
| if not preset_key: |
| preset_key = _build_extra_training_preset_key(model_name, existing=result) |
| normalized_payload = _normalize_extra_training_base_model_payload(preset_key, model_name) |
| if normalized_payload is None: |
| return {} |
| result[preset_key] = normalized_payload |
| return result |
|
|
|
|
| def _build_extra_training_preset_key( |
| model_name: str, |
| *, |
| existing: dict[str, Any], |
| ) -> str: |
| owner_name = model_name.strip().lower().replace("/", "-") |
| base_key = EXTRA_MODEL_KEY_SANITIZE_RE.sub("-", owner_name).strip("-") or "extra-model" |
| candidate = base_key |
| suffix = 2 |
| while candidate in existing: |
| candidate = f"{base_key}-{suffix}" |
| suffix += 1 |
| return candidate |
|
|
|
|
| def resolve_training_model( |
| model_name: str, |
| model_preset: str | None, |
| *, |
| available_models: dict[str, dict[str, str]] | None = None, |
| ) -> str: |
| """Atrisina modeļa preset uz konkrētu bāzes modeli.""" |
| normalized_preset = (model_preset or "").strip() |
| if not normalized_preset: |
| return model_name |
|
|
| resolved_models = available_models or list_training_base_models() |
| preset = resolved_models.get(normalized_preset) |
| if preset is None: |
| available = ", ".join(sorted(resolved_models)) |
| raise ValueError( |
| f"Nezināms MARIS_TRAIN_MODEL_PRESET/model_preset. Izmanto vienu no: {available}." |
| ) |
| return preset["model_name"] |
|
|
|
|
| def resolve_model_selection( |
| default_model_name: str, |
| *sources: dict[str, Any], |
| available_models: dict[str, dict[str, str]] | None = None, |
| ) -> tuple[str, str]: |
| """Atrod efektīvo modeli no augstākās prioritātes avota. |
| |
| Katrā avotā tiešs `model_name` ir prioritārāks par `model_preset`, jo tas |
| ir precīzāks override nekā presets. |
| """ |
| resolved_model_name = default_model_name |
| resolved_model_preset = "" |
|
|
| for source in sources: |
| source_model_name = source.get("model_name") |
| source_model_preset = source.get("model_preset") |
| if source_model_name not in (None, ""): |
| return str(source_model_name), "" |
| if source_model_preset not in (None, ""): |
| resolved_model_preset = str(source_model_preset) |
| resolved_model_name = resolve_training_model( |
| default_model_name, |
| resolved_model_preset, |
| available_models=available_models, |
| ) |
| return resolved_model_name, resolved_model_preset |
|
|
| return resolved_model_name, resolved_model_preset |
|
|
|
|
| def load_training_config( |
| config_path: str | None = None, |
| overrides: dict[str, Any] | None = None, |
| ) -> TrainingConfig: |
| """Ielādē konfigurāciju no JSON, vides mainīgajiem un CLI override'iem.""" |
| data: dict[str, Any] = {} |
| defaults = TrainingConfig().to_dict() |
|
|
| if config_path: |
| data = json.loads(Path(config_path).read_text(encoding="utf-8")) |
|
|
| env_data: dict[str, Any] = { |
| "model_name": get_env_any("MARIS_TRAIN_BASE_MODEL", "HF_TRAIN_BASE_MODEL", "TEXT_MODEL"), |
| "model_preset": get_env_any("MARIS_TRAIN_MODEL_PRESET", "HF_TRAIN_MODEL_PRESET"), |
| "branch_name": get_env_any("MARIS_TRAIN_BRANCH_NAME", "HF_TRAIN_BRANCH_NAME"), |
| "branch_focus": get_env_any("MARIS_TRAIN_BRANCH_FOCUS", "HF_TRAIN_BRANCH_FOCUS"), |
| "adapter_type": get_env_any("MARIS_TRAIN_ADAPTER_TYPE", "HF_TRAIN_ADAPTER_TYPE"), |
| "lora_r": get_env_any("MARIS_TRAIN_LORA_R", "HF_TRAIN_LORA_R"), |
| "lora_alpha": get_env_any("MARIS_TRAIN_LORA_ALPHA", "HF_TRAIN_LORA_ALPHA"), |
| "lora_dropout": get_env_any("MARIS_TRAIN_LORA_DROPOUT", "HF_TRAIN_LORA_DROPOUT"), |
| "lora_bias": get_env_any("MARIS_TRAIN_LORA_BIAS", "HF_TRAIN_LORA_BIAS"), |
| "peft_target_modules": get_env_any( |
| "MARIS_TRAIN_PEFT_TARGET_MODULES", |
| "HF_TRAIN_PEFT_TARGET_MODULES", |
| ), |
| "qlora_quant_type": get_env_any( |
| "MARIS_TRAIN_QLORA_QUANT_TYPE", |
| "HF_TRAIN_QLORA_QUANT_TYPE", |
| ), |
| "qlora_use_double_quant": get_env_any( |
| "MARIS_TRAIN_QLORA_USE_DOUBLE_QUANT", |
| "HF_TRAIN_QLORA_USE_DOUBLE_QUANT", |
| ), |
| "qlora_compute_dtype": get_env_any( |
| "MARIS_TRAIN_QLORA_COMPUTE_DTYPE", |
| "HF_TRAIN_QLORA_COMPUTE_DTYPE", |
| ), |
| "dataset_repo": get_env_any("MARIS_MEMORY_REPO", "MARIS_DATASET_REPO", "HF_DATASET_REPO"), |
| "dataset_repos": get_env_any("MARIS_DATASET_REPOS", "HF_DATASET_REPOS"), |
| "eval_dataset_repo": get_env_any("MARIS_EVAL_DATASET_REPO", "HF_EVAL_DATASET_REPO"), |
| "eval_dataset_repos": get_env_any("MARIS_EVAL_DATASET_REPOS", "HF_EVAL_DATASET_REPOS"), |
| "output_dir": get_env_any("MARIS_TRAIN_OUTPUT_DIR", "HF_TRAIN_OUTPUT_DIR"), |
| "hub_model_id": get_env_any("MARIS_MODEL_REPO", "HF_MODEL_REPO"), |
| "text_model_id": get_env_any("TEXT_MODEL", default=DEFAULT_TEXT_MODEL_REPO), |
| "image_model_id": get_env_any("IMAGE_MODEL", default=DEFAULT_IMAGE_MODEL_REPO), |
| "music_model_id": get_env_any("MUSIC_MODEL", default=DEFAULT_MUSIC_MODEL_REPO), |
| "tts_model_id": get_env_any("TTS_MODEL", default=DEFAULT_TTS_MODEL_REPO), |
| "stt_model_id": get_env_any("STT_MODEL", default=DEFAULT_STT_MODEL_REPO), |
| "video_model_id": get_env_any("VIDEO_MODEL", default=DEFAULT_VIDEO_MODEL_REPO), |
| "num_epochs": get_env_any("MARIS_TRAIN_NUM_EPOCHS", "HF_TRAIN_NUM_EPOCHS"), |
| "learning_rate": get_env_any("MARIS_TRAIN_LEARNING_RATE", "HF_TRAIN_LEARNING_RATE"), |
| "per_device_train_batch_size": get_env_any("MARIS_TRAIN_BATCH_SIZE", "HF_TRAIN_BATCH_SIZE"), |
| "per_device_eval_batch_size": get_env_any( |
| "MARIS_TRAIN_EVAL_BATCH_SIZE", "HF_TRAIN_EVAL_BATCH_SIZE" |
| ), |
| "gradient_accumulation_steps": get_env_any( |
| "MARIS_TRAIN_GRADIENT_ACCUMULATION_STEPS", |
| "HF_TRAIN_GRADIENT_ACCUMULATION_STEPS", |
| ), |
| "warmup_ratio": get_env_any("MARIS_TRAIN_WARMUP_RATIO", "HF_TRAIN_WARMUP_RATIO"), |
| "weight_decay": get_env_any("MARIS_TRAIN_WEIGHT_DECAY", "HF_TRAIN_WEIGHT_DECAY"), |
| "logging_steps": get_env_any("MARIS_TRAIN_LOGGING_STEPS", "HF_TRAIN_LOGGING_STEPS"), |
| "save_steps": get_env_any("MARIS_TRAIN_SAVE_STEPS", "HF_TRAIN_SAVE_STEPS"), |
| "eval_steps": get_env_any("MARIS_TRAIN_EVAL_STEPS", "HF_TRAIN_EVAL_STEPS"), |
| "save_total_limit": get_env_any( |
| "MARIS_TRAIN_SAVE_TOTAL_LIMIT", "HF_TRAIN_SAVE_TOTAL_LIMIT" |
| ), |
| "max_seq_length": get_env_any("MARIS_TRAIN_MAX_SEQ_LENGTH", "HF_TRAIN_MAX_SEQ_LENGTH"), |
| "validation_split_ratio": get_env_any( |
| "MARIS_TRAIN_VALIDATION_SPLIT", "HF_TRAIN_VALIDATION_SPLIT" |
| ), |
| "seed": get_env_any("MARIS_TRAIN_SEED", "HF_TRAIN_SEED"), |
| "fp16": get_env_any("MARIS_TRAIN_FP16", "HF_TRAIN_FP16"), |
| "bf16": get_env_any("MARIS_TRAIN_BF16", "HF_TRAIN_BF16"), |
| "gradient_checkpointing": get_env_any( |
| "MARIS_TRAIN_GRADIENT_CHECKPOINTING", |
| "HF_TRAIN_GRADIENT_CHECKPOINTING", |
| ), |
| "gradient_checkpointing_use_reentrant": get_env_any( |
| "MARIS_TRAIN_GRADIENT_CHECKPOINTING_USE_REENTRANT", |
| "HF_TRAIN_GRADIENT_CHECKPOINTING_USE_REENTRANT", |
| ), |
| "distributed_strategy": get_env_any( |
| "MARIS_TRAIN_DISTRIBUTED_STRATEGY", |
| "HF_TRAIN_DISTRIBUTED_STRATEGY", |
| ), |
| "distributed_config_path": get_env_any( |
| "MARIS_TRAIN_DISTRIBUTED_CONFIG_PATH", |
| "HF_TRAIN_DISTRIBUTED_CONFIG_PATH", |
| ), |
| "use_accelerate": get_env_any( |
| "MARIS_TRAIN_USE_ACCELERATE", |
| "HF_TRAIN_USE_ACCELERATE", |
| ), |
| "accelerate_config_path": get_env_any( |
| "MARIS_TRAIN_ACCELERATE_CONFIG_PATH", |
| "HF_TRAIN_ACCELERATE_CONFIG_PATH", |
| ), |
| "num_processes": get_env_any("MARIS_TRAIN_NUM_PROCESSES", "HF_TRAIN_NUM_PROCESSES"), |
| "num_machines": get_env_any("MARIS_TRAIN_NUM_MACHINES", "HF_TRAIN_NUM_MACHINES"), |
| "machine_rank": get_env_any("MARIS_TRAIN_MACHINE_RANK", "HF_TRAIN_MACHINE_RANK"), |
| "main_process_ip": get_env_any( |
| "MARIS_TRAIN_MAIN_PROCESS_IP", |
| "HF_TRAIN_MAIN_PROCESS_IP", |
| ), |
| "main_process_port": get_env_any( |
| "MARIS_TRAIN_MAIN_PROCESS_PORT", |
| "HF_TRAIN_MAIN_PROCESS_PORT", |
| ), |
| "fsdp_transformer_layer_cls_to_wrap": get_env_any( |
| "MARIS_TRAIN_FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP", |
| "HF_TRAIN_FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP", |
| ), |
| "fsdp_min_num_params": get_env_any( |
| "MARIS_TRAIN_FSDP_MIN_NUM_PARAMS", |
| "HF_TRAIN_FSDP_MIN_NUM_PARAMS", |
| ), |
| "report_to": get_env_any("MARIS_TRAIN_REPORT_TO", "HF_TRAIN_REPORT_TO"), |
| "push_to_hub": get_env_any("MARIS_TRAIN_PUBLISH", "HF_TRAIN_PUSH_TO_HUB"), |
| "save_safetensors": get_env_any( |
| "MARIS_TRAIN_SAVE_SAFETENSORS", "HF_TRAIN_SAVE_SAFETENSORS" |
| ), |
| "lr_scheduler_type": get_env_any( |
| "MARIS_TRAIN_LR_SCHEDULER_TYPE", "HF_TRAIN_LR_SCHEDULER_TYPE" |
| ), |
| "benchmark_dataset_path": get_env_any( |
| "MARIS_BENCHMARK_DATASET_PATH", "HF_BENCHMARK_DATASET_PATH" |
| ), |
| "benchmark_name": get_env_any("MARIS_BENCHMARK_NAME", "HF_BENCHMARK_NAME"), |
| "benchmark_levels": get_env_any("MARIS_BENCHMARK_LEVELS", "HF_BENCHMARK_LEVELS"), |
| "benchmark_min_overall": get_env_any( |
| "MARIS_BENCHMARK_MIN_OVERALL", "HF_BENCHMARK_MIN_OVERALL" |
| ), |
| "benchmark_gate_enabled": get_env_any( |
| "MARIS_BENCHMARK_GATE_ENABLED", "HF_BENCHMARK_GATE_ENABLED" |
| ), |
| "branch_benchmark_names": get_env_any( |
| "MARIS_BRANCH_BENCHMARK_NAMES", |
| "HF_BRANCH_BENCHMARK_NAMES", |
| ), |
| "branch_benchmark_dataset_paths": get_env_any( |
| "MARIS_BRANCH_BENCHMARK_DATASET_PATHS", |
| "HF_BRANCH_BENCHMARK_DATASET_PATHS", |
| ), |
| "branch_preference_dataset_paths": get_env_any( |
| "MARIS_BRANCH_PREFERENCE_DATASET_PATHS", |
| "HF_BRANCH_PREFERENCE_DATASET_PATHS", |
| ), |
| "branch_dataset_filter_rules": get_env_any( |
| "MARIS_BRANCH_DATASET_FILTER_RULES", |
| "HF_BRANCH_DATASET_FILTER_RULES", |
| ), |
| "preference_dataset_path": get_env_any( |
| "MARIS_PREFERENCE_DATASET_PATH", "HF_PREFERENCE_DATASET_PATH" |
| ), |
| "preference_optimization": get_env_any( |
| "MARIS_PREFERENCE_OPTIMIZATION", |
| "HF_PREFERENCE_OPTIMIZATION", |
| ), |
| "preference_beta": get_env_any("MARIS_PREFERENCE_BETA", "HF_PREFERENCE_BETA"), |
| "preference_max_prompt_length": get_env_any( |
| "MARIS_PREFERENCE_MAX_PROMPT_LENGTH", |
| "HF_PREFERENCE_MAX_PROMPT_LENGTH", |
| ), |
| "preference_max_length": get_env_any( |
| "MARIS_PREFERENCE_MAX_LENGTH", |
| "HF_PREFERENCE_MAX_LENGTH", |
| ), |
| "preference_reference_model": get_env_any( |
| "MARIS_PREFERENCE_REFERENCE_MODEL", |
| "HF_PREFERENCE_REFERENCE_MODEL", |
| ), |
| "quality_gate_enabled": get_env_any( |
| "MARIS_TRAIN_QUALITY_GATE_ENABLED", |
| "HF_TRAIN_QUALITY_GATE_ENABLED", |
| ), |
| "dedupe_enabled": get_env_any( |
| "MARIS_TRAIN_DEDUPE_ENABLED", |
| "HF_TRAIN_DEDUPE_ENABLED", |
| ), |
| "quality_min_text_chars": get_env_any( |
| "MARIS_TRAIN_QUALITY_MIN_CHARS", |
| "HF_TRAIN_QUALITY_MIN_CHARS", |
| ), |
| "scoring_enabled": get_env_any( |
| "MARIS_TRAIN_SCORING_ENABLED", |
| "HF_TRAIN_SCORING_ENABLED", |
| ), |
| "weighted_repetition_enabled": get_env_any( |
| "MARIS_TRAIN_WEIGHTED_REPETITION_ENABLED", |
| "HF_TRAIN_WEIGHTED_REPETITION_ENABLED", |
| ), |
| "medium_score_repeat_count": get_env_any( |
| "MARIS_TRAIN_MEDIUM_SCORE_REPEAT_COUNT", |
| "HF_TRAIN_MEDIUM_SCORE_REPEAT_COUNT", |
| ), |
| "high_score_repeat_count": get_env_any( |
| "MARIS_TRAIN_HIGH_SCORE_REPEAT_COUNT", |
| "HF_TRAIN_HIGH_SCORE_REPEAT_COUNT", |
| ), |
| "source_weighting_enabled": get_env_any( |
| "MARIS_TRAIN_SOURCE_WEIGHTING_ENABLED", |
| "HF_TRAIN_SOURCE_WEIGHTING_ENABLED", |
| ), |
| "source_weight_map": get_env_any( |
| "MARIS_TRAIN_SOURCE_WEIGHT_MAP", |
| "HF_TRAIN_SOURCE_WEIGHT_MAP", |
| ), |
| "category_weight_map": get_env_any( |
| "MARIS_TRAIN_CATEGORY_WEIGHT_MAP", |
| "HF_TRAIN_CATEGORY_WEIGHT_MAP", |
| ), |
| "max_effective_repeat_count": get_env_any( |
| "MARIS_TRAIN_MAX_EFFECTIVE_REPEAT_COUNT", |
| "HF_TRAIN_MAX_EFFECTIVE_REPEAT_COUNT", |
| ), |
| "benchmark_feedback_enabled": get_env_any( |
| "MARIS_TRAIN_BENCHMARK_FEEDBACK_ENABLED", |
| "HF_TRAIN_BENCHMARK_FEEDBACK_ENABLED", |
| ), |
| "benchmark_feedback_auto_discover": get_env_any( |
| "MARIS_TRAIN_BENCHMARK_FEEDBACK_AUTO_DISCOVER", |
| "HF_TRAIN_BENCHMARK_FEEDBACK_AUTO_DISCOVER", |
| ), |
| "benchmark_feedback_path": get_env_any( |
| "MARIS_TRAIN_BENCHMARK_FEEDBACK_PATH", |
| "HF_TRAIN_BENCHMARK_FEEDBACK_PATH", |
| ), |
| "benchmark_feedback_boost_scale": get_env_any( |
| "MARIS_TRAIN_BENCHMARK_FEEDBACK_BOOST_SCALE", |
| "HF_TRAIN_BENCHMARK_FEEDBACK_BOOST_SCALE", |
| ), |
| "benchmark_feedback_max_multiplier": get_env_any( |
| "MARIS_TRAIN_BENCHMARK_FEEDBACK_MAX_MULTIPLIER", |
| "HF_TRAIN_BENCHMARK_FEEDBACK_MAX_MULTIPLIER", |
| ), |
| "continue_from_latest_artifact": get_env_any( |
| "MARIS_TRAIN_CONTINUE_FROM_LATEST", |
| "HF_TRAIN_CONTINUE_FROM_LATEST", |
| ), |
| "continue_model_path": get_env_any( |
| "MARIS_TRAIN_CONTINUE_MODEL_PATH", |
| "HF_TRAIN_CONTINUE_MODEL_PATH", |
| "MARIS_LOCAL_MODEL_DIR", |
| "HF_LOCAL_MODEL_DIR", |
| ), |
| } |
| env_overrides = {key: value for key, value in env_data.items() if value not in (None, "")} |
| cli_overrides = {key: value for key, value in (overrides or {}).items() if value is not None} |
|
|
| merged: dict[str, Any] = {} |
| merged.update(defaults) |
| merged.update(data) |
| merged.update(env_overrides) |
| merged.update(cli_overrides) |
| resolved_distributed_strategy = str( |
| merged.get("distributed_strategy", "none") or "none" |
| ).lower() |
| explicit_use_accelerate = next( |
| ( |
| source["use_accelerate"] |
| for source in (cli_overrides, env_overrides, data) |
| if source.get("use_accelerate") not in (None, "") |
| ), |
| None, |
| ) |
|
|
| available_base_models = list_training_base_models() |
| resolved_model_name, resolved_model_preset = resolve_model_selection( |
| str(defaults["model_name"]), |
| cli_overrides, |
| env_overrides, |
| data, |
| available_models=available_base_models, |
| ) |
|
|
| config = TrainingConfig( |
| model_name=resolved_model_name, |
| model_preset=resolved_model_preset, |
| branch_name=str(merged["branch_name"]), |
| branch_focus=str(merged["branch_focus"]), |
| adapter_type=str(merged["adapter_type"]), |
| lora_r=int(merged.get("lora_r", 16)), |
| lora_alpha=int(merged.get("lora_alpha", 32)), |
| lora_dropout=float(merged.get("lora_dropout", 0.05)), |
| lora_bias=str(merged.get("lora_bias", "none") or "none"), |
| peft_target_modules=_parse_list(merged.get("peft_target_modules")), |
| qlora_quant_type=str(merged.get("qlora_quant_type", "nf4") or "nf4"), |
| qlora_use_double_quant=_parse_bool( |
| merged.get("qlora_use_double_quant"), |
| default=True, |
| ), |
| qlora_compute_dtype=str(merged.get("qlora_compute_dtype") or "float16").lower(), |
| dataset_repo=validate_maris_repo( |
| str(merged["dataset_repo"]), |
| "MARIS_MEMORY_REPO/MARIS_DATASET_REPO/HF_DATASET_REPO/dataset_repo", |
| label="dataset repozitorijs", |
| ), |
| dataset_repos=[], |
| eval_dataset_repo=( |
| validate_maris_repo( |
| str(merged["eval_dataset_repo"]), |
| "MARIS_EVAL_DATASET_REPO/HF_EVAL_DATASET_REPO/eval_dataset_repo", |
| label="eval dataset repozitorijs", |
| ) |
| if merged.get("eval_dataset_repo") not in (None, "") |
| else "" |
| ), |
| eval_dataset_repos=[], |
| output_dir=str(merged["output_dir"]), |
| hub_model_id=validate_maris_model( |
| str(merged["hub_model_id"]), |
| "MARIS_MODEL_REPO/HF_MODEL_REPO/hub_model_id", |
| ), |
| text_model_id=validate_maris_model( |
| str(merged["text_model_id"]), |
| "TEXT_MODEL/text_model_id", |
| ), |
| image_model_id=validate_maris_model( |
| str(merged["image_model_id"]), |
| "IMAGE_MODEL/image_model_id", |
| ), |
| music_model_id=validate_maris_model( |
| str(merged["music_model_id"]), |
| "MUSIC_MODEL/music_model_id", |
| ), |
| tts_model_id=validate_maris_model( |
| str(merged["tts_model_id"]), |
| "TTS_MODEL/tts_model_id", |
| ), |
| stt_model_id=validate_maris_model( |
| str(merged["stt_model_id"]), |
| "STT_MODEL/stt_model_id", |
| ), |
| video_model_id=validate_maris_model( |
| str(merged["video_model_id"]), |
| "VIDEO_MODEL/video_model_id", |
| ), |
| num_epochs=int(merged["num_epochs"]), |
| learning_rate=float(merged["learning_rate"]), |
| per_device_train_batch_size=int(merged["per_device_train_batch_size"]), |
| per_device_eval_batch_size=int(merged["per_device_eval_batch_size"]), |
| gradient_accumulation_steps=int(merged["gradient_accumulation_steps"]), |
| warmup_ratio=float(merged["warmup_ratio"]), |
| weight_decay=float(merged["weight_decay"]), |
| logging_steps=int(merged["logging_steps"]), |
| save_steps=int(merged["save_steps"]), |
| eval_steps=int(merged["eval_steps"]), |
| save_total_limit=int(merged["save_total_limit"]), |
| max_seq_length=int(merged["max_seq_length"]), |
| validation_split_ratio=float(merged["validation_split_ratio"]), |
| seed=int(merged["seed"]), |
| fp16=_parse_bool(merged.get("fp16"), default=False), |
| bf16=_parse_bool(merged.get("bf16"), default=False), |
| gradient_checkpointing=_parse_bool(merged.get("gradient_checkpointing"), default=False), |
| gradient_checkpointing_use_reentrant=_parse_optional_bool( |
| merged.get("gradient_checkpointing_use_reentrant") |
| ), |
| distributed_strategy=resolved_distributed_strategy, |
| distributed_config_path=str(merged.get("distributed_config_path", "") or ""), |
| use_accelerate=_parse_bool( |
| explicit_use_accelerate, |
| default=resolved_distributed_strategy != "none", |
| ), |
| accelerate_config_path=str(merged.get("accelerate_config_path", "") or ""), |
| num_processes=int(merged.get("num_processes", 1)), |
| num_machines=int(merged.get("num_machines", 1)), |
| machine_rank=int(merged.get("machine_rank", 0)), |
| main_process_ip=str(merged.get("main_process_ip", "") or ""), |
| main_process_port=int(merged.get("main_process_port", 29500)), |
| fsdp_transformer_layer_cls_to_wrap=_parse_list( |
| merged.get("fsdp_transformer_layer_cls_to_wrap") |
| ), |
| fsdp_min_num_params=int(merged.get("fsdp_min_num_params", 100_000_000)), |
| report_to=_parse_list(merged.get("report_to")), |
| push_to_hub=_parse_bool(merged.get("push_to_hub"), default=False), |
| save_safetensors=_parse_bool(merged.get("save_safetensors"), default=True), |
| lr_scheduler_type=str(merged["lr_scheduler_type"]), |
| benchmark_dataset_path=str(merged.get("benchmark_dataset_path", "") or ""), |
| benchmark_name=str( |
| merged.get("benchmark_name", DEFAULT_BENCHMARK_NAME) or DEFAULT_BENCHMARK_NAME |
| ), |
| benchmark_levels=_parse_list(merged.get("benchmark_levels")) or ["local", "ci", "release"], |
| benchmark_min_overall=float(merged.get("benchmark_min_overall", 0.7)), |
| benchmark_gate_enabled=_parse_bool(merged.get("benchmark_gate_enabled"), default=False), |
| branch_benchmark_targets=_parse_branch_targets(merged.get("branch_benchmark_targets")), |
| branch_benchmark_names=_parse_branch_benchmark_names(merged.get("branch_benchmark_names")), |
| branch_benchmark_dataset_paths=_parse_branch_benchmark_dataset_paths( |
| merged.get("branch_benchmark_dataset_paths") |
| ), |
| branch_preference_dataset_paths=_parse_branch_preference_dataset_paths( |
| merged.get("branch_preference_dataset_paths") |
| ), |
| branch_dataset_filter_rules=_parse_branch_dataset_filter_rules( |
| merged.get("branch_dataset_filter_rules") |
| ), |
| preference_dataset_path=str(merged.get("preference_dataset_path", "") or ""), |
| preference_optimization=str( |
| merged.get("preference_optimization", "none") or "none" |
| ).lower(), |
| preference_beta=float(merged.get("preference_beta", 0.1)), |
| preference_max_prompt_length=int(merged.get("preference_max_prompt_length", 512)), |
| preference_max_length=int(merged.get("preference_max_length", 1024)), |
| preference_reference_model=( |
| validate_maris_model( |
| str(merged["preference_reference_model"]), |
| "MARIS_PREFERENCE_REFERENCE_MODEL/HF_PREFERENCE_REFERENCE_MODEL/preference_reference_model", |
| ) |
| if merged.get("preference_reference_model") not in (None, "") |
| else "" |
| ), |
| quality_gate_enabled=_parse_bool(merged.get("quality_gate_enabled"), default=True), |
| dedupe_enabled=_parse_bool(merged.get("dedupe_enabled"), default=True), |
| quality_min_text_chars=int(merged.get("quality_min_text_chars", 4)), |
| scoring_enabled=_parse_bool(merged.get("scoring_enabled"), default=True), |
| weighted_repetition_enabled=_parse_bool( |
| merged.get("weighted_repetition_enabled"), |
| default=True, |
| ), |
| medium_score_repeat_count=int(merged.get("medium_score_repeat_count", 2)), |
| high_score_repeat_count=int(merged.get("high_score_repeat_count", 3)), |
| source_weighting_enabled=_parse_bool( |
| merged.get("source_weighting_enabled"), |
| default=True, |
| ), |
| source_weight_map=_parse_source_weight_map(merged.get("source_weight_map")), |
| category_weight_map=_parse_category_weight_map(merged.get("category_weight_map")), |
| max_effective_repeat_count=int(merged.get("max_effective_repeat_count", 6)), |
| benchmark_feedback_enabled=_parse_bool( |
| merged.get("benchmark_feedback_enabled"), |
| default=True, |
| ), |
| benchmark_feedback_auto_discover=_parse_bool( |
| merged.get("benchmark_feedback_auto_discover"), |
| default=True, |
| ), |
| benchmark_feedback_path=str(merged.get("benchmark_feedback_path", "") or ""), |
| benchmark_feedback_boost_scale=float(merged.get("benchmark_feedback_boost_scale", 2.0)), |
| benchmark_feedback_max_multiplier=float( |
| merged.get("benchmark_feedback_max_multiplier", 1.75) |
| ), |
| continue_from_latest_artifact=_parse_bool( |
| merged.get("continue_from_latest_artifact"), |
| default=False, |
| ), |
| continue_model_path=str(merged.get("continue_model_path", "") or ""), |
| ) |
| config.dataset_repos = [ |
| validate_maris_repo( |
| repo_id, |
| "MARIS_DATASET_REPOS/HF_DATASET_REPOS/dataset_repos", |
| label="dataset repozitorijs", |
| ) |
| for repo_id in _parse_repo_list( |
| merged.get("dataset_repos"), |
| default=[config.dataset_repo], |
| ) |
| ] |
| if config.dataset_repo not in config.dataset_repos: |
| config.dataset_repos.insert(0, config.dataset_repo) |
| config.eval_dataset_repos = [ |
| validate_maris_repo( |
| repo_id, |
| "MARIS_EVAL_DATASET_REPOS/HF_EVAL_DATASET_REPOS/eval_dataset_repos", |
| label="eval dataset repozitorijs", |
| ) |
| for repo_id in _parse_repo_list( |
| merged.get("eval_dataset_repos"), |
| default=[config.eval_dataset_repo] if config.eval_dataset_repo else [], |
| ) |
| ] |
| if config.eval_dataset_repo and config.eval_dataset_repo not in config.eval_dataset_repos: |
| config.eval_dataset_repos.insert(0, config.eval_dataset_repo) |
| if config.fp16 and config.bf16: |
| raise ValueError("Maris training konfigurācijā nevar vienlaikus ieslēgt fp16 un bf16.") |
| if config.adapter_type not in {"full", "lora", "qlora", "specialist_model"}: |
| raise ValueError("adapter_type jābūt vienam no: full, lora, qlora, specialist_model.") |
| if config.distributed_strategy not in {"none", "fsdp", "deepspeed"}: |
| raise ValueError("distributed_strategy jābūt vienam no: none, fsdp, deepspeed.") |
| if config.preference_optimization not in {"none", "dpo", "orpo"}: |
| raise ValueError("preference_optimization jābūt vienam no: none, dpo, orpo.") |
| if config.preference_optimization != "none" and not config.preference_dataset_path: |
| raise ValueError( |
| "Preference optimization vajag preference_dataset_path ar prompt/chosen/rejected datiem." |
| ) |
| if config.num_processes < 1: |
| raise ValueError("num_processes jābūt vismaz 1.") |
| if config.num_machines < 1: |
| raise ValueError("num_machines jābūt vismaz 1.") |
| if config.machine_rank < 0: |
| raise ValueError("machine_rank nedrīkst būt negatīvs.") |
| if config.main_process_port < 1: |
| raise ValueError("main_process_port jābūt pozitīvam portam.") |
| if config.fsdp_min_num_params < 0: |
| raise ValueError("fsdp_min_num_params nedrīkst būt negatīvs.") |
| return config |
|
|
|
|
| def _parse_branch_targets(value: Any) -> dict[str, dict[str, float]]: |
| if value in (None, ""): |
| return {key: item.copy() for key, item in DEFAULT_BRANCH_BENCHMARK_TARGETS.items()} |
| parsed = json.loads(value) if isinstance(value, str) else value |
| if not isinstance(parsed, dict): |
| raise ValueError("branch_benchmark_targets jābūt objektam ar branch -> metric -> score.") |
|
|
| normalized: dict[str, dict[str, float]] = {} |
| for branch, metrics in parsed.items(): |
| if not isinstance(metrics, dict): |
| raise ValueError("Katram branch_benchmark_targets ierakstam jābūt objektam.") |
| normalized[str(branch)] = {str(name): float(score) for name, score in metrics.items()} |
| return normalized |
|
|
|
|
| def _parse_source_weight_map(value: Any) -> dict[str, float]: |
| if value in (None, ""): |
| return DEFAULT_SOURCE_WEIGHT_MAP.copy() |
| parsed = json.loads(value) if isinstance(value, str) else value |
| if not isinstance(parsed, dict): |
| raise ValueError("source_weight_map jābūt objektam ar source_tier -> weight.") |
|
|
| normalized = DEFAULT_SOURCE_WEIGHT_MAP.copy() |
| for tier, weight in parsed.items(): |
| normalized[str(tier)] = float(weight) |
| return normalized |
|
|
|
|
| def _parse_branch_benchmark_names(value: Any) -> dict[str, str]: |
| if value in (None, ""): |
| return DEFAULT_BRANCH_BENCHMARK_NAMES.copy() |
| if isinstance(value, str): |
| value = json.loads(value) |
| if not isinstance(value, dict): |
| raise ValueError("branch_benchmark_names jābūt objektam ar branch -> benchmark name.") |
| normalized: dict[str, str] = {} |
| for branch_name, benchmark_name in value.items(): |
| name = str(benchmark_name or "").strip() |
| if name: |
| normalized[str(branch_name)] = name |
| return normalized |
|
|
|
|
| def _parse_branch_benchmark_dataset_paths(value: Any) -> dict[str, str]: |
| if value in (None, ""): |
| return DEFAULT_BRANCH_BENCHMARK_DATASET_PATHS.copy() |
| parsed = json.loads(value) if isinstance(value, str) else value |
| if not isinstance(parsed, dict): |
| raise ValueError( |
| "branch_benchmark_dataset_paths jābūt objektam ar branch -> benchmark dataset path." |
| ) |
| normalized = DEFAULT_BRANCH_BENCHMARK_DATASET_PATHS.copy() |
| for branch_name, path in parsed.items(): |
| candidate = str(path or "").strip() |
| if not candidate: |
| continue |
| normalized[str(branch_name)] = candidate |
| return normalized |
|
|
|
|
| def _parse_branch_preference_dataset_paths(value: Any) -> dict[str, str]: |
| if value in (None, ""): |
| return DEFAULT_BRANCH_PREFERENCE_DATASET_PATHS.copy() |
| parsed = json.loads(value) if isinstance(value, str) else value |
| if not isinstance(parsed, dict): |
| raise ValueError( |
| "branch_preference_dataset_paths jābūt objektam ar branch -> preference dataset path." |
| ) |
| normalized = DEFAULT_BRANCH_PREFERENCE_DATASET_PATHS.copy() |
| for branch_name, path in parsed.items(): |
| candidate = str(path or "").strip() |
| if not candidate: |
| continue |
| normalized[str(branch_name)] = candidate |
| return normalized |
|
|
|
|
| def _parse_branch_dataset_filter_rules(value: Any) -> dict[str, dict[str, Any]]: |
| if value in (None, ""): |
| return {key: item.copy() for key, item in DEFAULT_BRANCH_DATASET_FILTER_RULES.items()} |
| parsed = json.loads(value) if isinstance(value, str) else value |
| if not isinstance(parsed, dict): |
| raise ValueError("branch_dataset_filter_rules jābūt objektam ar branch -> filter rule map.") |
| normalized = {key: item.copy() for key, item in DEFAULT_BRANCH_DATASET_FILTER_RULES.items()} |
| for branch_name, rules in parsed.items(): |
| if not isinstance(rules, dict): |
| raise ValueError("Katram branch_dataset_filter_rules ierakstam jābūt objektam.") |
| normalized[str(branch_name)] = dict(rules) |
| return normalized |
|
|
|
|
| def _parse_category_weight_map(value: Any) -> dict[str, float]: |
| if value in (None, ""): |
| return DEFAULT_CATEGORY_WEIGHT_MAP.copy() |
| parsed = json.loads(value) if isinstance(value, str) else value |
| if not isinstance(parsed, dict): |
| raise ValueError("category_weight_map jābūt objektam ar category -> weight.") |
| return {str(label): float(weight) for label, weight in parsed.items()} |
|
|