"""CLI ieeja pilnam Maris AI training pipeline skrējienam.""" from __future__ import annotations import argparse import json import logging from dataclasses import replace from maris_core.training.config import list_training_base_models, load_training_config logger = logging.getLogger(__name__) def _parse_bool_arg(value: str) -> bool: """Parsē CLI boolean vērtību no true/false, yes/no vai 1/0 formāta.""" normalized = value.strip().lower() if normalized in {"1", "true", "yes", "on"}: return True if normalized in {"0", "false", "no", "off"}: return False raise argparse.ArgumentTypeError("Izmanto true/false, yes/no vai 1/0.") def main() -> int: """Izpilda vienu pilnu apmācības skrējienu pēc JSON konfigurācijas.""" parser = argparse.ArgumentParser(description="Apmāca Maris AI modeli ar Transformers") parser.add_argument("--config", help="JSON konfigurācijas fails") parser.add_argument("--model-name", help="Bāzes modelis fine-tuningam") parser.add_argument("--model-preset", help="Iepriekš definēts HF bāzes modeļa presets") parser.add_argument("--dataset-repo", help="HF dataset repo ID") parser.add_argument("--eval-dataset-repo", help="Atsevišķs HF eval dataset repo ID") parser.add_argument( "--benchmark-dataset-path", help="Lokāls JSON benchmarks release gate un score manifestam" ) parser.add_argument("--benchmark-name", help="Benchmark komplekta nosaukums artefaktiem") parser.add_argument( "--benchmark-levels", help="Comma-separated benchmark līmeņi, piemēram local,ci,release", ) parser.add_argument( "--benchmark-min-overall", type=float, help="Minimālais overall benchmark score release gate vajadzībām", ) parser.add_argument( "--benchmark-gate-enabled", type=_parse_bool_arg, help="Vai training skrējiens jāaptur, ja benchmark gate neiziet", ) parser.add_argument( "--preference-dataset-path", help="Lokāls JSON preference-feedback datasets auditējamam artifactam", ) parser.add_argument( "--preference-optimization", help="Preference optimization režīms: none, dpo vai orpo", ) parser.add_argument("--preference-beta", type=float, help="DPO/ORPO beta parametrs") parser.add_argument( "--preference-max-prompt-length", type=int, help="Maksimālais prompt tokenu garums preference optimization laikā", ) parser.add_argument( "--preference-max-length", type=int, help="Maksimālais kopējais tokenu garums preference optimization laikā", ) parser.add_argument( "--preference-reference-model", help="Atsauces modelis DPO preference optimization stadijai", ) parser.add_argument("--branch-name", help="Maris atzara nosaukums") parser.add_argument("--branch-focus", help="Atzara specializācijas fokuss") parser.add_argument("--adapter-type", help="Adapteru stratēģija, piemēram full vai lora") parser.add_argument("--lora-r", type=int, help="LoRA rank parametrs PEFT adapteriem") parser.add_argument("--lora-alpha", type=int, help="LoRA alpha parametrs PEFT adapteriem") parser.add_argument("--lora-dropout", type=float, help="LoRA dropout parametrs") parser.add_argument("--lora-bias", help="LoRA bias režīms, piemēram none vai all") parser.add_argument( "--peft-target-modules", help="Comma-separated PEFT target modules saraksts", ) parser.add_argument("--qlora-quant-type", help="QLoRA quant type, piemēram nf4 vai fp4") parser.add_argument( "--qlora-use-double-quant", type=_parse_bool_arg, help="Vai QLoRA izmantot double quantization", ) parser.add_argument( "--qlora-compute-dtype", help="QLoRA compute dtype, piemēram float16 vai bfloat16", ) parser.add_argument( "--distributed-strategy", help="Distributed režīms: none, fsdp vai deepspeed", ) parser.add_argument( "--distributed-config-path", help="Ceļš uz FSDP vai DeepSpeed JSON konfigurāciju", ) parser.add_argument( "--use-accelerate", type=_parse_bool_arg, help="Vai palaist treniņu ar accelerate launcher semantiku", ) parser.add_argument( "--accelerate-config-path", help="Ceļš uz accelerate launcher YAML konfigurāciju", ) parser.add_argument("--num-processes", type=int, help="Procesu/GPU skaits distributed launcham") parser.add_argument("--num-machines", type=int, help="Mašīnu skaits distributed launcham") parser.add_argument( "--machine-rank", type=int, help="Pašreizējās mašīnas ranks distributed launcham" ) parser.add_argument("--main-process-ip", help="Galvenā procesa IP multi-node launcham") parser.add_argument( "--main-process-port", type=int, help="Galvenā procesa ports multi-node launcham" ) parser.add_argument( "--fsdp-transformer-layer-cls-to-wrap", help="Comma-separated transformer layer class saraksts FSDP auto-wrap vajadzībām", ) parser.add_argument( "--fsdp-min-num-params", type=int, help="Minimālais parametru skaits FSDP wrap aktivēšanai", ) parser.add_argument("--hub-model-id", help="Maris model repo ID publicētajam rezultātam") parser.add_argument("--output-dir", help="Kur saglabāt apmācīto modeli") parser.add_argument("--num-epochs", type=int, help="Epoku skaits") parser.add_argument("--learning-rate", type=float, help="Learning rate") parser.add_argument("--max-seq-length", type=int, help="Maksimālais tokenu garums") parser.add_argument( "--push-to-hub", type=_parse_bool_arg, help="Vai pēc treniņa publicēt pilnu output direktoriju uz Hugging Face Hub", ) parser.add_argument( "--all-branches", action="store_true", help="Palaist branch-specific training pipeline visiem atzariem", ) parser.add_argument( "--validation-split-ratio", type=float, help="Validation split proporcija, ja repo nav validation split", ) parser.add_argument( "--list-base-models", action="store_true", help="Izvada pieejamos bāzes modeļu presetus JSON formātā un beidz darbu", ) args = parser.parse_args() logging.basicConfig(level=logging.INFO) try: if args.list_base_models: print(json.dumps(list_training_base_models(), indent=2, ensure_ascii=False)) return 0 from maris_core.training.train import train_branch_suite, train_with_config config = load_training_config( args.config, overrides={ "model_name": args.model_name, "model_preset": args.model_preset, "dataset_repo": args.dataset_repo, "eval_dataset_repo": args.eval_dataset_repo, "benchmark_dataset_path": args.benchmark_dataset_path, "benchmark_name": args.benchmark_name, "benchmark_levels": args.benchmark_levels, "benchmark_min_overall": args.benchmark_min_overall, "benchmark_gate_enabled": args.benchmark_gate_enabled, "preference_dataset_path": args.preference_dataset_path, "preference_optimization": args.preference_optimization, "preference_beta": args.preference_beta, "preference_max_prompt_length": args.preference_max_prompt_length, "preference_max_length": args.preference_max_length, "preference_reference_model": args.preference_reference_model, "branch_name": args.branch_name, "branch_focus": args.branch_focus, "adapter_type": args.adapter_type, "lora_r": args.lora_r, "lora_alpha": args.lora_alpha, "lora_dropout": args.lora_dropout, "lora_bias": args.lora_bias, "peft_target_modules": args.peft_target_modules, "qlora_quant_type": args.qlora_quant_type, "qlora_use_double_quant": args.qlora_use_double_quant, "qlora_compute_dtype": args.qlora_compute_dtype, "distributed_strategy": args.distributed_strategy, "distributed_config_path": args.distributed_config_path, "use_accelerate": args.use_accelerate, "accelerate_config_path": args.accelerate_config_path, "num_processes": args.num_processes, "num_machines": args.num_machines, "machine_rank": args.machine_rank, "main_process_ip": args.main_process_ip, "main_process_port": args.main_process_port, "fsdp_transformer_layer_cls_to_wrap": args.fsdp_transformer_layer_cls_to_wrap, "fsdp_min_num_params": args.fsdp_min_num_params, "hub_model_id": args.hub_model_id, "output_dir": args.output_dir, "num_epochs": args.num_epochs, "learning_rate": args.learning_rate, "max_seq_length": args.max_seq_length, "push_to_hub": args.push_to_hub, "validation_split_ratio": args.validation_split_ratio, }, ) execution_config = replace(config, push_to_hub=False) if args.all_branches else config metrics = ( train_branch_suite(execution_config) if args.all_branches else train_with_config(config) ) logger.info("Training metrics: %s", metrics) print(json.dumps(metrics, indent=2, ensure_ascii=False)) return 0 except (FileNotFoundError, ImportError, ValueError) as exc: parser.exit(2, f"{parser.prog}: error: {exc}\n") if __name__ == "__main__": raise SystemExit(main())