MarisUK's picture
Maris AI model sync
f440f03 verified
"""Konfigurācija Maris apmācības pipeline'iem."""
from __future__ import annotations
import json
import logging
import re
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any
from maris_core.utils.env import get_env_any, validate_maris_model, validate_maris_repo
logger = logging.getLogger(__name__)
EXTRA_MODEL_SPLIT_RE = re.compile(r"[\n,;]+")
EXTRA_MODEL_KEY_SANITIZE_RE = re.compile(r"[^a-z0-9]+")
DEFAULT_TRAINING_BASE_MODEL = "MarisUK/maris-ai-master"
DEFAULT_MASTER_MODEL_REPO = "MarisUK/maris-ai-master"
DEFAULT_TEXT_MODEL_REPO = "MarisUK/maris-ai-text"
DEFAULT_IMAGE_MODEL_REPO = "MarisUK/maris-ai-image"
DEFAULT_MUSIC_MODEL_REPO = "MarisUK/maris-ai-music"
DEFAULT_TTS_MODEL_REPO = "MarisUK/maris-tts-runtime"
DEFAULT_STT_MODEL_REPO = "MarisUK/maris-stt-runtime"
DEFAULT_VIDEO_MODEL_REPO = "MarisUK/maris-ai-video"
DEFAULT_PRIMARY_TRAINING_DATASET_REPO = "MarisUK/maris-ai-memory"
DEFAULT_TRAINING_DATASET_REPOS: list[str] = [
"MarisUK/maris-ai-memory",
DEFAULT_PRIMARY_TRAINING_DATASET_REPO,
"MarisUK/maris-ai-evals",
"MarisUK/maris-ai-benchmark",
]
DEFAULT_EVAL_DATASET_REPOS: list[str] = [
"MarisUK/maris-ai-evals",
"MarisUK/maris-ai-benchmark",
]
AVAILABLE_TRAINING_BASE_MODELS: dict[str, dict[str, str]] = {
"balanced": {
"model_name": DEFAULT_TRAINING_BASE_MODEL,
"label": "Balanced default",
"description": "Galvenais Maris master modelis pilnam instruction fine-tuning ciklam.",
},
"reasoning": {
"model_name": DEFAULT_TRAINING_BASE_MODEL,
"label": "Master reasoning",
"description": "Maris master modelis uzdevumiem, kur svarīga vispārējā reasoning kvalitāte.",
},
"coding": {
"model_name": "MarisUK/maris-ai-text",
"label": "Text specialist",
"description": "Maris text modelis čata, instrukciju un tehniska teksta fine-tuning skrējieniem.",
},
"lightweight": {
"model_name": "MarisUK/maris-ai-text",
"label": "Lean text runtime",
"description": "Teksta runtime modelis lētākiem vai ātrākiem eksperimentiem Maris ekosistēmā.",
},
}
DEFAULT_BENCHMARK_NAME = "chat-quality"
DEFAULT_BRANCH_BENCHMARK_NAMES: dict[str, str] = {
"master": "memory-quality",
"coder": "coder-release-quality",
"planner": "planner-release-quality",
}
DEFAULT_BRANCH_BENCHMARK_TARGETS: dict[str, dict[str, float]] = {
"master": {
"overall": 0.78,
"reasoning": 0.74,
"factuality": 0.72,
"helpfulness": 0.76,
"latvian_quality": 0.74,
"memory_retrieval_pass_rate": 0.8,
"memory_multi_turn_continuity": 0.74,
"memory_cross_session_recall": 0.72,
"memory_user_preferences_recall": 0.76,
"memory_cross_lingual_retrieval": 0.72,
"memory_stale_memory_rejection": 0.8,
"production_like_pass_rate": 0.75,
},
"coder": {
"overall": 0.74,
"coding": 0.78,
"reasoning": 0.72,
"execution": 0.7,
"grounding": 0.74,
"safety": 0.9,
"production_like_pass_rate": 0.75,
},
"planner": {
"overall": 0.76,
"reasoning": 0.77,
"helpfulness": 0.74,
"long_context": 0.72,
"grounding": 0.72,
"safety": 0.9,
"production_like_pass_rate": 0.75,
},
}
DEFAULT_SOURCE_WEIGHT_MAP: dict[str, float] = {
"production": 1.3,
"synthetic": 1.0,
"noisy": 0.65,
"unknown": 1.0,
}
DEFAULT_CATEGORY_WEIGHT_MAP: dict[str, float] = {}
EVALS_DIR = Path(__file__).resolve().parents[2] / "evals"
DEFAULT_BRANCH_DATASET_FILTER_RULES_PATH = EVALS_DIR / "branch_dataset_filter_rules.json"
def _load_default_branch_config() -> dict[str, Any]:
raw = json.loads(DEFAULT_BRANCH_DATASET_FILTER_RULES_PATH.read_text(encoding="utf-8"))
if not isinstance(raw, dict):
raise ValueError("Branch dataset defaults failam jābūt JSON objektam.")
return raw
def _resolve_default_branch_config_path(path: Any) -> str:
candidate = str(path or "").strip()
if not candidate:
return ""
resolved = Path(candidate)
if not resolved.is_absolute():
resolved = DEFAULT_BRANCH_DATASET_FILTER_RULES_PATH.parent / resolved
return str(resolved.resolve())
def _normalize_default_branch_path_map(value: Any) -> dict[str, str]:
if not isinstance(value, dict):
raise ValueError("Branch path defaults jābūt objektam ar branch -> dataset path.")
normalized: dict[str, str] = {}
for branch_name, path in value.items():
resolved = _resolve_default_branch_config_path(path)
if resolved:
normalized[str(branch_name)] = resolved
return normalized
def _normalize_default_branch_rule_map(value: Any) -> dict[str, dict[str, Any]]:
if not isinstance(value, dict):
raise ValueError("branch_dataset_filter_rules noklusējumam jābūt objektam.")
return {
str(branch): dict(payload) for branch, payload in value.items() if isinstance(payload, dict)
}
DEFAULT_BRANCH_CONFIG = _load_default_branch_config()
DEFAULT_BRANCH_BENCHMARK_DATASET_PATHS: dict[str, str] = _normalize_default_branch_path_map(
DEFAULT_BRANCH_CONFIG.get("branch_benchmark_dataset_paths", {})
)
DEFAULT_BRANCH_BENCHMARK_NAMES: dict[str, str] = {
**DEFAULT_BRANCH_BENCHMARK_NAMES,
**{
str(branch): str(name).strip()
for branch, name in DEFAULT_BRANCH_CONFIG.get("branch_benchmark_names", {}).items()
if str(name).strip()
},
}
DEFAULT_BRANCH_PREFERENCE_DATASET_PATHS: dict[str, str] = _normalize_default_branch_path_map(
DEFAULT_BRANCH_CONFIG.get("branch_preference_dataset_paths", {})
)
DEFAULT_BRANCH_DATASET_FILTER_RULES = _normalize_default_branch_rule_map(
DEFAULT_BRANCH_CONFIG.get("branch_dataset_filter_rules", DEFAULT_BRANCH_CONFIG)
)
def _parse_bool(value: Any, *, default: bool) -> bool:
if value is None:
return default
if isinstance(value, bool):
return value
return str(value).strip().lower() in {"1", "true", "yes", "on"}
def _parse_optional_bool(value: Any) -> bool | None:
if value is None:
return None
if isinstance(value, str) and not value.strip():
return None
return _parse_bool(value, default=False)
def _parse_list(value: Any) -> list[str]:
if value is None:
return []
if isinstance(value, list):
return [str(item) for item in value if str(item).strip()]
return [item.strip() for item in str(value).split(",") if item.strip()]
def _parse_repo_list(value: Any, *, default: list[str] | None = None) -> list[str]:
if value in (None, ""):
return list(default or [])
parsed = (
json.loads(value) if isinstance(value, str) and value.lstrip().startswith("[") else value
)
raw_items = parsed if isinstance(parsed, list) else EXTRA_MODEL_SPLIT_RE.split(str(parsed))
normalized: list[str] = []
for item in raw_items:
candidate = str(item or "").strip()
if candidate and candidate not in normalized:
normalized.append(candidate)
return normalized
@dataclass(slots=True)
class TrainingConfig:
"""Pilna apmācības konfigurācija vienam Maris treniņa skrējienam."""
model_name: str = DEFAULT_TRAINING_BASE_MODEL
model_preset: str = ""
branch_name: str = "master"
branch_focus: str = "general_reasoning"
adapter_type: str = "full"
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.05
lora_bias: str = "none"
peft_target_modules: list[str] = field(default_factory=list)
qlora_quant_type: str = "nf4"
qlora_use_double_quant: bool = True
qlora_compute_dtype: str = "float16"
dataset_repo: str = DEFAULT_PRIMARY_TRAINING_DATASET_REPO
dataset_repos: list[str] = field(default_factory=list)
eval_dataset_repo: str = ""
eval_dataset_repos: list[str] = field(default_factory=list)
output_dir: str = "./output/model"
hub_model_id: str = DEFAULT_MASTER_MODEL_REPO
text_model_id: str = DEFAULT_TEXT_MODEL_REPO
image_model_id: str = DEFAULT_IMAGE_MODEL_REPO
music_model_id: str = DEFAULT_MUSIC_MODEL_REPO
tts_model_id: str = DEFAULT_TTS_MODEL_REPO
stt_model_id: str = DEFAULT_STT_MODEL_REPO
video_model_id: str = DEFAULT_VIDEO_MODEL_REPO
num_epochs: int = 3
learning_rate: float = 2e-5
per_device_train_batch_size: int = 1
per_device_eval_batch_size: int = 1
gradient_accumulation_steps: int = 8
warmup_ratio: float = 0.1
weight_decay: float = 0.01
logging_steps: int = 10
save_steps: int = 100
eval_steps: int = 100
save_total_limit: int = 2
max_seq_length: int = 1024
validation_split_ratio: float = 0.1
seed: int = 42
fp16: bool = False
bf16: bool = False
gradient_checkpointing: bool = False
gradient_checkpointing_use_reentrant: bool | None = None
distributed_strategy: str = "none"
distributed_config_path: str = ""
use_accelerate: bool = False
accelerate_config_path: str = ""
num_processes: int = 1
num_machines: int = 1
machine_rank: int = 0
main_process_ip: str = ""
main_process_port: int = 29500
fsdp_transformer_layer_cls_to_wrap: list[str] = field(default_factory=list)
fsdp_min_num_params: int = 100_000_000
report_to: list[str] = field(default_factory=list)
push_to_hub: bool = False
save_safetensors: bool = True
lr_scheduler_type: str = "cosine"
benchmark_dataset_path: str = ""
benchmark_name: str = DEFAULT_BENCHMARK_NAME
benchmark_levels: list[str] = field(default_factory=lambda: ["local", "ci", "release"])
benchmark_min_overall: float = 0.7
benchmark_gate_enabled: bool = False
branch_benchmark_targets: dict[str, dict[str, float]] = field(
default_factory=lambda: {
key: value.copy() for key, value in DEFAULT_BRANCH_BENCHMARK_TARGETS.items()
}
)
branch_benchmark_names: dict[str, str] = field(
default_factory=lambda: DEFAULT_BRANCH_BENCHMARK_NAMES.copy()
)
branch_benchmark_dataset_paths: dict[str, str] = field(
default_factory=lambda: DEFAULT_BRANCH_BENCHMARK_DATASET_PATHS.copy()
)
branch_preference_dataset_paths: dict[str, str] = field(
default_factory=lambda: DEFAULT_BRANCH_PREFERENCE_DATASET_PATHS.copy()
)
branch_dataset_filter_rules: dict[str, dict[str, Any]] = field(
default_factory=lambda: {
key: value.copy() for key, value in DEFAULT_BRANCH_DATASET_FILTER_RULES.items()
}
)
preference_dataset_path: str = ""
preference_optimization: str = "none"
preference_beta: float = 0.1
preference_max_prompt_length: int = 512
preference_max_length: int = 1024
preference_reference_model: str = ""
quality_gate_enabled: bool = True
dedupe_enabled: bool = True
quality_min_text_chars: int = 4
scoring_enabled: bool = True
weighted_repetition_enabled: bool = True
medium_score_repeat_count: int = 2
high_score_repeat_count: int = 3
source_weighting_enabled: bool = True
source_weight_map: dict[str, float] = field(
default_factory=lambda: DEFAULT_SOURCE_WEIGHT_MAP.copy()
)
category_weight_map: dict[str, float] = field(
default_factory=lambda: DEFAULT_CATEGORY_WEIGHT_MAP.copy()
)
max_effective_repeat_count: int = 6
benchmark_feedback_enabled: bool = True
benchmark_feedback_auto_discover: bool = True
benchmark_feedback_path: str = ""
benchmark_feedback_boost_scale: float = 2.0
benchmark_feedback_max_multiplier: float = 1.75
continue_from_latest_artifact: bool = False
continue_model_path: str = ""
def to_dict(self) -> dict[str, Any]:
"""Serializē konfigurāciju uz dict."""
return asdict(self)
def list_training_base_models() -> dict[str, dict[str, str]]:
"""Atgriež iepriekš definētos bāzes modeļu presetus."""
models = {key: value.copy() for key, value in AVAILABLE_TRAINING_BASE_MODELS.items()}
models.update(_load_extra_training_base_models())
return models
def _normalize_extra_training_base_model_payload(
preset_key: str,
payload: Any,
) -> dict[str, str] | None:
if isinstance(payload, str):
model_name = payload.strip()
label = ""
description = ""
elif isinstance(payload, dict):
model_name = str(payload.get("model_name", "") or "").strip()
label = str(payload.get("label", "") or "").strip()
description = str(payload.get("description", "") or "").strip()
else:
logger.warning(
"Ignoring extra training preset %r because payload must be an object or model string.",
preset_key,
)
return None
model_parts = model_name.split("/", 1)
if len(model_parts) != 2 or not all(part.strip() for part in model_parts):
logger.warning(
"Ignoring extra training preset %r because model_name must use owner/name format.",
preset_key,
)
return None
if not label:
label = preset_key.replace("-", " ").replace("_", " ").title()
if not description:
description = f"External base model preset {model_name}."
return {
"model_name": model_name,
"label": label,
"description": description,
}
def _load_extra_training_base_models() -> dict[str, dict[str, str]]:
raw_value = get_env_any("MARIS_TRAIN_EXTRA_MODELS", "HF_TRAIN_EXTRA_MODELS", default="") or ""
normalized = raw_value.strip()
if not normalized:
return {}
try:
parsed = json.loads(normalized)
except json.JSONDecodeError as exc:
fallback_models = _parse_extra_training_base_models_fallback(normalized)
if fallback_models:
logger.info(
"Parsed MARIS_TRAIN_EXTRA_MODELS/HF_TRAIN_EXTRA_MODELS using owner/name fallback syntax."
)
return fallback_models
logger.warning(
"Ignoring MARIS_TRAIN_EXTRA_MODELS/HF_TRAIN_EXTRA_MODELS because value is not valid JSON or supported fallback syntax.",
exc_info=exc,
)
return {}
normalized_payloads = _coerce_extra_training_base_models_payload(parsed)
if normalized_payloads is None:
logger.warning(
"Ignoring MARIS_TRAIN_EXTRA_MODELS/HF_TRAIN_EXTRA_MODELS because top-level value must be a JSON object, JSON array, or owner/name fallback list."
)
return {}
result: dict[str, dict[str, str]] = {}
for preset_name, payload in normalized_payloads.items():
preset_key = str(preset_name).strip()
if not preset_key:
logger.warning("Ignoring extra training preset with empty name.")
continue
normalized_payload = _normalize_extra_training_base_model_payload(preset_key, payload)
if normalized_payload is None:
continue
result[preset_key] = normalized_payload
return result
def _coerce_extra_training_base_models_payload(parsed: Any) -> dict[str, Any] | None:
if isinstance(parsed, dict):
return parsed
if not isinstance(parsed, list):
return None
coerced: dict[str, Any] = {}
for payload in parsed:
if isinstance(payload, str):
preset_key = _build_extra_training_preset_key(payload, existing=coerced)
coerced[preset_key] = payload
continue
if isinstance(payload, dict):
preset_key = str(payload.get("preset") or payload.get("key") or "").strip()
model_name = str(payload.get("model_name") or payload.get("model") or "").strip()
if not model_name:
logger.warning(
"Ignoring extra training preset list item because model_name is missing."
)
continue
if not preset_key:
preset_key = _build_extra_training_preset_key(model_name, existing=coerced)
normalized_payload = {
"model_name": model_name,
"label": str(payload.get("label", "") or "").strip(),
"description": str(payload.get("description", "") or "").strip(),
}
coerced[preset_key] = normalized_payload
continue
logger.warning(
"Ignoring extra training preset list item %r because it must be a string or object.",
payload,
)
return coerced
def _parse_extra_training_base_models_fallback(raw_value: str) -> dict[str, dict[str, str]]:
result: dict[str, dict[str, str]] = {}
candidates = [item.strip() for item in EXTRA_MODEL_SPLIT_RE.split(raw_value) if item.strip()]
for candidate in candidates:
preset_key = ""
model_name = candidate
if "=" in candidate:
preset_key, model_name = candidate.split("=", 1)
preset_key = preset_key.strip()
model_name = model_name.strip()
if not model_name:
continue
if not preset_key:
preset_key = _build_extra_training_preset_key(model_name, existing=result)
normalized_payload = _normalize_extra_training_base_model_payload(preset_key, model_name)
if normalized_payload is None:
return {}
result[preset_key] = normalized_payload
return result
def _build_extra_training_preset_key(
model_name: str,
*,
existing: dict[str, Any],
) -> str:
owner_name = model_name.strip().lower().replace("/", "-")
base_key = EXTRA_MODEL_KEY_SANITIZE_RE.sub("-", owner_name).strip("-") or "extra-model"
candidate = base_key
suffix = 2
while candidate in existing:
candidate = f"{base_key}-{suffix}"
suffix += 1
return candidate
def resolve_training_model(
model_name: str,
model_preset: str | None,
*,
available_models: dict[str, dict[str, str]] | None = None,
) -> str:
"""Atrisina modeļa preset uz konkrētu bāzes modeli."""
normalized_preset = (model_preset or "").strip()
if not normalized_preset:
return model_name
resolved_models = available_models or list_training_base_models()
preset = resolved_models.get(normalized_preset)
if preset is None:
available = ", ".join(sorted(resolved_models))
raise ValueError(
f"Nezināms MARIS_TRAIN_MODEL_PRESET/model_preset. Izmanto vienu no: {available}."
)
return preset["model_name"]
def resolve_model_selection(
default_model_name: str,
*sources: dict[str, Any],
available_models: dict[str, dict[str, str]] | None = None,
) -> tuple[str, str]:
"""Atrod efektīvo modeli no augstākās prioritātes avota.
Katrā avotā tiešs `model_name` ir prioritārāks par `model_preset`, jo tas
ir precīzāks override nekā presets.
"""
resolved_model_name = default_model_name
resolved_model_preset = ""
for source in sources:
source_model_name = source.get("model_name")
source_model_preset = source.get("model_preset")
if source_model_name not in (None, ""):
return str(source_model_name), ""
if source_model_preset not in (None, ""):
resolved_model_preset = str(source_model_preset)
resolved_model_name = resolve_training_model(
default_model_name,
resolved_model_preset,
available_models=available_models,
)
return resolved_model_name, resolved_model_preset
return resolved_model_name, resolved_model_preset
def load_training_config(
config_path: str | None = None,
overrides: dict[str, Any] | None = None,
) -> TrainingConfig:
"""Ielādē konfigurāciju no JSON, vides mainīgajiem un CLI override'iem."""
data: dict[str, Any] = {}
defaults = TrainingConfig().to_dict()
if config_path:
data = json.loads(Path(config_path).read_text(encoding="utf-8"))
env_data: dict[str, Any] = {
"model_name": get_env_any("MARIS_TRAIN_BASE_MODEL", "HF_TRAIN_BASE_MODEL", "TEXT_MODEL"),
"model_preset": get_env_any("MARIS_TRAIN_MODEL_PRESET", "HF_TRAIN_MODEL_PRESET"),
"branch_name": get_env_any("MARIS_TRAIN_BRANCH_NAME", "HF_TRAIN_BRANCH_NAME"),
"branch_focus": get_env_any("MARIS_TRAIN_BRANCH_FOCUS", "HF_TRAIN_BRANCH_FOCUS"),
"adapter_type": get_env_any("MARIS_TRAIN_ADAPTER_TYPE", "HF_TRAIN_ADAPTER_TYPE"),
"lora_r": get_env_any("MARIS_TRAIN_LORA_R", "HF_TRAIN_LORA_R"),
"lora_alpha": get_env_any("MARIS_TRAIN_LORA_ALPHA", "HF_TRAIN_LORA_ALPHA"),
"lora_dropout": get_env_any("MARIS_TRAIN_LORA_DROPOUT", "HF_TRAIN_LORA_DROPOUT"),
"lora_bias": get_env_any("MARIS_TRAIN_LORA_BIAS", "HF_TRAIN_LORA_BIAS"),
"peft_target_modules": get_env_any(
"MARIS_TRAIN_PEFT_TARGET_MODULES",
"HF_TRAIN_PEFT_TARGET_MODULES",
),
"qlora_quant_type": get_env_any(
"MARIS_TRAIN_QLORA_QUANT_TYPE",
"HF_TRAIN_QLORA_QUANT_TYPE",
),
"qlora_use_double_quant": get_env_any(
"MARIS_TRAIN_QLORA_USE_DOUBLE_QUANT",
"HF_TRAIN_QLORA_USE_DOUBLE_QUANT",
),
"qlora_compute_dtype": get_env_any(
"MARIS_TRAIN_QLORA_COMPUTE_DTYPE",
"HF_TRAIN_QLORA_COMPUTE_DTYPE",
),
"dataset_repo": get_env_any("MARIS_MEMORY_REPO", "MARIS_DATASET_REPO", "HF_DATASET_REPO"),
"dataset_repos": get_env_any("MARIS_DATASET_REPOS", "HF_DATASET_REPOS"),
"eval_dataset_repo": get_env_any("MARIS_EVAL_DATASET_REPO", "HF_EVAL_DATASET_REPO"),
"eval_dataset_repos": get_env_any("MARIS_EVAL_DATASET_REPOS", "HF_EVAL_DATASET_REPOS"),
"output_dir": get_env_any("MARIS_TRAIN_OUTPUT_DIR", "HF_TRAIN_OUTPUT_DIR"),
"hub_model_id": get_env_any("MARIS_MODEL_REPO", "HF_MODEL_REPO"),
"text_model_id": get_env_any("TEXT_MODEL", default=DEFAULT_TEXT_MODEL_REPO),
"image_model_id": get_env_any("IMAGE_MODEL", default=DEFAULT_IMAGE_MODEL_REPO),
"music_model_id": get_env_any("MUSIC_MODEL", default=DEFAULT_MUSIC_MODEL_REPO),
"tts_model_id": get_env_any("TTS_MODEL", default=DEFAULT_TTS_MODEL_REPO),
"stt_model_id": get_env_any("STT_MODEL", default=DEFAULT_STT_MODEL_REPO),
"video_model_id": get_env_any("VIDEO_MODEL", default=DEFAULT_VIDEO_MODEL_REPO),
"num_epochs": get_env_any("MARIS_TRAIN_NUM_EPOCHS", "HF_TRAIN_NUM_EPOCHS"),
"learning_rate": get_env_any("MARIS_TRAIN_LEARNING_RATE", "HF_TRAIN_LEARNING_RATE"),
"per_device_train_batch_size": get_env_any("MARIS_TRAIN_BATCH_SIZE", "HF_TRAIN_BATCH_SIZE"),
"per_device_eval_batch_size": get_env_any(
"MARIS_TRAIN_EVAL_BATCH_SIZE", "HF_TRAIN_EVAL_BATCH_SIZE"
),
"gradient_accumulation_steps": get_env_any(
"MARIS_TRAIN_GRADIENT_ACCUMULATION_STEPS",
"HF_TRAIN_GRADIENT_ACCUMULATION_STEPS",
),
"warmup_ratio": get_env_any("MARIS_TRAIN_WARMUP_RATIO", "HF_TRAIN_WARMUP_RATIO"),
"weight_decay": get_env_any("MARIS_TRAIN_WEIGHT_DECAY", "HF_TRAIN_WEIGHT_DECAY"),
"logging_steps": get_env_any("MARIS_TRAIN_LOGGING_STEPS", "HF_TRAIN_LOGGING_STEPS"),
"save_steps": get_env_any("MARIS_TRAIN_SAVE_STEPS", "HF_TRAIN_SAVE_STEPS"),
"eval_steps": get_env_any("MARIS_TRAIN_EVAL_STEPS", "HF_TRAIN_EVAL_STEPS"),
"save_total_limit": get_env_any(
"MARIS_TRAIN_SAVE_TOTAL_LIMIT", "HF_TRAIN_SAVE_TOTAL_LIMIT"
),
"max_seq_length": get_env_any("MARIS_TRAIN_MAX_SEQ_LENGTH", "HF_TRAIN_MAX_SEQ_LENGTH"),
"validation_split_ratio": get_env_any(
"MARIS_TRAIN_VALIDATION_SPLIT", "HF_TRAIN_VALIDATION_SPLIT"
),
"seed": get_env_any("MARIS_TRAIN_SEED", "HF_TRAIN_SEED"),
"fp16": get_env_any("MARIS_TRAIN_FP16", "HF_TRAIN_FP16"),
"bf16": get_env_any("MARIS_TRAIN_BF16", "HF_TRAIN_BF16"),
"gradient_checkpointing": get_env_any(
"MARIS_TRAIN_GRADIENT_CHECKPOINTING",
"HF_TRAIN_GRADIENT_CHECKPOINTING",
),
"gradient_checkpointing_use_reentrant": get_env_any(
"MARIS_TRAIN_GRADIENT_CHECKPOINTING_USE_REENTRANT",
"HF_TRAIN_GRADIENT_CHECKPOINTING_USE_REENTRANT",
),
"distributed_strategy": get_env_any(
"MARIS_TRAIN_DISTRIBUTED_STRATEGY",
"HF_TRAIN_DISTRIBUTED_STRATEGY",
),
"distributed_config_path": get_env_any(
"MARIS_TRAIN_DISTRIBUTED_CONFIG_PATH",
"HF_TRAIN_DISTRIBUTED_CONFIG_PATH",
),
"use_accelerate": get_env_any(
"MARIS_TRAIN_USE_ACCELERATE",
"HF_TRAIN_USE_ACCELERATE",
),
"accelerate_config_path": get_env_any(
"MARIS_TRAIN_ACCELERATE_CONFIG_PATH",
"HF_TRAIN_ACCELERATE_CONFIG_PATH",
),
"num_processes": get_env_any("MARIS_TRAIN_NUM_PROCESSES", "HF_TRAIN_NUM_PROCESSES"),
"num_machines": get_env_any("MARIS_TRAIN_NUM_MACHINES", "HF_TRAIN_NUM_MACHINES"),
"machine_rank": get_env_any("MARIS_TRAIN_MACHINE_RANK", "HF_TRAIN_MACHINE_RANK"),
"main_process_ip": get_env_any(
"MARIS_TRAIN_MAIN_PROCESS_IP",
"HF_TRAIN_MAIN_PROCESS_IP",
),
"main_process_port": get_env_any(
"MARIS_TRAIN_MAIN_PROCESS_PORT",
"HF_TRAIN_MAIN_PROCESS_PORT",
),
"fsdp_transformer_layer_cls_to_wrap": get_env_any(
"MARIS_TRAIN_FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP",
"HF_TRAIN_FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP",
),
"fsdp_min_num_params": get_env_any(
"MARIS_TRAIN_FSDP_MIN_NUM_PARAMS",
"HF_TRAIN_FSDP_MIN_NUM_PARAMS",
),
"report_to": get_env_any("MARIS_TRAIN_REPORT_TO", "HF_TRAIN_REPORT_TO"),
"push_to_hub": get_env_any("MARIS_TRAIN_PUBLISH", "HF_TRAIN_PUSH_TO_HUB"),
"save_safetensors": get_env_any(
"MARIS_TRAIN_SAVE_SAFETENSORS", "HF_TRAIN_SAVE_SAFETENSORS"
),
"lr_scheduler_type": get_env_any(
"MARIS_TRAIN_LR_SCHEDULER_TYPE", "HF_TRAIN_LR_SCHEDULER_TYPE"
),
"benchmark_dataset_path": get_env_any(
"MARIS_BENCHMARK_DATASET_PATH", "HF_BENCHMARK_DATASET_PATH"
),
"benchmark_name": get_env_any("MARIS_BENCHMARK_NAME", "HF_BENCHMARK_NAME"),
"benchmark_levels": get_env_any("MARIS_BENCHMARK_LEVELS", "HF_BENCHMARK_LEVELS"),
"benchmark_min_overall": get_env_any(
"MARIS_BENCHMARK_MIN_OVERALL", "HF_BENCHMARK_MIN_OVERALL"
),
"benchmark_gate_enabled": get_env_any(
"MARIS_BENCHMARK_GATE_ENABLED", "HF_BENCHMARK_GATE_ENABLED"
),
"branch_benchmark_names": get_env_any(
"MARIS_BRANCH_BENCHMARK_NAMES",
"HF_BRANCH_BENCHMARK_NAMES",
),
"branch_benchmark_dataset_paths": get_env_any(
"MARIS_BRANCH_BENCHMARK_DATASET_PATHS",
"HF_BRANCH_BENCHMARK_DATASET_PATHS",
),
"branch_preference_dataset_paths": get_env_any(
"MARIS_BRANCH_PREFERENCE_DATASET_PATHS",
"HF_BRANCH_PREFERENCE_DATASET_PATHS",
),
"branch_dataset_filter_rules": get_env_any(
"MARIS_BRANCH_DATASET_FILTER_RULES",
"HF_BRANCH_DATASET_FILTER_RULES",
),
"preference_dataset_path": get_env_any(
"MARIS_PREFERENCE_DATASET_PATH", "HF_PREFERENCE_DATASET_PATH"
),
"preference_optimization": get_env_any(
"MARIS_PREFERENCE_OPTIMIZATION",
"HF_PREFERENCE_OPTIMIZATION",
),
"preference_beta": get_env_any("MARIS_PREFERENCE_BETA", "HF_PREFERENCE_BETA"),
"preference_max_prompt_length": get_env_any(
"MARIS_PREFERENCE_MAX_PROMPT_LENGTH",
"HF_PREFERENCE_MAX_PROMPT_LENGTH",
),
"preference_max_length": get_env_any(
"MARIS_PREFERENCE_MAX_LENGTH",
"HF_PREFERENCE_MAX_LENGTH",
),
"preference_reference_model": get_env_any(
"MARIS_PREFERENCE_REFERENCE_MODEL",
"HF_PREFERENCE_REFERENCE_MODEL",
),
"quality_gate_enabled": get_env_any(
"MARIS_TRAIN_QUALITY_GATE_ENABLED",
"HF_TRAIN_QUALITY_GATE_ENABLED",
),
"dedupe_enabled": get_env_any(
"MARIS_TRAIN_DEDUPE_ENABLED",
"HF_TRAIN_DEDUPE_ENABLED",
),
"quality_min_text_chars": get_env_any(
"MARIS_TRAIN_QUALITY_MIN_CHARS",
"HF_TRAIN_QUALITY_MIN_CHARS",
),
"scoring_enabled": get_env_any(
"MARIS_TRAIN_SCORING_ENABLED",
"HF_TRAIN_SCORING_ENABLED",
),
"weighted_repetition_enabled": get_env_any(
"MARIS_TRAIN_WEIGHTED_REPETITION_ENABLED",
"HF_TRAIN_WEIGHTED_REPETITION_ENABLED",
),
"medium_score_repeat_count": get_env_any(
"MARIS_TRAIN_MEDIUM_SCORE_REPEAT_COUNT",
"HF_TRAIN_MEDIUM_SCORE_REPEAT_COUNT",
),
"high_score_repeat_count": get_env_any(
"MARIS_TRAIN_HIGH_SCORE_REPEAT_COUNT",
"HF_TRAIN_HIGH_SCORE_REPEAT_COUNT",
),
"source_weighting_enabled": get_env_any(
"MARIS_TRAIN_SOURCE_WEIGHTING_ENABLED",
"HF_TRAIN_SOURCE_WEIGHTING_ENABLED",
),
"source_weight_map": get_env_any(
"MARIS_TRAIN_SOURCE_WEIGHT_MAP",
"HF_TRAIN_SOURCE_WEIGHT_MAP",
),
"category_weight_map": get_env_any(
"MARIS_TRAIN_CATEGORY_WEIGHT_MAP",
"HF_TRAIN_CATEGORY_WEIGHT_MAP",
),
"max_effective_repeat_count": get_env_any(
"MARIS_TRAIN_MAX_EFFECTIVE_REPEAT_COUNT",
"HF_TRAIN_MAX_EFFECTIVE_REPEAT_COUNT",
),
"benchmark_feedback_enabled": get_env_any(
"MARIS_TRAIN_BENCHMARK_FEEDBACK_ENABLED",
"HF_TRAIN_BENCHMARK_FEEDBACK_ENABLED",
),
"benchmark_feedback_auto_discover": get_env_any(
"MARIS_TRAIN_BENCHMARK_FEEDBACK_AUTO_DISCOVER",
"HF_TRAIN_BENCHMARK_FEEDBACK_AUTO_DISCOVER",
),
"benchmark_feedback_path": get_env_any(
"MARIS_TRAIN_BENCHMARK_FEEDBACK_PATH",
"HF_TRAIN_BENCHMARK_FEEDBACK_PATH",
),
"benchmark_feedback_boost_scale": get_env_any(
"MARIS_TRAIN_BENCHMARK_FEEDBACK_BOOST_SCALE",
"HF_TRAIN_BENCHMARK_FEEDBACK_BOOST_SCALE",
),
"benchmark_feedback_max_multiplier": get_env_any(
"MARIS_TRAIN_BENCHMARK_FEEDBACK_MAX_MULTIPLIER",
"HF_TRAIN_BENCHMARK_FEEDBACK_MAX_MULTIPLIER",
),
"continue_from_latest_artifact": get_env_any(
"MARIS_TRAIN_CONTINUE_FROM_LATEST",
"HF_TRAIN_CONTINUE_FROM_LATEST",
),
"continue_model_path": get_env_any(
"MARIS_TRAIN_CONTINUE_MODEL_PATH",
"HF_TRAIN_CONTINUE_MODEL_PATH",
"MARIS_LOCAL_MODEL_DIR",
"HF_LOCAL_MODEL_DIR",
),
}
env_overrides = {key: value for key, value in env_data.items() if value not in (None, "")}
cli_overrides = {key: value for key, value in (overrides or {}).items() if value is not None}
merged: dict[str, Any] = {}
merged.update(defaults)
merged.update(data)
merged.update(env_overrides)
merged.update(cli_overrides)
resolved_distributed_strategy = str(
merged.get("distributed_strategy", "none") or "none"
).lower()
explicit_use_accelerate = next(
(
source["use_accelerate"]
for source in (cli_overrides, env_overrides, data)
if source.get("use_accelerate") not in (None, "")
),
None,
)
available_base_models = list_training_base_models()
resolved_model_name, resolved_model_preset = resolve_model_selection(
str(defaults["model_name"]),
cli_overrides,
env_overrides,
data,
available_models=available_base_models,
)
config = TrainingConfig(
model_name=resolved_model_name,
model_preset=resolved_model_preset,
branch_name=str(merged["branch_name"]),
branch_focus=str(merged["branch_focus"]),
adapter_type=str(merged["adapter_type"]),
lora_r=int(merged.get("lora_r", 16)),
lora_alpha=int(merged.get("lora_alpha", 32)),
lora_dropout=float(merged.get("lora_dropout", 0.05)),
lora_bias=str(merged.get("lora_bias", "none") or "none"),
peft_target_modules=_parse_list(merged.get("peft_target_modules")),
qlora_quant_type=str(merged.get("qlora_quant_type", "nf4") or "nf4"),
qlora_use_double_quant=_parse_bool(
merged.get("qlora_use_double_quant"),
default=True,
),
qlora_compute_dtype=str(merged.get("qlora_compute_dtype") or "float16").lower(),
dataset_repo=validate_maris_repo(
str(merged["dataset_repo"]),
"MARIS_MEMORY_REPO/MARIS_DATASET_REPO/HF_DATASET_REPO/dataset_repo",
label="dataset repozitorijs",
),
dataset_repos=[],
eval_dataset_repo=(
validate_maris_repo(
str(merged["eval_dataset_repo"]),
"MARIS_EVAL_DATASET_REPO/HF_EVAL_DATASET_REPO/eval_dataset_repo",
label="eval dataset repozitorijs",
)
if merged.get("eval_dataset_repo") not in (None, "")
else ""
),
eval_dataset_repos=[],
output_dir=str(merged["output_dir"]),
hub_model_id=validate_maris_model(
str(merged["hub_model_id"]),
"MARIS_MODEL_REPO/HF_MODEL_REPO/hub_model_id",
),
text_model_id=validate_maris_model(
str(merged["text_model_id"]),
"TEXT_MODEL/text_model_id",
),
image_model_id=validate_maris_model(
str(merged["image_model_id"]),
"IMAGE_MODEL/image_model_id",
),
music_model_id=validate_maris_model(
str(merged["music_model_id"]),
"MUSIC_MODEL/music_model_id",
),
tts_model_id=validate_maris_model(
str(merged["tts_model_id"]),
"TTS_MODEL/tts_model_id",
),
stt_model_id=validate_maris_model(
str(merged["stt_model_id"]),
"STT_MODEL/stt_model_id",
),
video_model_id=validate_maris_model(
str(merged["video_model_id"]),
"VIDEO_MODEL/video_model_id",
),
num_epochs=int(merged["num_epochs"]),
learning_rate=float(merged["learning_rate"]),
per_device_train_batch_size=int(merged["per_device_train_batch_size"]),
per_device_eval_batch_size=int(merged["per_device_eval_batch_size"]),
gradient_accumulation_steps=int(merged["gradient_accumulation_steps"]),
warmup_ratio=float(merged["warmup_ratio"]),
weight_decay=float(merged["weight_decay"]),
logging_steps=int(merged["logging_steps"]),
save_steps=int(merged["save_steps"]),
eval_steps=int(merged["eval_steps"]),
save_total_limit=int(merged["save_total_limit"]),
max_seq_length=int(merged["max_seq_length"]),
validation_split_ratio=float(merged["validation_split_ratio"]),
seed=int(merged["seed"]),
fp16=_parse_bool(merged.get("fp16"), default=False),
bf16=_parse_bool(merged.get("bf16"), default=False),
gradient_checkpointing=_parse_bool(merged.get("gradient_checkpointing"), default=False),
gradient_checkpointing_use_reentrant=_parse_optional_bool(
merged.get("gradient_checkpointing_use_reentrant")
),
distributed_strategy=resolved_distributed_strategy,
distributed_config_path=str(merged.get("distributed_config_path", "") or ""),
use_accelerate=_parse_bool(
explicit_use_accelerate,
default=resolved_distributed_strategy != "none",
),
accelerate_config_path=str(merged.get("accelerate_config_path", "") or ""),
num_processes=int(merged.get("num_processes", 1)),
num_machines=int(merged.get("num_machines", 1)),
machine_rank=int(merged.get("machine_rank", 0)),
main_process_ip=str(merged.get("main_process_ip", "") or ""),
main_process_port=int(merged.get("main_process_port", 29500)),
fsdp_transformer_layer_cls_to_wrap=_parse_list(
merged.get("fsdp_transformer_layer_cls_to_wrap")
),
fsdp_min_num_params=int(merged.get("fsdp_min_num_params", 100_000_000)),
report_to=_parse_list(merged.get("report_to")),
push_to_hub=_parse_bool(merged.get("push_to_hub"), default=False),
save_safetensors=_parse_bool(merged.get("save_safetensors"), default=True),
lr_scheduler_type=str(merged["lr_scheduler_type"]),
benchmark_dataset_path=str(merged.get("benchmark_dataset_path", "") or ""),
benchmark_name=str(
merged.get("benchmark_name", DEFAULT_BENCHMARK_NAME) or DEFAULT_BENCHMARK_NAME
),
benchmark_levels=_parse_list(merged.get("benchmark_levels")) or ["local", "ci", "release"],
benchmark_min_overall=float(merged.get("benchmark_min_overall", 0.7)),
benchmark_gate_enabled=_parse_bool(merged.get("benchmark_gate_enabled"), default=False),
branch_benchmark_targets=_parse_branch_targets(merged.get("branch_benchmark_targets")),
branch_benchmark_names=_parse_branch_benchmark_names(merged.get("branch_benchmark_names")),
branch_benchmark_dataset_paths=_parse_branch_benchmark_dataset_paths(
merged.get("branch_benchmark_dataset_paths")
),
branch_preference_dataset_paths=_parse_branch_preference_dataset_paths(
merged.get("branch_preference_dataset_paths")
),
branch_dataset_filter_rules=_parse_branch_dataset_filter_rules(
merged.get("branch_dataset_filter_rules")
),
preference_dataset_path=str(merged.get("preference_dataset_path", "") or ""),
preference_optimization=str(
merged.get("preference_optimization", "none") or "none"
).lower(),
preference_beta=float(merged.get("preference_beta", 0.1)),
preference_max_prompt_length=int(merged.get("preference_max_prompt_length", 512)),
preference_max_length=int(merged.get("preference_max_length", 1024)),
preference_reference_model=(
validate_maris_model(
str(merged["preference_reference_model"]),
"MARIS_PREFERENCE_REFERENCE_MODEL/HF_PREFERENCE_REFERENCE_MODEL/preference_reference_model",
)
if merged.get("preference_reference_model") not in (None, "")
else ""
),
quality_gate_enabled=_parse_bool(merged.get("quality_gate_enabled"), default=True),
dedupe_enabled=_parse_bool(merged.get("dedupe_enabled"), default=True),
quality_min_text_chars=int(merged.get("quality_min_text_chars", 4)),
scoring_enabled=_parse_bool(merged.get("scoring_enabled"), default=True),
weighted_repetition_enabled=_parse_bool(
merged.get("weighted_repetition_enabled"),
default=True,
),
medium_score_repeat_count=int(merged.get("medium_score_repeat_count", 2)),
high_score_repeat_count=int(merged.get("high_score_repeat_count", 3)),
source_weighting_enabled=_parse_bool(
merged.get("source_weighting_enabled"),
default=True,
),
source_weight_map=_parse_source_weight_map(merged.get("source_weight_map")),
category_weight_map=_parse_category_weight_map(merged.get("category_weight_map")),
max_effective_repeat_count=int(merged.get("max_effective_repeat_count", 6)),
benchmark_feedback_enabled=_parse_bool(
merged.get("benchmark_feedback_enabled"),
default=True,
),
benchmark_feedback_auto_discover=_parse_bool(
merged.get("benchmark_feedback_auto_discover"),
default=True,
),
benchmark_feedback_path=str(merged.get("benchmark_feedback_path", "") or ""),
benchmark_feedback_boost_scale=float(merged.get("benchmark_feedback_boost_scale", 2.0)),
benchmark_feedback_max_multiplier=float(
merged.get("benchmark_feedback_max_multiplier", 1.75)
),
continue_from_latest_artifact=_parse_bool(
merged.get("continue_from_latest_artifact"),
default=False,
),
continue_model_path=str(merged.get("continue_model_path", "") or ""),
)
config.dataset_repos = [
validate_maris_repo(
repo_id,
"MARIS_DATASET_REPOS/HF_DATASET_REPOS/dataset_repos",
label="dataset repozitorijs",
)
for repo_id in _parse_repo_list(
merged.get("dataset_repos"),
default=[config.dataset_repo],
)
]
if config.dataset_repo not in config.dataset_repos:
config.dataset_repos.insert(0, config.dataset_repo)
config.eval_dataset_repos = [
validate_maris_repo(
repo_id,
"MARIS_EVAL_DATASET_REPOS/HF_EVAL_DATASET_REPOS/eval_dataset_repos",
label="eval dataset repozitorijs",
)
for repo_id in _parse_repo_list(
merged.get("eval_dataset_repos"),
default=[config.eval_dataset_repo] if config.eval_dataset_repo else [],
)
]
if config.eval_dataset_repo and config.eval_dataset_repo not in config.eval_dataset_repos:
config.eval_dataset_repos.insert(0, config.eval_dataset_repo)
if config.fp16 and config.bf16:
raise ValueError("Maris training konfigurācijā nevar vienlaikus ieslēgt fp16 un bf16.")
if config.adapter_type not in {"full", "lora", "qlora", "specialist_model"}:
raise ValueError("adapter_type jābūt vienam no: full, lora, qlora, specialist_model.")
if config.distributed_strategy not in {"none", "fsdp", "deepspeed"}:
raise ValueError("distributed_strategy jābūt vienam no: none, fsdp, deepspeed.")
if config.preference_optimization not in {"none", "dpo", "orpo"}:
raise ValueError("preference_optimization jābūt vienam no: none, dpo, orpo.")
if config.preference_optimization != "none" and not config.preference_dataset_path:
raise ValueError(
"Preference optimization vajag preference_dataset_path ar prompt/chosen/rejected datiem."
)
if config.num_processes < 1:
raise ValueError("num_processes jābūt vismaz 1.")
if config.num_machines < 1:
raise ValueError("num_machines jābūt vismaz 1.")
if config.machine_rank < 0:
raise ValueError("machine_rank nedrīkst būt negatīvs.")
if config.main_process_port < 1:
raise ValueError("main_process_port jābūt pozitīvam portam.")
if config.fsdp_min_num_params < 0:
raise ValueError("fsdp_min_num_params nedrīkst būt negatīvs.")
return config
def _parse_branch_targets(value: Any) -> dict[str, dict[str, float]]:
if value in (None, ""):
return {key: item.copy() for key, item in DEFAULT_BRANCH_BENCHMARK_TARGETS.items()}
parsed = json.loads(value) if isinstance(value, str) else value
if not isinstance(parsed, dict):
raise ValueError("branch_benchmark_targets jābūt objektam ar branch -> metric -> score.")
normalized: dict[str, dict[str, float]] = {}
for branch, metrics in parsed.items():
if not isinstance(metrics, dict):
raise ValueError("Katram branch_benchmark_targets ierakstam jābūt objektam.")
normalized[str(branch)] = {str(name): float(score) for name, score in metrics.items()}
return normalized
def _parse_source_weight_map(value: Any) -> dict[str, float]:
if value in (None, ""):
return DEFAULT_SOURCE_WEIGHT_MAP.copy()
parsed = json.loads(value) if isinstance(value, str) else value
if not isinstance(parsed, dict):
raise ValueError("source_weight_map jābūt objektam ar source_tier -> weight.")
normalized = DEFAULT_SOURCE_WEIGHT_MAP.copy()
for tier, weight in parsed.items():
normalized[str(tier)] = float(weight)
return normalized
def _parse_branch_benchmark_names(value: Any) -> dict[str, str]:
if value in (None, ""):
return DEFAULT_BRANCH_BENCHMARK_NAMES.copy()
if isinstance(value, str):
value = json.loads(value)
if not isinstance(value, dict):
raise ValueError("branch_benchmark_names jābūt objektam ar branch -> benchmark name.")
normalized: dict[str, str] = {}
for branch_name, benchmark_name in value.items():
name = str(benchmark_name or "").strip()
if name:
normalized[str(branch_name)] = name
return normalized
def _parse_branch_benchmark_dataset_paths(value: Any) -> dict[str, str]:
if value in (None, ""):
return DEFAULT_BRANCH_BENCHMARK_DATASET_PATHS.copy()
parsed = json.loads(value) if isinstance(value, str) else value
if not isinstance(parsed, dict):
raise ValueError(
"branch_benchmark_dataset_paths jābūt objektam ar branch -> benchmark dataset path."
)
normalized = DEFAULT_BRANCH_BENCHMARK_DATASET_PATHS.copy()
for branch_name, path in parsed.items():
candidate = str(path or "").strip()
if not candidate:
continue
normalized[str(branch_name)] = candidate
return normalized
def _parse_branch_preference_dataset_paths(value: Any) -> dict[str, str]:
if value in (None, ""):
return DEFAULT_BRANCH_PREFERENCE_DATASET_PATHS.copy()
parsed = json.loads(value) if isinstance(value, str) else value
if not isinstance(parsed, dict):
raise ValueError(
"branch_preference_dataset_paths jābūt objektam ar branch -> preference dataset path."
)
normalized = DEFAULT_BRANCH_PREFERENCE_DATASET_PATHS.copy()
for branch_name, path in parsed.items():
candidate = str(path or "").strip()
if not candidate:
continue
normalized[str(branch_name)] = candidate
return normalized
def _parse_branch_dataset_filter_rules(value: Any) -> dict[str, dict[str, Any]]:
if value in (None, ""):
return {key: item.copy() for key, item in DEFAULT_BRANCH_DATASET_FILTER_RULES.items()}
parsed = json.loads(value) if isinstance(value, str) else value
if not isinstance(parsed, dict):
raise ValueError("branch_dataset_filter_rules jābūt objektam ar branch -> filter rule map.")
normalized = {key: item.copy() for key, item in DEFAULT_BRANCH_DATASET_FILTER_RULES.items()}
for branch_name, rules in parsed.items():
if not isinstance(rules, dict):
raise ValueError("Katram branch_dataset_filter_rules ierakstam jābūt objektam.")
normalized[str(branch_name)] = dict(rules)
return normalized
def _parse_category_weight_map(value: Any) -> dict[str, float]:
if value in (None, ""):
return DEFAULT_CATEGORY_WEIGHT_MAP.copy()
parsed = json.loads(value) if isinstance(value, str) else value
if not isinstance(parsed, dict):
raise ValueError("category_weight_map jābūt objektam ar category -> weight.")
return {str(label): float(weight) for label, weight in parsed.items()}