maris-ai-master / core-python /scripts /train_model.py
MarisUK's picture
Maris AI model sync
f440f03 verified
"""CLI ieeja pilnam Maris AI training pipeline skrējienam."""
from __future__ import annotations
import argparse
import json
import logging
from dataclasses import replace
from maris_core.training.config import list_training_base_models, load_training_config
logger = logging.getLogger(__name__)
def _parse_bool_arg(value: str) -> bool:
"""Parsē CLI boolean vērtību no true/false, yes/no vai 1/0 formāta."""
normalized = value.strip().lower()
if normalized in {"1", "true", "yes", "on"}:
return True
if normalized in {"0", "false", "no", "off"}:
return False
raise argparse.ArgumentTypeError("Izmanto true/false, yes/no vai 1/0.")
def main() -> int:
"""Izpilda vienu pilnu apmācības skrējienu pēc JSON konfigurācijas."""
parser = argparse.ArgumentParser(description="Apmāca Maris AI modeli ar Transformers")
parser.add_argument("--config", help="JSON konfigurācijas fails")
parser.add_argument("--model-name", help="Bāzes modelis fine-tuningam")
parser.add_argument("--model-preset", help="Iepriekš definēts HF bāzes modeļa presets")
parser.add_argument("--dataset-repo", help="HF dataset repo ID")
parser.add_argument("--eval-dataset-repo", help="Atsevišķs HF eval dataset repo ID")
parser.add_argument(
"--benchmark-dataset-path", help="Lokāls JSON benchmarks release gate un score manifestam"
)
parser.add_argument("--benchmark-name", help="Benchmark komplekta nosaukums artefaktiem")
parser.add_argument(
"--benchmark-levels",
help="Comma-separated benchmark līmeņi, piemēram local,ci,release",
)
parser.add_argument(
"--benchmark-min-overall",
type=float,
help="Minimālais overall benchmark score release gate vajadzībām",
)
parser.add_argument(
"--benchmark-gate-enabled",
type=_parse_bool_arg,
help="Vai training skrējiens jāaptur, ja benchmark gate neiziet",
)
parser.add_argument(
"--preference-dataset-path",
help="Lokāls JSON preference-feedback datasets auditējamam artifactam",
)
parser.add_argument(
"--preference-optimization",
help="Preference optimization režīms: none, dpo vai orpo",
)
parser.add_argument("--preference-beta", type=float, help="DPO/ORPO beta parametrs")
parser.add_argument(
"--preference-max-prompt-length",
type=int,
help="Maksimālais prompt tokenu garums preference optimization laikā",
)
parser.add_argument(
"--preference-max-length",
type=int,
help="Maksimālais kopējais tokenu garums preference optimization laikā",
)
parser.add_argument(
"--preference-reference-model",
help="Atsauces modelis DPO preference optimization stadijai",
)
parser.add_argument("--branch-name", help="Maris atzara nosaukums")
parser.add_argument("--branch-focus", help="Atzara specializācijas fokuss")
parser.add_argument("--adapter-type", help="Adapteru stratēģija, piemēram full vai lora")
parser.add_argument("--lora-r", type=int, help="LoRA rank parametrs PEFT adapteriem")
parser.add_argument("--lora-alpha", type=int, help="LoRA alpha parametrs PEFT adapteriem")
parser.add_argument("--lora-dropout", type=float, help="LoRA dropout parametrs")
parser.add_argument("--lora-bias", help="LoRA bias režīms, piemēram none vai all")
parser.add_argument(
"--peft-target-modules",
help="Comma-separated PEFT target modules saraksts",
)
parser.add_argument("--qlora-quant-type", help="QLoRA quant type, piemēram nf4 vai fp4")
parser.add_argument(
"--qlora-use-double-quant",
type=_parse_bool_arg,
help="Vai QLoRA izmantot double quantization",
)
parser.add_argument(
"--qlora-compute-dtype",
help="QLoRA compute dtype, piemēram float16 vai bfloat16",
)
parser.add_argument(
"--distributed-strategy",
help="Distributed režīms: none, fsdp vai deepspeed",
)
parser.add_argument(
"--distributed-config-path",
help="Ceļš uz FSDP vai DeepSpeed JSON konfigurāciju",
)
parser.add_argument(
"--use-accelerate",
type=_parse_bool_arg,
help="Vai palaist treniņu ar accelerate launcher semantiku",
)
parser.add_argument(
"--accelerate-config-path",
help="Ceļš uz accelerate launcher YAML konfigurāciju",
)
parser.add_argument("--num-processes", type=int, help="Procesu/GPU skaits distributed launcham")
parser.add_argument("--num-machines", type=int, help="Mašīnu skaits distributed launcham")
parser.add_argument(
"--machine-rank", type=int, help="Pašreizējās mašīnas ranks distributed launcham"
)
parser.add_argument("--main-process-ip", help="Galvenā procesa IP multi-node launcham")
parser.add_argument(
"--main-process-port", type=int, help="Galvenā procesa ports multi-node launcham"
)
parser.add_argument(
"--fsdp-transformer-layer-cls-to-wrap",
help="Comma-separated transformer layer class saraksts FSDP auto-wrap vajadzībām",
)
parser.add_argument(
"--fsdp-min-num-params",
type=int,
help="Minimālais parametru skaits FSDP wrap aktivēšanai",
)
parser.add_argument("--hub-model-id", help="Maris model repo ID publicētajam rezultātam")
parser.add_argument("--output-dir", help="Kur saglabāt apmācīto modeli")
parser.add_argument("--num-epochs", type=int, help="Epoku skaits")
parser.add_argument("--learning-rate", type=float, help="Learning rate")
parser.add_argument("--max-seq-length", type=int, help="Maksimālais tokenu garums")
parser.add_argument(
"--push-to-hub",
type=_parse_bool_arg,
help="Vai pēc treniņa publicēt pilnu output direktoriju uz Hugging Face Hub",
)
parser.add_argument(
"--all-branches",
action="store_true",
help="Palaist branch-specific training pipeline visiem atzariem",
)
parser.add_argument(
"--validation-split-ratio",
type=float,
help="Validation split proporcija, ja repo nav validation split",
)
parser.add_argument(
"--list-base-models",
action="store_true",
help="Izvada pieejamos bāzes modeļu presetus JSON formātā un beidz darbu",
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
try:
if args.list_base_models:
print(json.dumps(list_training_base_models(), indent=2, ensure_ascii=False))
return 0
from maris_core.training.train import train_branch_suite, train_with_config
config = load_training_config(
args.config,
overrides={
"model_name": args.model_name,
"model_preset": args.model_preset,
"dataset_repo": args.dataset_repo,
"eval_dataset_repo": args.eval_dataset_repo,
"benchmark_dataset_path": args.benchmark_dataset_path,
"benchmark_name": args.benchmark_name,
"benchmark_levels": args.benchmark_levels,
"benchmark_min_overall": args.benchmark_min_overall,
"benchmark_gate_enabled": args.benchmark_gate_enabled,
"preference_dataset_path": args.preference_dataset_path,
"preference_optimization": args.preference_optimization,
"preference_beta": args.preference_beta,
"preference_max_prompt_length": args.preference_max_prompt_length,
"preference_max_length": args.preference_max_length,
"preference_reference_model": args.preference_reference_model,
"branch_name": args.branch_name,
"branch_focus": args.branch_focus,
"adapter_type": args.adapter_type,
"lora_r": args.lora_r,
"lora_alpha": args.lora_alpha,
"lora_dropout": args.lora_dropout,
"lora_bias": args.lora_bias,
"peft_target_modules": args.peft_target_modules,
"qlora_quant_type": args.qlora_quant_type,
"qlora_use_double_quant": args.qlora_use_double_quant,
"qlora_compute_dtype": args.qlora_compute_dtype,
"distributed_strategy": args.distributed_strategy,
"distributed_config_path": args.distributed_config_path,
"use_accelerate": args.use_accelerate,
"accelerate_config_path": args.accelerate_config_path,
"num_processes": args.num_processes,
"num_machines": args.num_machines,
"machine_rank": args.machine_rank,
"main_process_ip": args.main_process_ip,
"main_process_port": args.main_process_port,
"fsdp_transformer_layer_cls_to_wrap": args.fsdp_transformer_layer_cls_to_wrap,
"fsdp_min_num_params": args.fsdp_min_num_params,
"hub_model_id": args.hub_model_id,
"output_dir": args.output_dir,
"num_epochs": args.num_epochs,
"learning_rate": args.learning_rate,
"max_seq_length": args.max_seq_length,
"push_to_hub": args.push_to_hub,
"validation_split_ratio": args.validation_split_ratio,
},
)
execution_config = replace(config, push_to_hub=False) if args.all_branches else config
metrics = (
train_branch_suite(execution_config) if args.all_branches else train_with_config(config)
)
logger.info("Training metrics: %s", metrics)
print(json.dumps(metrics, indent=2, ensure_ascii=False))
return 0
except (FileNotFoundError, ImportError, ValueError) as exc:
parser.exit(2, f"{parser.prog}: error: {exc}\n")
if __name__ == "__main__":
raise SystemExit(main())