"""Palīgfunkcijas Maris treniņu darba telpas UI.""" from __future__ import annotations import json import os import re import signal import time from pathlib import Path from typing import Any from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from maris_core.training.config import list_training_base_models from maris_core.utils.env import validate_maris_model, validate_maris_repo HF_REPO_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*/[A-Za-z0-9][A-Za-z0-9._-]*$") SAFE_OUTPUT_SEGMENT_RE = re.compile(r"^[A-Za-z0-9._/-]+$") LOG_EPOCH_RE = re.compile(r"Epoch\s+(\d+(?:\.\d+)?)\s*/\s*(\d+(?:\.\d+)?)", re.IGNORECASE) VALUE_EPOCH_RE = re.compile(r"(?:['\"]?epoch['\"]?\s*[:=]\s*)(\d+(?:\.\d+)?)", re.IGNORECASE) LOSS_RE = re.compile(r"(?:['\"]?loss['\"]?\s*[:=]\s*)(\d+(?:\.\d+)?)", re.IGNORECASE) EVAL_LOSS_RE = re.compile(r"(?:['\"]?eval_loss['\"]?\s*[:=]\s*)(\d+(?:\.\d+)?)", re.IGNORECASE) LEARNING_RATE_RE = re.compile( r"(?:['\"]?learning_rate['\"]?\s*[:=]\s*)(\d+(?:\.\d+)?(?:e[+-]?\d+)?)", re.IGNORECASE, ) STEP_PROGRESS_RE = re.compile(r"(\d+)\s*/\s*(\d+)\s*\[", re.IGNORECASE) MARIS_PROGRESS_EVENT_KEY = "maris_training_event" # Keep a larger rolling event window than persisted run history because live status # parsing needs several recent progress/save/eval events from the current log tail. MAX_STRUCTURED_EVENTS = 64 SPACE_TRAINING_CONFIG_PATH_DEFAULT = "huggingface/training-config.json" SPACE_TRAINING_COMPLETION_MARKERS = ( "training-metrics.json", "trainer_state.json", "training-provenance.json", "branch-suite.json", ) def _validate_repo_id(value: str) -> str: normalized = value.strip() if not HF_REPO_ID_RE.fullmatch(normalized): raise ValueError("Repo ID jābūt formātā owner/name.") return normalized class SpaceTrainingRequest(BaseModel): """UI pieprasījums Maris treniņa palaišanai.""" model_config = ConfigDict(str_strip_whitespace=True) dataset_repo: str = "MarisUK/maris-ai-memory" model_repo: str = "MarisUK/maris-ai-master" hub_model_id: str = "" model_preset: str = "balanced" model_name: str = "" num_epochs: int = Field(default=3, ge=1, le=100) all_branches: bool = False push_to_hub: bool = True output_subdir: str = "maris-ai-master" continue_from_latest_artifact: bool = True continue_model_path: str = "" @field_validator("dataset_repo") @classmethod def validate_dataset_repo(cls, value: str) -> str: normalized = _validate_repo_id(value) try: return validate_maris_repo(normalized, "dataset_repo", label="repozitorijs") except RuntimeError as exc: raise ValueError(str(exc)) from exc @field_validator("model_repo") @classmethod def validate_model_repo(cls, value: str) -> str: if not value.strip(): return "" normalized = _validate_repo_id(value) try: return validate_maris_model(normalized, "model_repo") except RuntimeError as exc: raise ValueError(str(exc)) from exc @field_validator("hub_model_id") @classmethod def validate_hub_model_id(cls, value: str) -> str: if not value.strip(): return "" normalized = _validate_repo_id(value) try: return validate_maris_model(normalized, "hub_model_id") except RuntimeError as exc: raise ValueError(str(exc)) from exc @field_validator("model_name") @classmethod def validate_model_name(cls, value: str) -> str: normalized = value.strip() if normalized and not HF_REPO_ID_RE.fullmatch(normalized): raise ValueError("Custom modelim jābūt formātā owner/name.") return normalized @field_validator("model_preset") @classmethod def validate_model_preset(cls, value: str) -> str: normalized = value.strip() if not normalized: return "" if normalized not in list_training_base_models(): raise ValueError("Nezināms modeļa presets.") return normalized @field_validator("output_subdir") @classmethod def validate_output_subdir(cls, value: str) -> str: normalized = value.strip().strip("/") if not normalized: raise ValueError("Output apakšdirektorija nedrīkst būt tukša.") if ".." in Path(normalized).parts or not SAFE_OUTPUT_SEGMENT_RE.fullmatch(normalized): raise ValueError("Output apakšdirektorijā drīkst būt tikai droši ceļa segmenti.") return normalized @field_validator("continue_model_path") @classmethod def validate_continue_model_path(cls, value: str) -> str: normalized = value.strip() if not normalized: return "" stripped = normalized.strip("/") if not stripped: raise ValueError("Continue modeļa ceļš nedrīkst būt tukšs.") parts = Path(stripped).parts if ".." in parts or not SAFE_OUTPUT_SEGMENT_RE.fullmatch(stripped): raise ValueError("Continue modeļa ceļā drīkst būt tikai droši ceļa segmenti.") return stripped @model_validator(mode="after") def validate_model_selection(self) -> SpaceTrainingRequest: resolved_model_repo = self.hub_model_id or self.model_repo if not resolved_model_repo: raise ValueError("Jānorāda hub_model_id vai model_repo.") self.hub_model_id = resolved_model_repo self.model_repo = resolved_model_repo if not self.model_name and not self.model_preset: self.model_preset = "balanced" return self def resolve_output_dir(persistent_dir: str, output_subdir: str) -> Path: """Normalizē output ceļu persistent storage ietvaros.""" root = Path(persistent_dir).expanduser().resolve() target = (root / output_subdir).resolve() if os.path.commonpath([str(root), str(target)]) != str(root): raise ValueError("Output direktorijai jāatrodas Maris persistent storage ietvaros.") return target def resolve_optional_persistent_path(persistent_dir: str, path_value: str) -> Path | None: """Normalizē optional persistent storage ceļu.""" normalized = str(path_value or "").strip().strip("/") if not normalized: return None root = Path(persistent_dir).expanduser().resolve() target = (root / normalized).resolve() if os.path.commonpath([str(root), str(target)]) != str(root): raise ValueError( "Continue modeļa direktorijai jāatrodas Maris persistent storage ietvaros." ) return target def build_space_training_command(script_path: str, request: SpaceTrainingRequest) -> list[str]: """Izveido drošu komandu Maris treniņa palaišanai.""" command = ["bash", script_path] if request.model_name: command.extend(["--model-name", request.model_name]) elif request.model_preset: command.extend(["--model-preset", request.model_preset]) if request.all_branches: command.append("--all-branches") return command def build_space_training_env( base_env: dict[str, str], request: SpaceTrainingRequest, persistent_dir: str, ) -> dict[str, str]: """Sagatavo vidi Maris treniņa procesam.""" output_dir = resolve_output_dir(persistent_dir, request.output_subdir) continue_model_dir = resolve_optional_persistent_path( persistent_dir, request.continue_model_path ) config_path = ( str( base_env.get("MARIS_SPACE_TRAIN_CONFIG_PATH") or base_env.get("HF_SPACE_TRAINING_CONFIG_PATH") or base_env.get("MARIS_TRAIN_CONFIG_PATH") or base_env.get("HF_TRAINING_CONFIG_PATH") or SPACE_TRAINING_CONFIG_PATH_DEFAULT ).strip() or SPACE_TRAINING_CONFIG_PATH_DEFAULT ) env = dict(base_env) env.update( { "MARIS_PERSISTENT_DIR": persistent_dir, "MARIS_MEMORY_REPO": request.dataset_repo, "MARIS_MODEL_REPO": request.hub_model_id, "MARIS_TRAIN_CONFIG_PATH": config_path, "MARIS_TRAIN_NUM_EPOCHS": str(request.num_epochs), "MARIS_TRAIN_PUBLISH": "true" if request.push_to_hub else "false", "MARIS_TRAIN_OUTPUT_DIR": str(output_dir), "MARIS_LOCAL_MODEL_DIR": str(output_dir), "MARIS_TRAIN_CONTINUE_FROM_LATEST": ( "true" if request.continue_from_latest_artifact else "false" ), "HF_PERSISTENT_DIR": persistent_dir, "HF_DATASET_REPO": request.dataset_repo, "HF_MODEL_REPO": request.hub_model_id, "HF_TRAINING_CONFIG_PATH": config_path, "HF_TRAIN_NUM_EPOCHS": str(request.num_epochs), "HF_TRAIN_PUSH_TO_HUB": "true" if request.push_to_hub else "false", "HF_TRAIN_OUTPUT_DIR": str(output_dir), "HF_LOCAL_MODEL_DIR": str(output_dir), "HF_TRAIN_CONTINUE_FROM_LATEST": ( "true" if request.continue_from_latest_artifact else "false" ), "MARIS_TRAIN_DISTRIBUTED_STRATEGY": "none", "HF_TRAIN_DISTRIBUTED_STRATEGY": "none", "PYTHONUNBUFFERED": env.get("PYTHONUNBUFFERED", "1"), } ) env.pop("MARIS_TRAIN_DISTRIBUTED_CONFIG_PATH", None) env.pop("HF_TRAIN_DISTRIBUTED_CONFIG_PATH", None) if continue_model_dir is not None: env["MARIS_TRAIN_CONTINUE_MODEL_PATH"] = str(continue_model_dir) env["HF_TRAIN_CONTINUE_MODEL_PATH"] = str(continue_model_dir) else: env.pop("MARIS_TRAIN_CONTINUE_MODEL_PATH", None) env.pop("HF_TRAIN_CONTINUE_MODEL_PATH", None) if request.model_name: env["MARIS_TRAIN_BASE_MODEL"] = request.model_name env["HF_TRAIN_BASE_MODEL"] = request.model_name env.pop("MARIS_TRAIN_MODEL_PRESET", None) env.pop("HF_TRAIN_MODEL_PRESET", None) elif request.model_preset: env["MARIS_TRAIN_MODEL_PRESET"] = request.model_preset env["HF_TRAIN_MODEL_PRESET"] = request.model_preset env.pop("MARIS_TRAIN_BASE_MODEL", None) env.pop("HF_TRAIN_BASE_MODEL", None) return env def has_completed_training_artifacts(output_dir: Path) -> bool: """Nosaka, vai Space output direktorijā jau ir pabeigta treniņa artefakti.""" return any( output_dir.joinpath(marker).is_file() for marker in SPACE_TRAINING_COMPLETION_MARKERS ) def tail_log(log_path: str | Path, *, max_chars: int = 16000) -> str: """Atgriež loga beigas UI vajadzībām.""" path = Path(log_path) if not path.is_file(): return "" content = path.read_text(encoding="utf-8", errors="replace") return content[-max_chars:] def read_log_since(log_path: str | Path, offset: int, *, max_chars: int = 8192) -> tuple[str, int]: """Nolasa tikai jauno loga daļu, sākot no dotā offset.""" path = Path(log_path) if not path.is_file(): return "", 0 file_size = path.stat().st_size next_offset = max(0, min(offset, file_size)) with path.open(encoding="utf-8", errors="replace") as handle: handle.seek(next_offset) chunk = handle.read(max_chars) next_offset = handle.tell() return chunk, next_offset def parse_training_progress( log_text: str, *, request: dict[str, Any] | None = None, running: bool, exit_code: int | None, ) -> dict[str, Any]: """Heuristiski izvada treniņa progresu no logiem un stāvokļa.""" lower_log = log_text.lower() total_epochs = int(request.get("num_epochs", 0)) or None if request else None current_epoch: float | None = None detected_total_epochs = total_epochs current_step: int | None = None total_steps: int | None = None loss: float | None = None eval_loss: float | None = None learning_rate: float | None = None eta_seconds: float | None = None structured_stage: str | None = None structured_label: str | None = None structured_events = _extract_structured_training_events(log_text) if structured_events: latest_event = structured_events[-1] current_epoch = _as_float(latest_event.get("epoch")) detected_total_epochs = ( _as_int(latest_event.get("total_epochs")) or detected_total_epochs or total_epochs ) current_step = _as_int(latest_event.get("step")) total_steps = _as_int(latest_event.get("total_steps")) loss = _as_float(latest_event.get("loss")) eval_loss = _as_float(latest_event.get("eval_loss")) learning_rate = _as_float(latest_event.get("learning_rate")) eta_seconds = _as_float(latest_event.get("eta_seconds")) structured_stage = _normalize_stage(latest_event.get("stage")) structured_label = _normalize_label(latest_event.get("label")) epoch_matches = list(LOG_EPOCH_RE.finditer(log_text)) if current_epoch is None and epoch_matches: last_match = epoch_matches[-1] current_epoch = float(last_match.group(1)) detected_total_epochs = int(float(last_match.group(2))) elif current_epoch is None: value_matches = list(VALUE_EPOCH_RE.finditer(log_text)) if value_matches: current_epoch = float(value_matches[-1].group(1)) if current_step is None or total_steps is None: step_matches = list(STEP_PROGRESS_RE.finditer(log_text)) if step_matches: last_match = step_matches[-1] current_step = int(last_match.group(1)) total_steps = int(last_match.group(2)) if loss is None: loss_matches = list(LOSS_RE.finditer(log_text)) loss = float(loss_matches[-1].group(1)) if loss_matches else None if eval_loss is None: eval_loss_matches = list(EVAL_LOSS_RE.finditer(log_text)) eval_loss = float(eval_loss_matches[-1].group(1)) if eval_loss_matches else None if learning_rate is None: learning_rate_matches = list(LEARNING_RATE_RE.finditer(log_text)) learning_rate = float(learning_rate_matches[-1].group(1)) if learning_rate_matches else None percent = 0 stage = structured_stage or "queued" label = structured_label or "Gaida treniņa startu" if "stop requested by user" in lower_log or "training stopped by user" in lower_log: stage = "stopped" label = "Treniņš apturēts pēc pieprasījuma" elif exit_code == 0: stage = "completed" label = "Treniņš pabeigts veiksmīgi" percent = 100 elif exit_code is not None and not running: stage = "failed" label = f"Treniņš beidzās ar kļūdu (exit {exit_code})" percent = 100 elif stage == "publishing": label = structured_label or "Publicē modeli origin repozitorijā" percent = 96 elif stage == "saving": label = structured_label or "Saglabā modeļa artefaktus" percent = 92 elif stage == "evaluating": label = structured_label or "Veic validāciju un eval metriku aprēķinu" percent = 88 if current_step else 82 elif stage == "benchmarking": label = structured_label or "Palaiž benchmark un release gate pārbaudes" percent = 94 elif stage == "preparing": label = structured_label or "Sagatavo datus, modeli un cache" percent = 20 elif any(token in lower_log for token in ("uploading", "pushing", "export_to_hf")): stage = "publishing" label = "Publicē modeli origin repozitorijā" percent = 96 elif current_step is not None and total_steps: stage = structured_stage or "training" progress_ratio = min(current_step / max(total_steps, 1), 1.0) percent = min(95, max(35, int(35 + progress_ratio * 55))) label = structured_label or f"Trenē modeli · solis {current_step}/{total_steps}" elif current_epoch is not None: stage = structured_stage or "training" epoch_total = detected_total_epochs or total_epochs if epoch_total: percent = min(95, max(35, int(35 + min(current_epoch / epoch_total, 1.0) * 55))) label = structured_label or f"Trenē modeli · epoha {current_epoch:g}/{epoch_total}" else: percent = 65 label = structured_label or f"Trenē modeli · epoha {current_epoch:g}" elif any( token in lower_log for token in ("tokeniz", "validation split", "dataset", "snapshot", "download", "cache") ): stage = structured_stage or "preparing" label = structured_label or "Sagatavo datus, modeli un cache" percent = 20 elif running: stage = structured_stage or "starting" label = structured_label or "Inicializē treniņu" percent = 5 return { "percent": percent, "stage": stage, "label": label, "current_epoch": current_epoch, "total_epochs": detected_total_epochs or total_epochs, "loss": loss, "eval_loss": eval_loss, "learning_rate": learning_rate, "current_step": current_step, "total_steps": total_steps, "eta_seconds": eta_seconds, "events_detected": len(structured_events), } def _extract_structured_training_events(log_text: str) -> list[dict[str, Any]]: events: list[dict[str, Any]] = [] for line in reversed(log_text.splitlines()): if MARIS_PROGRESS_EVENT_KEY not in line: continue json_start = line.find("{") if json_start < 0: continue try: payload = json.loads(line[json_start:]) except json.JSONDecodeError: continue if isinstance(payload, dict) and payload.get(MARIS_PROGRESS_EVENT_KEY): events.append(payload) if len(events) >= MAX_STRUCTURED_EVENTS: break events.reverse() return events def _as_float(value: Any) -> float | None: if isinstance(value, bool) or value in (None, ""): return None try: return float(value) except (TypeError, ValueError): return None def _as_int(value: Any) -> int | None: if isinstance(value, bool) or value in (None, ""): return None try: return int(float(value)) except (TypeError, ValueError): return None def _normalize_stage(value: Any) -> str | None: if not isinstance(value, str): return None normalized = value.strip().lower() return normalized or None def _normalize_label(value: Any) -> str | None: if not isinstance(value, str): return None normalized = value.strip() return normalized or None def terminate_process_tree(process: Any, *, grace_seconds: float = 10.0) -> int | None: """Apstādina procesu un tā child procesus iespējami korekti.""" if process.poll() is not None: return process.returncode try: if hasattr(os, "killpg"): os.killpg(process.pid, signal.SIGTERM) else: process.terminate() except ProcessLookupError: return process.poll() deadline = time.monotonic() + grace_seconds while time.monotonic() < deadline: exit_code = process.poll() if exit_code is not None: return exit_code time.sleep(0.1) try: if hasattr(os, "killpg"): os.killpg(process.pid, signal.SIGKILL) else: process.kill() except ProcessLookupError: return process.poll() try: process.wait(timeout=1) except Exception: return process.poll() return process.returncode def list_space_model_choices() -> dict[str, dict[str, Any]]: """Atgriež UI vajadzībām pieejamos bāzes modeļus.""" return list_training_base_models()