| """CLI ieeja pilnam Maris AI training pipeline skrējienam.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| from dataclasses import replace |
|
|
| from maris_core.training.config import list_training_base_models, load_training_config |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def _parse_bool_arg(value: str) -> bool: |
| """Parsē CLI boolean vērtību no true/false, yes/no vai 1/0 formāta.""" |
| normalized = value.strip().lower() |
| if normalized in {"1", "true", "yes", "on"}: |
| return True |
| if normalized in {"0", "false", "no", "off"}: |
| return False |
| raise argparse.ArgumentTypeError("Izmanto true/false, yes/no vai 1/0.") |
|
|
|
|
| def main() -> int: |
| """Izpilda vienu pilnu apmācības skrējienu pēc JSON konfigurācijas.""" |
| parser = argparse.ArgumentParser(description="Apmāca Maris AI modeli ar Transformers") |
| parser.add_argument("--config", help="JSON konfigurācijas fails") |
| parser.add_argument("--model-name", help="Bāzes modelis fine-tuningam") |
| parser.add_argument("--model-preset", help="Iepriekš definēts HF bāzes modeļa presets") |
| parser.add_argument("--dataset-repo", help="HF dataset repo ID") |
| parser.add_argument("--eval-dataset-repo", help="Atsevišķs HF eval dataset repo ID") |
| parser.add_argument( |
| "--benchmark-dataset-path", help="Lokāls JSON benchmarks release gate un score manifestam" |
| ) |
| parser.add_argument("--benchmark-name", help="Benchmark komplekta nosaukums artefaktiem") |
| parser.add_argument( |
| "--benchmark-levels", |
| help="Comma-separated benchmark līmeņi, piemēram local,ci,release", |
| ) |
| parser.add_argument( |
| "--benchmark-min-overall", |
| type=float, |
| help="Minimālais overall benchmark score release gate vajadzībām", |
| ) |
| parser.add_argument( |
| "--benchmark-gate-enabled", |
| type=_parse_bool_arg, |
| help="Vai training skrējiens jāaptur, ja benchmark gate neiziet", |
| ) |
| parser.add_argument( |
| "--preference-dataset-path", |
| help="Lokāls JSON preference-feedback datasets auditējamam artifactam", |
| ) |
| parser.add_argument( |
| "--preference-optimization", |
| help="Preference optimization režīms: none, dpo vai orpo", |
| ) |
| parser.add_argument("--preference-beta", type=float, help="DPO/ORPO beta parametrs") |
| parser.add_argument( |
| "--preference-max-prompt-length", |
| type=int, |
| help="Maksimālais prompt tokenu garums preference optimization laikā", |
| ) |
| parser.add_argument( |
| "--preference-max-length", |
| type=int, |
| help="Maksimālais kopējais tokenu garums preference optimization laikā", |
| ) |
| parser.add_argument( |
| "--preference-reference-model", |
| help="Atsauces modelis DPO preference optimization stadijai", |
| ) |
| parser.add_argument("--branch-name", help="Maris atzara nosaukums") |
| parser.add_argument("--branch-focus", help="Atzara specializācijas fokuss") |
| parser.add_argument("--adapter-type", help="Adapteru stratēģija, piemēram full vai lora") |
| parser.add_argument("--lora-r", type=int, help="LoRA rank parametrs PEFT adapteriem") |
| parser.add_argument("--lora-alpha", type=int, help="LoRA alpha parametrs PEFT adapteriem") |
| parser.add_argument("--lora-dropout", type=float, help="LoRA dropout parametrs") |
| parser.add_argument("--lora-bias", help="LoRA bias režīms, piemēram none vai all") |
| parser.add_argument( |
| "--peft-target-modules", |
| help="Comma-separated PEFT target modules saraksts", |
| ) |
| parser.add_argument("--qlora-quant-type", help="QLoRA quant type, piemēram nf4 vai fp4") |
| parser.add_argument( |
| "--qlora-use-double-quant", |
| type=_parse_bool_arg, |
| help="Vai QLoRA izmantot double quantization", |
| ) |
| parser.add_argument( |
| "--qlora-compute-dtype", |
| help="QLoRA compute dtype, piemēram float16 vai bfloat16", |
| ) |
| parser.add_argument( |
| "--distributed-strategy", |
| help="Distributed režīms: none, fsdp vai deepspeed", |
| ) |
| parser.add_argument( |
| "--distributed-config-path", |
| help="Ceļš uz FSDP vai DeepSpeed JSON konfigurāciju", |
| ) |
| parser.add_argument( |
| "--use-accelerate", |
| type=_parse_bool_arg, |
| help="Vai palaist treniņu ar accelerate launcher semantiku", |
| ) |
| parser.add_argument( |
| "--accelerate-config-path", |
| help="Ceļš uz accelerate launcher YAML konfigurāciju", |
| ) |
| parser.add_argument("--num-processes", type=int, help="Procesu/GPU skaits distributed launcham") |
| parser.add_argument("--num-machines", type=int, help="Mašīnu skaits distributed launcham") |
| parser.add_argument( |
| "--machine-rank", type=int, help="Pašreizējās mašīnas ranks distributed launcham" |
| ) |
| parser.add_argument("--main-process-ip", help="Galvenā procesa IP multi-node launcham") |
| parser.add_argument( |
| "--main-process-port", type=int, help="Galvenā procesa ports multi-node launcham" |
| ) |
| parser.add_argument( |
| "--fsdp-transformer-layer-cls-to-wrap", |
| help="Comma-separated transformer layer class saraksts FSDP auto-wrap vajadzībām", |
| ) |
| parser.add_argument( |
| "--fsdp-min-num-params", |
| type=int, |
| help="Minimālais parametru skaits FSDP wrap aktivēšanai", |
| ) |
| parser.add_argument("--hub-model-id", help="Maris model repo ID publicētajam rezultātam") |
| parser.add_argument("--output-dir", help="Kur saglabāt apmācīto modeli") |
| parser.add_argument("--num-epochs", type=int, help="Epoku skaits") |
| parser.add_argument("--learning-rate", type=float, help="Learning rate") |
| parser.add_argument("--max-seq-length", type=int, help="Maksimālais tokenu garums") |
| parser.add_argument( |
| "--push-to-hub", |
| type=_parse_bool_arg, |
| help="Vai pēc treniņa publicēt pilnu output direktoriju uz Hugging Face Hub", |
| ) |
| parser.add_argument( |
| "--all-branches", |
| action="store_true", |
| help="Palaist branch-specific training pipeline visiem atzariem", |
| ) |
| parser.add_argument( |
| "--validation-split-ratio", |
| type=float, |
| help="Validation split proporcija, ja repo nav validation split", |
| ) |
| parser.add_argument( |
| "--list-base-models", |
| action="store_true", |
| help="Izvada pieejamos bāzes modeļu presetus JSON formātā un beidz darbu", |
| ) |
| args = parser.parse_args() |
|
|
| logging.basicConfig(level=logging.INFO) |
| try: |
| if args.list_base_models: |
| print(json.dumps(list_training_base_models(), indent=2, ensure_ascii=False)) |
| return 0 |
|
|
| from maris_core.training.train import train_branch_suite, train_with_config |
|
|
| config = load_training_config( |
| args.config, |
| overrides={ |
| "model_name": args.model_name, |
| "model_preset": args.model_preset, |
| "dataset_repo": args.dataset_repo, |
| "eval_dataset_repo": args.eval_dataset_repo, |
| "benchmark_dataset_path": args.benchmark_dataset_path, |
| "benchmark_name": args.benchmark_name, |
| "benchmark_levels": args.benchmark_levels, |
| "benchmark_min_overall": args.benchmark_min_overall, |
| "benchmark_gate_enabled": args.benchmark_gate_enabled, |
| "preference_dataset_path": args.preference_dataset_path, |
| "preference_optimization": args.preference_optimization, |
| "preference_beta": args.preference_beta, |
| "preference_max_prompt_length": args.preference_max_prompt_length, |
| "preference_max_length": args.preference_max_length, |
| "preference_reference_model": args.preference_reference_model, |
| "branch_name": args.branch_name, |
| "branch_focus": args.branch_focus, |
| "adapter_type": args.adapter_type, |
| "lora_r": args.lora_r, |
| "lora_alpha": args.lora_alpha, |
| "lora_dropout": args.lora_dropout, |
| "lora_bias": args.lora_bias, |
| "peft_target_modules": args.peft_target_modules, |
| "qlora_quant_type": args.qlora_quant_type, |
| "qlora_use_double_quant": args.qlora_use_double_quant, |
| "qlora_compute_dtype": args.qlora_compute_dtype, |
| "distributed_strategy": args.distributed_strategy, |
| "distributed_config_path": args.distributed_config_path, |
| "use_accelerate": args.use_accelerate, |
| "accelerate_config_path": args.accelerate_config_path, |
| "num_processes": args.num_processes, |
| "num_machines": args.num_machines, |
| "machine_rank": args.machine_rank, |
| "main_process_ip": args.main_process_ip, |
| "main_process_port": args.main_process_port, |
| "fsdp_transformer_layer_cls_to_wrap": args.fsdp_transformer_layer_cls_to_wrap, |
| "fsdp_min_num_params": args.fsdp_min_num_params, |
| "hub_model_id": args.hub_model_id, |
| "output_dir": args.output_dir, |
| "num_epochs": args.num_epochs, |
| "learning_rate": args.learning_rate, |
| "max_seq_length": args.max_seq_length, |
| "push_to_hub": args.push_to_hub, |
| "validation_split_ratio": args.validation_split_ratio, |
| }, |
| ) |
| execution_config = replace(config, push_to_hub=False) if args.all_branches else config |
| metrics = ( |
| train_branch_suite(execution_config) if args.all_branches else train_with_config(config) |
| ) |
| logger.info("Training metrics: %s", metrics) |
| print(json.dumps(metrics, indent=2, ensure_ascii=False)) |
| return 0 |
| except (FileNotFoundError, ImportError, ValueError) as exc: |
| parser.exit(2, f"{parser.prog}: error: {exc}\n") |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|