File size: 10,050 Bytes
f440f03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""CLI ieeja pilnam Maris AI training pipeline skrējienam."""

from __future__ import annotations

import argparse
import json
import logging
from dataclasses import replace

from maris_core.training.config import list_training_base_models, load_training_config

logger = logging.getLogger(__name__)


def _parse_bool_arg(value: str) -> bool:
    """Parsē CLI boolean vērtību no true/false, yes/no vai 1/0 formāta."""
    normalized = value.strip().lower()
    if normalized in {"1", "true", "yes", "on"}:
        return True
    if normalized in {"0", "false", "no", "off"}:
        return False
    raise argparse.ArgumentTypeError("Izmanto true/false, yes/no vai 1/0.")


def main() -> int:
    """Izpilda vienu pilnu apmācības skrējienu pēc JSON konfigurācijas."""
    parser = argparse.ArgumentParser(description="Apmāca Maris AI modeli ar Transformers")
    parser.add_argument("--config", help="JSON konfigurācijas fails")
    parser.add_argument("--model-name", help="Bāzes modelis fine-tuningam")
    parser.add_argument("--model-preset", help="Iepriekš definēts HF bāzes modeļa presets")
    parser.add_argument("--dataset-repo", help="HF dataset repo ID")
    parser.add_argument("--eval-dataset-repo", help="Atsevišķs HF eval dataset repo ID")
    parser.add_argument(
        "--benchmark-dataset-path", help="Lokāls JSON benchmarks release gate un score manifestam"
    )
    parser.add_argument("--benchmark-name", help="Benchmark komplekta nosaukums artefaktiem")
    parser.add_argument(
        "--benchmark-levels",
        help="Comma-separated benchmark līmeņi, piemēram local,ci,release",
    )
    parser.add_argument(
        "--benchmark-min-overall",
        type=float,
        help="Minimālais overall benchmark score release gate vajadzībām",
    )
    parser.add_argument(
        "--benchmark-gate-enabled",
        type=_parse_bool_arg,
        help="Vai training skrējiens jāaptur, ja benchmark gate neiziet",
    )
    parser.add_argument(
        "--preference-dataset-path",
        help="Lokāls JSON preference-feedback datasets auditējamam artifactam",
    )
    parser.add_argument(
        "--preference-optimization",
        help="Preference optimization režīms: none, dpo vai orpo",
    )
    parser.add_argument("--preference-beta", type=float, help="DPO/ORPO beta parametrs")
    parser.add_argument(
        "--preference-max-prompt-length",
        type=int,
        help="Maksimālais prompt tokenu garums preference optimization laikā",
    )
    parser.add_argument(
        "--preference-max-length",
        type=int,
        help="Maksimālais kopējais tokenu garums preference optimization laikā",
    )
    parser.add_argument(
        "--preference-reference-model",
        help="Atsauces modelis DPO preference optimization stadijai",
    )
    parser.add_argument("--branch-name", help="Maris atzara nosaukums")
    parser.add_argument("--branch-focus", help="Atzara specializācijas fokuss")
    parser.add_argument("--adapter-type", help="Adapteru stratēģija, piemēram full vai lora")
    parser.add_argument("--lora-r", type=int, help="LoRA rank parametrs PEFT adapteriem")
    parser.add_argument("--lora-alpha", type=int, help="LoRA alpha parametrs PEFT adapteriem")
    parser.add_argument("--lora-dropout", type=float, help="LoRA dropout parametrs")
    parser.add_argument("--lora-bias", help="LoRA bias režīms, piemēram none vai all")
    parser.add_argument(
        "--peft-target-modules",
        help="Comma-separated PEFT target modules saraksts",
    )
    parser.add_argument("--qlora-quant-type", help="QLoRA quant type, piemēram nf4 vai fp4")
    parser.add_argument(
        "--qlora-use-double-quant",
        type=_parse_bool_arg,
        help="Vai QLoRA izmantot double quantization",
    )
    parser.add_argument(
        "--qlora-compute-dtype",
        help="QLoRA compute dtype, piemēram float16 vai bfloat16",
    )
    parser.add_argument(
        "--distributed-strategy",
        help="Distributed režīms: none, fsdp vai deepspeed",
    )
    parser.add_argument(
        "--distributed-config-path",
        help="Ceļš uz FSDP vai DeepSpeed JSON konfigurāciju",
    )
    parser.add_argument(
        "--use-accelerate",
        type=_parse_bool_arg,
        help="Vai palaist treniņu ar accelerate launcher semantiku",
    )
    parser.add_argument(
        "--accelerate-config-path",
        help="Ceļš uz accelerate launcher YAML konfigurāciju",
    )
    parser.add_argument("--num-processes", type=int, help="Procesu/GPU skaits distributed launcham")
    parser.add_argument("--num-machines", type=int, help="Mašīnu skaits distributed launcham")
    parser.add_argument(
        "--machine-rank", type=int, help="Pašreizējās mašīnas ranks distributed launcham"
    )
    parser.add_argument("--main-process-ip", help="Galvenā procesa IP multi-node launcham")
    parser.add_argument(
        "--main-process-port", type=int, help="Galvenā procesa ports multi-node launcham"
    )
    parser.add_argument(
        "--fsdp-transformer-layer-cls-to-wrap",
        help="Comma-separated transformer layer class saraksts FSDP auto-wrap vajadzībām",
    )
    parser.add_argument(
        "--fsdp-min-num-params",
        type=int,
        help="Minimālais parametru skaits FSDP wrap aktivēšanai",
    )
    parser.add_argument("--hub-model-id", help="Maris model repo ID publicētajam rezultātam")
    parser.add_argument("--output-dir", help="Kur saglabāt apmācīto modeli")
    parser.add_argument("--num-epochs", type=int, help="Epoku skaits")
    parser.add_argument("--learning-rate", type=float, help="Learning rate")
    parser.add_argument("--max-seq-length", type=int, help="Maksimālais tokenu garums")
    parser.add_argument(
        "--push-to-hub",
        type=_parse_bool_arg,
        help="Vai pēc treniņa publicēt pilnu output direktoriju uz Hugging Face Hub",
    )
    parser.add_argument(
        "--all-branches",
        action="store_true",
        help="Palaist branch-specific training pipeline visiem atzariem",
    )
    parser.add_argument(
        "--validation-split-ratio",
        type=float,
        help="Validation split proporcija, ja repo nav validation split",
    )
    parser.add_argument(
        "--list-base-models",
        action="store_true",
        help="Izvada pieejamos bāzes modeļu presetus JSON formātā un beidz darbu",
    )
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)
    try:
        if args.list_base_models:
            print(json.dumps(list_training_base_models(), indent=2, ensure_ascii=False))
            return 0

        from maris_core.training.train import train_branch_suite, train_with_config

        config = load_training_config(
            args.config,
            overrides={
                "model_name": args.model_name,
                "model_preset": args.model_preset,
                "dataset_repo": args.dataset_repo,
                "eval_dataset_repo": args.eval_dataset_repo,
                "benchmark_dataset_path": args.benchmark_dataset_path,
                "benchmark_name": args.benchmark_name,
                "benchmark_levels": args.benchmark_levels,
                "benchmark_min_overall": args.benchmark_min_overall,
                "benchmark_gate_enabled": args.benchmark_gate_enabled,
                "preference_dataset_path": args.preference_dataset_path,
                "preference_optimization": args.preference_optimization,
                "preference_beta": args.preference_beta,
                "preference_max_prompt_length": args.preference_max_prompt_length,
                "preference_max_length": args.preference_max_length,
                "preference_reference_model": args.preference_reference_model,
                "branch_name": args.branch_name,
                "branch_focus": args.branch_focus,
                "adapter_type": args.adapter_type,
                "lora_r": args.lora_r,
                "lora_alpha": args.lora_alpha,
                "lora_dropout": args.lora_dropout,
                "lora_bias": args.lora_bias,
                "peft_target_modules": args.peft_target_modules,
                "qlora_quant_type": args.qlora_quant_type,
                "qlora_use_double_quant": args.qlora_use_double_quant,
                "qlora_compute_dtype": args.qlora_compute_dtype,
                "distributed_strategy": args.distributed_strategy,
                "distributed_config_path": args.distributed_config_path,
                "use_accelerate": args.use_accelerate,
                "accelerate_config_path": args.accelerate_config_path,
                "num_processes": args.num_processes,
                "num_machines": args.num_machines,
                "machine_rank": args.machine_rank,
                "main_process_ip": args.main_process_ip,
                "main_process_port": args.main_process_port,
                "fsdp_transformer_layer_cls_to_wrap": args.fsdp_transformer_layer_cls_to_wrap,
                "fsdp_min_num_params": args.fsdp_min_num_params,
                "hub_model_id": args.hub_model_id,
                "output_dir": args.output_dir,
                "num_epochs": args.num_epochs,
                "learning_rate": args.learning_rate,
                "max_seq_length": args.max_seq_length,
                "push_to_hub": args.push_to_hub,
                "validation_split_ratio": args.validation_split_ratio,
            },
        )
        execution_config = replace(config, push_to_hub=False) if args.all_branches else config
        metrics = (
            train_branch_suite(execution_config) if args.all_branches else train_with_config(config)
        )
        logger.info("Training metrics: %s", metrics)
        print(json.dumps(metrics, indent=2, ensure_ascii=False))
        return 0
    except (FileNotFoundError, ImportError, ValueError) as exc:
        parser.exit(2, f"{parser.prog}: error: {exc}\n")


if __name__ == "__main__":
    raise SystemExit(main())