| """Pal墨gfunkcijas Maris treni艈u darba telpas UI.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| import re |
| import signal |
| import time |
| from pathlib import Path |
| from typing import Any |
|
|
| from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator |
|
|
| from maris_core.training.config import list_training_base_models |
| from maris_core.utils.env import validate_maris_model, validate_maris_repo |
|
|
| HF_REPO_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*/[A-Za-z0-9][A-Za-z0-9._-]*$") |
| SAFE_OUTPUT_SEGMENT_RE = re.compile(r"^[A-Za-z0-9._/-]+$") |
| LOG_EPOCH_RE = re.compile(r"Epoch\s+(\d+(?:\.\d+)?)\s*/\s*(\d+(?:\.\d+)?)", re.IGNORECASE) |
| VALUE_EPOCH_RE = re.compile(r"(?:['\"]?epoch['\"]?\s*[:=]\s*)(\d+(?:\.\d+)?)", re.IGNORECASE) |
| LOSS_RE = re.compile(r"(?:['\"]?loss['\"]?\s*[:=]\s*)(\d+(?:\.\d+)?)", re.IGNORECASE) |
| EVAL_LOSS_RE = re.compile(r"(?:['\"]?eval_loss['\"]?\s*[:=]\s*)(\d+(?:\.\d+)?)", re.IGNORECASE) |
| LEARNING_RATE_RE = re.compile( |
| r"(?:['\"]?learning_rate['\"]?\s*[:=]\s*)(\d+(?:\.\d+)?(?:e[+-]?\d+)?)", |
| re.IGNORECASE, |
| ) |
| STEP_PROGRESS_RE = re.compile(r"(\d+)\s*/\s*(\d+)\s*\[", re.IGNORECASE) |
| MARIS_PROGRESS_EVENT_KEY = "maris_training_event" |
| |
| |
| MAX_STRUCTURED_EVENTS = 64 |
| SPACE_TRAINING_CONFIG_PATH_DEFAULT = "huggingface/training-config.json" |
| SPACE_TRAINING_COMPLETION_MARKERS = ( |
| "training-metrics.json", |
| "trainer_state.json", |
| "training-provenance.json", |
| "branch-suite.json", |
| ) |
|
|
|
|
| def _validate_repo_id(value: str) -> str: |
| normalized = value.strip() |
| if not HF_REPO_ID_RE.fullmatch(normalized): |
| raise ValueError("Repo ID j膩b奴t form膩t膩 owner/name.") |
| return normalized |
|
|
|
|
| class SpaceTrainingRequest(BaseModel): |
| """UI piepras墨jums Maris treni艈a palai拧anai.""" |
|
|
| model_config = ConfigDict(str_strip_whitespace=True) |
|
|
| dataset_repo: str = "MarisUK/maris-ai-memory" |
| model_repo: str = "MarisUK/maris-ai-master" |
| hub_model_id: str = "" |
| model_preset: str = "balanced" |
| model_name: str = "" |
| num_epochs: int = Field(default=3, ge=1, le=100) |
| all_branches: bool = False |
| push_to_hub: bool = True |
| output_subdir: str = "maris-ai-master" |
| continue_from_latest_artifact: bool = True |
| continue_model_path: str = "" |
|
|
| @field_validator("dataset_repo") |
| @classmethod |
| def validate_dataset_repo(cls, value: str) -> str: |
| normalized = _validate_repo_id(value) |
| try: |
| return validate_maris_repo(normalized, "dataset_repo", label="repozitorijs") |
| except RuntimeError as exc: |
| raise ValueError(str(exc)) from exc |
|
|
| @field_validator("model_repo") |
| @classmethod |
| def validate_model_repo(cls, value: str) -> str: |
| if not value.strip(): |
| return "" |
| normalized = _validate_repo_id(value) |
| try: |
| return validate_maris_model(normalized, "model_repo") |
| except RuntimeError as exc: |
| raise ValueError(str(exc)) from exc |
|
|
| @field_validator("hub_model_id") |
| @classmethod |
| def validate_hub_model_id(cls, value: str) -> str: |
| if not value.strip(): |
| return "" |
| normalized = _validate_repo_id(value) |
| try: |
| return validate_maris_model(normalized, "hub_model_id") |
| except RuntimeError as exc: |
| raise ValueError(str(exc)) from exc |
|
|
| @field_validator("model_name") |
| @classmethod |
| def validate_model_name(cls, value: str) -> str: |
| normalized = value.strip() |
| if normalized and not HF_REPO_ID_RE.fullmatch(normalized): |
| raise ValueError("Custom modelim j膩b奴t form膩t膩 owner/name.") |
| return normalized |
|
|
| @field_validator("model_preset") |
| @classmethod |
| def validate_model_preset(cls, value: str) -> str: |
| normalized = value.strip() |
| if not normalized: |
| return "" |
| if normalized not in list_training_base_models(): |
| raise ValueError("Nezin膩ms mode募a presets.") |
| return normalized |
|
|
| @field_validator("output_subdir") |
| @classmethod |
| def validate_output_subdir(cls, value: str) -> str: |
| normalized = value.strip().strip("/") |
| if not normalized: |
| raise ValueError("Output apak拧direktorija nedr墨kst b奴t tuk拧a.") |
| if ".." in Path(normalized).parts or not SAFE_OUTPUT_SEGMENT_RE.fullmatch(normalized): |
| raise ValueError("Output apak拧direktorij膩 dr墨kst b奴t tikai dro拧i ce募a segmenti.") |
| return normalized |
|
|
| @field_validator("continue_model_path") |
| @classmethod |
| def validate_continue_model_path(cls, value: str) -> str: |
| normalized = value.strip() |
| if not normalized: |
| return "" |
| stripped = normalized.strip("/") |
| if not stripped: |
| raise ValueError("Continue mode募a ce募拧 nedr墨kst b奴t tuk拧s.") |
| parts = Path(stripped).parts |
| if ".." in parts or not SAFE_OUTPUT_SEGMENT_RE.fullmatch(stripped): |
| raise ValueError("Continue mode募a ce募膩 dr墨kst b奴t tikai dro拧i ce募a segmenti.") |
| return stripped |
|
|
| @model_validator(mode="after") |
| def validate_model_selection(self) -> SpaceTrainingRequest: |
| resolved_model_repo = self.hub_model_id or self.model_repo |
| if not resolved_model_repo: |
| raise ValueError("J膩nor膩da hub_model_id vai model_repo.") |
| self.hub_model_id = resolved_model_repo |
| self.model_repo = resolved_model_repo |
| if not self.model_name and not self.model_preset: |
| self.model_preset = "balanced" |
| return self |
|
|
|
|
| def resolve_output_dir(persistent_dir: str, output_subdir: str) -> Path: |
| """Normaliz膿 output ce募u persistent storage ietvaros.""" |
| root = Path(persistent_dir).expanduser().resolve() |
| target = (root / output_subdir).resolve() |
| if os.path.commonpath([str(root), str(target)]) != str(root): |
| raise ValueError("Output direktorijai j膩atrodas Maris persistent storage ietvaros.") |
| return target |
|
|
|
|
| def resolve_optional_persistent_path(persistent_dir: str, path_value: str) -> Path | None: |
| """Normaliz膿 optional persistent storage ce募u.""" |
| normalized = str(path_value or "").strip().strip("/") |
| if not normalized: |
| return None |
| root = Path(persistent_dir).expanduser().resolve() |
| target = (root / normalized).resolve() |
| if os.path.commonpath([str(root), str(target)]) != str(root): |
| raise ValueError( |
| "Continue mode募a direktorijai j膩atrodas Maris persistent storage ietvaros." |
| ) |
| return target |
|
|
|
|
| def build_space_training_command(script_path: str, request: SpaceTrainingRequest) -> list[str]: |
| """Izveido dro拧u komandu Maris treni艈a palai拧anai.""" |
| command = ["bash", script_path] |
| if request.model_name: |
| command.extend(["--model-name", request.model_name]) |
| elif request.model_preset: |
| command.extend(["--model-preset", request.model_preset]) |
| if request.all_branches: |
| command.append("--all-branches") |
| return command |
|
|
|
|
| def build_space_training_env( |
| base_env: dict[str, str], |
| request: SpaceTrainingRequest, |
| persistent_dir: str, |
| ) -> dict[str, str]: |
| """Sagatavo vidi Maris treni艈a procesam.""" |
| output_dir = resolve_output_dir(persistent_dir, request.output_subdir) |
| continue_model_dir = resolve_optional_persistent_path( |
| persistent_dir, request.continue_model_path |
| ) |
| config_path = ( |
| str( |
| base_env.get("MARIS_SPACE_TRAIN_CONFIG_PATH") |
| or base_env.get("HF_SPACE_TRAINING_CONFIG_PATH") |
| or base_env.get("MARIS_TRAIN_CONFIG_PATH") |
| or base_env.get("HF_TRAINING_CONFIG_PATH") |
| or SPACE_TRAINING_CONFIG_PATH_DEFAULT |
| ).strip() |
| or SPACE_TRAINING_CONFIG_PATH_DEFAULT |
| ) |
| env = dict(base_env) |
| env.update( |
| { |
| "MARIS_PERSISTENT_DIR": persistent_dir, |
| "MARIS_MEMORY_REPO": request.dataset_repo, |
| "MARIS_MODEL_REPO": request.hub_model_id, |
| "MARIS_TRAIN_CONFIG_PATH": config_path, |
| "MARIS_TRAIN_NUM_EPOCHS": str(request.num_epochs), |
| "MARIS_TRAIN_PUBLISH": "true" if request.push_to_hub else "false", |
| "MARIS_TRAIN_OUTPUT_DIR": str(output_dir), |
| "MARIS_LOCAL_MODEL_DIR": str(output_dir), |
| "MARIS_TRAIN_CONTINUE_FROM_LATEST": ( |
| "true" if request.continue_from_latest_artifact else "false" |
| ), |
| "HF_PERSISTENT_DIR": persistent_dir, |
| "HF_DATASET_REPO": request.dataset_repo, |
| "HF_MODEL_REPO": request.hub_model_id, |
| "HF_TRAINING_CONFIG_PATH": config_path, |
| "HF_TRAIN_NUM_EPOCHS": str(request.num_epochs), |
| "HF_TRAIN_PUSH_TO_HUB": "true" if request.push_to_hub else "false", |
| "HF_TRAIN_OUTPUT_DIR": str(output_dir), |
| "HF_LOCAL_MODEL_DIR": str(output_dir), |
| "HF_TRAIN_CONTINUE_FROM_LATEST": ( |
| "true" if request.continue_from_latest_artifact else "false" |
| ), |
| "MARIS_TRAIN_DISTRIBUTED_STRATEGY": "none", |
| "HF_TRAIN_DISTRIBUTED_STRATEGY": "none", |
| "PYTHONUNBUFFERED": env.get("PYTHONUNBUFFERED", "1"), |
| } |
| ) |
| env.pop("MARIS_TRAIN_DISTRIBUTED_CONFIG_PATH", None) |
| env.pop("HF_TRAIN_DISTRIBUTED_CONFIG_PATH", None) |
| if continue_model_dir is not None: |
| env["MARIS_TRAIN_CONTINUE_MODEL_PATH"] = str(continue_model_dir) |
| env["HF_TRAIN_CONTINUE_MODEL_PATH"] = str(continue_model_dir) |
| else: |
| env.pop("MARIS_TRAIN_CONTINUE_MODEL_PATH", None) |
| env.pop("HF_TRAIN_CONTINUE_MODEL_PATH", None) |
| if request.model_name: |
| env["MARIS_TRAIN_BASE_MODEL"] = request.model_name |
| env["HF_TRAIN_BASE_MODEL"] = request.model_name |
| env.pop("MARIS_TRAIN_MODEL_PRESET", None) |
| env.pop("HF_TRAIN_MODEL_PRESET", None) |
| elif request.model_preset: |
| env["MARIS_TRAIN_MODEL_PRESET"] = request.model_preset |
| env["HF_TRAIN_MODEL_PRESET"] = request.model_preset |
| env.pop("MARIS_TRAIN_BASE_MODEL", None) |
| env.pop("HF_TRAIN_BASE_MODEL", None) |
| return env |
|
|
|
|
| def has_completed_training_artifacts(output_dir: Path) -> bool: |
| """Nosaka, vai Space output direktorij膩 jau ir pabeigta treni艈a artefakti.""" |
| return any( |
| output_dir.joinpath(marker).is_file() for marker in SPACE_TRAINING_COMPLETION_MARKERS |
| ) |
|
|
|
|
| def tail_log(log_path: str | Path, *, max_chars: int = 16000) -> str: |
| """Atgrie啪 loga beigas UI vajadz墨b膩m.""" |
| path = Path(log_path) |
| if not path.is_file(): |
| return "" |
| content = path.read_text(encoding="utf-8", errors="replace") |
| return content[-max_chars:] |
|
|
|
|
| def read_log_since(log_path: str | Path, offset: int, *, max_chars: int = 8192) -> tuple[str, int]: |
| """Nolasa tikai jauno loga da募u, s膩kot no dot膩 offset.""" |
| path = Path(log_path) |
| if not path.is_file(): |
| return "", 0 |
|
|
| file_size = path.stat().st_size |
| next_offset = max(0, min(offset, file_size)) |
|
|
| with path.open(encoding="utf-8", errors="replace") as handle: |
| handle.seek(next_offset) |
| chunk = handle.read(max_chars) |
| next_offset = handle.tell() |
|
|
| return chunk, next_offset |
|
|
|
|
| def parse_training_progress( |
| log_text: str, |
| *, |
| request: dict[str, Any] | None = None, |
| running: bool, |
| exit_code: int | None, |
| ) -> dict[str, Any]: |
| """Heuristiski izvada treni艈a progresu no logiem un st膩vok募a.""" |
| lower_log = log_text.lower() |
| total_epochs = int(request.get("num_epochs", 0)) or None if request else None |
| current_epoch: float | None = None |
| detected_total_epochs = total_epochs |
| current_step: int | None = None |
| total_steps: int | None = None |
| loss: float | None = None |
| eval_loss: float | None = None |
| learning_rate: float | None = None |
| eta_seconds: float | None = None |
| structured_stage: str | None = None |
| structured_label: str | None = None |
|
|
| structured_events = _extract_structured_training_events(log_text) |
| if structured_events: |
| latest_event = structured_events[-1] |
| current_epoch = _as_float(latest_event.get("epoch")) |
| detected_total_epochs = ( |
| _as_int(latest_event.get("total_epochs")) or detected_total_epochs or total_epochs |
| ) |
| current_step = _as_int(latest_event.get("step")) |
| total_steps = _as_int(latest_event.get("total_steps")) |
| loss = _as_float(latest_event.get("loss")) |
| eval_loss = _as_float(latest_event.get("eval_loss")) |
| learning_rate = _as_float(latest_event.get("learning_rate")) |
| eta_seconds = _as_float(latest_event.get("eta_seconds")) |
| structured_stage = _normalize_stage(latest_event.get("stage")) |
| structured_label = _normalize_label(latest_event.get("label")) |
|
|
| epoch_matches = list(LOG_EPOCH_RE.finditer(log_text)) |
| if current_epoch is None and epoch_matches: |
| last_match = epoch_matches[-1] |
| current_epoch = float(last_match.group(1)) |
| detected_total_epochs = int(float(last_match.group(2))) |
| elif current_epoch is None: |
| value_matches = list(VALUE_EPOCH_RE.finditer(log_text)) |
| if value_matches: |
| current_epoch = float(value_matches[-1].group(1)) |
|
|
| if current_step is None or total_steps is None: |
| step_matches = list(STEP_PROGRESS_RE.finditer(log_text)) |
| if step_matches: |
| last_match = step_matches[-1] |
| current_step = int(last_match.group(1)) |
| total_steps = int(last_match.group(2)) |
|
|
| if loss is None: |
| loss_matches = list(LOSS_RE.finditer(log_text)) |
| loss = float(loss_matches[-1].group(1)) if loss_matches else None |
|
|
| if eval_loss is None: |
| eval_loss_matches = list(EVAL_LOSS_RE.finditer(log_text)) |
| eval_loss = float(eval_loss_matches[-1].group(1)) if eval_loss_matches else None |
|
|
| if learning_rate is None: |
| learning_rate_matches = list(LEARNING_RATE_RE.finditer(log_text)) |
| learning_rate = float(learning_rate_matches[-1].group(1)) if learning_rate_matches else None |
|
|
| percent = 0 |
| stage = structured_stage or "queued" |
| label = structured_label or "Gaida treni艈a startu" |
|
|
| if "stop requested by user" in lower_log or "training stopped by user" in lower_log: |
| stage = "stopped" |
| label = "Treni艈拧 aptur膿ts p膿c piepras墨juma" |
| elif exit_code == 0: |
| stage = "completed" |
| label = "Treni艈拧 pabeigts veiksm墨gi" |
| percent = 100 |
| elif exit_code is not None and not running: |
| stage = "failed" |
| label = f"Treni艈拧 beidz膩s ar k募奴du (exit {exit_code})" |
| percent = 100 |
| elif stage == "publishing": |
| label = structured_label or "Public膿 modeli origin repozitorij膩" |
| percent = 96 |
| elif stage == "saving": |
| label = structured_label or "Saglab膩 mode募a artefaktus" |
| percent = 92 |
| elif stage == "evaluating": |
| label = structured_label or "Veic valid膩ciju un eval metriku apr膿姆inu" |
| percent = 88 if current_step else 82 |
| elif stage == "benchmarking": |
| label = structured_label or "Palai啪 benchmark un release gate p膩rbaudes" |
| percent = 94 |
| elif stage == "preparing": |
| label = structured_label or "Sagatavo datus, modeli un cache" |
| percent = 20 |
| elif any(token in lower_log for token in ("uploading", "pushing", "export_to_hf")): |
| stage = "publishing" |
| label = "Public膿 modeli origin repozitorij膩" |
| percent = 96 |
| elif current_step is not None and total_steps: |
| stage = structured_stage or "training" |
| progress_ratio = min(current_step / max(total_steps, 1), 1.0) |
| percent = min(95, max(35, int(35 + progress_ratio * 55))) |
| label = structured_label or f"Tren膿 modeli 路 solis {current_step}/{total_steps}" |
| elif current_epoch is not None: |
| stage = structured_stage or "training" |
| epoch_total = detected_total_epochs or total_epochs |
| if epoch_total: |
| percent = min(95, max(35, int(35 + min(current_epoch / epoch_total, 1.0) * 55))) |
| label = structured_label or f"Tren膿 modeli 路 epoha {current_epoch:g}/{epoch_total}" |
| else: |
| percent = 65 |
| label = structured_label or f"Tren膿 modeli 路 epoha {current_epoch:g}" |
| elif any( |
| token in lower_log |
| for token in ("tokeniz", "validation split", "dataset", "snapshot", "download", "cache") |
| ): |
| stage = structured_stage or "preparing" |
| label = structured_label or "Sagatavo datus, modeli un cache" |
| percent = 20 |
| elif running: |
| stage = structured_stage or "starting" |
| label = structured_label or "Inicializ膿 treni艈u" |
| percent = 5 |
|
|
| return { |
| "percent": percent, |
| "stage": stage, |
| "label": label, |
| "current_epoch": current_epoch, |
| "total_epochs": detected_total_epochs or total_epochs, |
| "loss": loss, |
| "eval_loss": eval_loss, |
| "learning_rate": learning_rate, |
| "current_step": current_step, |
| "total_steps": total_steps, |
| "eta_seconds": eta_seconds, |
| "events_detected": len(structured_events), |
| } |
|
|
|
|
| def _extract_structured_training_events(log_text: str) -> list[dict[str, Any]]: |
| events: list[dict[str, Any]] = [] |
| for line in reversed(log_text.splitlines()): |
| if MARIS_PROGRESS_EVENT_KEY not in line: |
| continue |
| json_start = line.find("{") |
| if json_start < 0: |
| continue |
| try: |
| payload = json.loads(line[json_start:]) |
| except json.JSONDecodeError: |
| continue |
| if isinstance(payload, dict) and payload.get(MARIS_PROGRESS_EVENT_KEY): |
| events.append(payload) |
| if len(events) >= MAX_STRUCTURED_EVENTS: |
| break |
| events.reverse() |
| return events |
|
|
|
|
| def _as_float(value: Any) -> float | None: |
| if isinstance(value, bool) or value in (None, ""): |
| return None |
| try: |
| return float(value) |
| except (TypeError, ValueError): |
| return None |
|
|
|
|
| def _as_int(value: Any) -> int | None: |
| if isinstance(value, bool) or value in (None, ""): |
| return None |
| try: |
| return int(float(value)) |
| except (TypeError, ValueError): |
| return None |
|
|
|
|
| def _normalize_stage(value: Any) -> str | None: |
| if not isinstance(value, str): |
| return None |
| normalized = value.strip().lower() |
| return normalized or None |
|
|
|
|
| def _normalize_label(value: Any) -> str | None: |
| if not isinstance(value, str): |
| return None |
| normalized = value.strip() |
| return normalized or None |
|
|
|
|
| def terminate_process_tree(process: Any, *, grace_seconds: float = 10.0) -> int | None: |
| """Apst膩dina procesu un t膩 child procesus iesp膿jami korekti.""" |
| if process.poll() is not None: |
| return process.returncode |
|
|
| try: |
| if hasattr(os, "killpg"): |
| os.killpg(process.pid, signal.SIGTERM) |
| else: |
| process.terminate() |
| except ProcessLookupError: |
| return process.poll() |
|
|
| deadline = time.monotonic() + grace_seconds |
| while time.monotonic() < deadline: |
| exit_code = process.poll() |
| if exit_code is not None: |
| return exit_code |
| time.sleep(0.1) |
|
|
| try: |
| if hasattr(os, "killpg"): |
| os.killpg(process.pid, signal.SIGKILL) |
| else: |
| process.kill() |
| except ProcessLookupError: |
| return process.poll() |
|
|
| try: |
| process.wait(timeout=1) |
| except Exception: |
| return process.poll() |
|
|
| return process.returncode |
|
|
|
|
| def list_space_model_choices() -> dict[str, dict[str, Any]]: |
| """Atgrie啪 UI vajadz墨b膩m pieejamos b膩zes mode募us.""" |
| return list_training_base_models() |
|
|