MarisUK's picture
Maris AI model sync
f440f03 verified
"""Pal墨gfunkcijas Maris treni艈u darba telpas UI."""
from __future__ import annotations
import json
import os
import re
import signal
import time
from pathlib import Path
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from maris_core.training.config import list_training_base_models
from maris_core.utils.env import validate_maris_model, validate_maris_repo
HF_REPO_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*/[A-Za-z0-9][A-Za-z0-9._-]*$")
SAFE_OUTPUT_SEGMENT_RE = re.compile(r"^[A-Za-z0-9._/-]+$")
LOG_EPOCH_RE = re.compile(r"Epoch\s+(\d+(?:\.\d+)?)\s*/\s*(\d+(?:\.\d+)?)", re.IGNORECASE)
VALUE_EPOCH_RE = re.compile(r"(?:['\"]?epoch['\"]?\s*[:=]\s*)(\d+(?:\.\d+)?)", re.IGNORECASE)
LOSS_RE = re.compile(r"(?:['\"]?loss['\"]?\s*[:=]\s*)(\d+(?:\.\d+)?)", re.IGNORECASE)
EVAL_LOSS_RE = re.compile(r"(?:['\"]?eval_loss['\"]?\s*[:=]\s*)(\d+(?:\.\d+)?)", re.IGNORECASE)
LEARNING_RATE_RE = re.compile(
r"(?:['\"]?learning_rate['\"]?\s*[:=]\s*)(\d+(?:\.\d+)?(?:e[+-]?\d+)?)",
re.IGNORECASE,
)
STEP_PROGRESS_RE = re.compile(r"(\d+)\s*/\s*(\d+)\s*\[", re.IGNORECASE)
MARIS_PROGRESS_EVENT_KEY = "maris_training_event"
# Keep a larger rolling event window than persisted run history because live status
# parsing needs several recent progress/save/eval events from the current log tail.
MAX_STRUCTURED_EVENTS = 64
SPACE_TRAINING_CONFIG_PATH_DEFAULT = "huggingface/training-config.json"
SPACE_TRAINING_COMPLETION_MARKERS = (
"training-metrics.json",
"trainer_state.json",
"training-provenance.json",
"branch-suite.json",
)
def _validate_repo_id(value: str) -> str:
normalized = value.strip()
if not HF_REPO_ID_RE.fullmatch(normalized):
raise ValueError("Repo ID j膩b奴t form膩t膩 owner/name.")
return normalized
class SpaceTrainingRequest(BaseModel):
"""UI piepras墨jums Maris treni艈a palai拧anai."""
model_config = ConfigDict(str_strip_whitespace=True)
dataset_repo: str = "MarisUK/maris-ai-memory"
model_repo: str = "MarisUK/maris-ai-master"
hub_model_id: str = ""
model_preset: str = "balanced"
model_name: str = ""
num_epochs: int = Field(default=3, ge=1, le=100)
all_branches: bool = False
push_to_hub: bool = True
output_subdir: str = "maris-ai-master"
continue_from_latest_artifact: bool = True
continue_model_path: str = ""
@field_validator("dataset_repo")
@classmethod
def validate_dataset_repo(cls, value: str) -> str:
normalized = _validate_repo_id(value)
try:
return validate_maris_repo(normalized, "dataset_repo", label="repozitorijs")
except RuntimeError as exc:
raise ValueError(str(exc)) from exc
@field_validator("model_repo")
@classmethod
def validate_model_repo(cls, value: str) -> str:
if not value.strip():
return ""
normalized = _validate_repo_id(value)
try:
return validate_maris_model(normalized, "model_repo")
except RuntimeError as exc:
raise ValueError(str(exc)) from exc
@field_validator("hub_model_id")
@classmethod
def validate_hub_model_id(cls, value: str) -> str:
if not value.strip():
return ""
normalized = _validate_repo_id(value)
try:
return validate_maris_model(normalized, "hub_model_id")
except RuntimeError as exc:
raise ValueError(str(exc)) from exc
@field_validator("model_name")
@classmethod
def validate_model_name(cls, value: str) -> str:
normalized = value.strip()
if normalized and not HF_REPO_ID_RE.fullmatch(normalized):
raise ValueError("Custom modelim j膩b奴t form膩t膩 owner/name.")
return normalized
@field_validator("model_preset")
@classmethod
def validate_model_preset(cls, value: str) -> str:
normalized = value.strip()
if not normalized:
return ""
if normalized not in list_training_base_models():
raise ValueError("Nezin膩ms mode募a presets.")
return normalized
@field_validator("output_subdir")
@classmethod
def validate_output_subdir(cls, value: str) -> str:
normalized = value.strip().strip("/")
if not normalized:
raise ValueError("Output apak拧direktorija nedr墨kst b奴t tuk拧a.")
if ".." in Path(normalized).parts or not SAFE_OUTPUT_SEGMENT_RE.fullmatch(normalized):
raise ValueError("Output apak拧direktorij膩 dr墨kst b奴t tikai dro拧i ce募a segmenti.")
return normalized
@field_validator("continue_model_path")
@classmethod
def validate_continue_model_path(cls, value: str) -> str:
normalized = value.strip()
if not normalized:
return ""
stripped = normalized.strip("/")
if not stripped:
raise ValueError("Continue mode募a ce募拧 nedr墨kst b奴t tuk拧s.")
parts = Path(stripped).parts
if ".." in parts or not SAFE_OUTPUT_SEGMENT_RE.fullmatch(stripped):
raise ValueError("Continue mode募a ce募膩 dr墨kst b奴t tikai dro拧i ce募a segmenti.")
return stripped
@model_validator(mode="after")
def validate_model_selection(self) -> SpaceTrainingRequest:
resolved_model_repo = self.hub_model_id or self.model_repo
if not resolved_model_repo:
raise ValueError("J膩nor膩da hub_model_id vai model_repo.")
self.hub_model_id = resolved_model_repo
self.model_repo = resolved_model_repo
if not self.model_name and not self.model_preset:
self.model_preset = "balanced"
return self
def resolve_output_dir(persistent_dir: str, output_subdir: str) -> Path:
"""Normaliz膿 output ce募u persistent storage ietvaros."""
root = Path(persistent_dir).expanduser().resolve()
target = (root / output_subdir).resolve()
if os.path.commonpath([str(root), str(target)]) != str(root):
raise ValueError("Output direktorijai j膩atrodas Maris persistent storage ietvaros.")
return target
def resolve_optional_persistent_path(persistent_dir: str, path_value: str) -> Path | None:
"""Normaliz膿 optional persistent storage ce募u."""
normalized = str(path_value or "").strip().strip("/")
if not normalized:
return None
root = Path(persistent_dir).expanduser().resolve()
target = (root / normalized).resolve()
if os.path.commonpath([str(root), str(target)]) != str(root):
raise ValueError(
"Continue mode募a direktorijai j膩atrodas Maris persistent storage ietvaros."
)
return target
def build_space_training_command(script_path: str, request: SpaceTrainingRequest) -> list[str]:
"""Izveido dro拧u komandu Maris treni艈a palai拧anai."""
command = ["bash", script_path]
if request.model_name:
command.extend(["--model-name", request.model_name])
elif request.model_preset:
command.extend(["--model-preset", request.model_preset])
if request.all_branches:
command.append("--all-branches")
return command
def build_space_training_env(
base_env: dict[str, str],
request: SpaceTrainingRequest,
persistent_dir: str,
) -> dict[str, str]:
"""Sagatavo vidi Maris treni艈a procesam."""
output_dir = resolve_output_dir(persistent_dir, request.output_subdir)
continue_model_dir = resolve_optional_persistent_path(
persistent_dir, request.continue_model_path
)
config_path = (
str(
base_env.get("MARIS_SPACE_TRAIN_CONFIG_PATH")
or base_env.get("HF_SPACE_TRAINING_CONFIG_PATH")
or base_env.get("MARIS_TRAIN_CONFIG_PATH")
or base_env.get("HF_TRAINING_CONFIG_PATH")
or SPACE_TRAINING_CONFIG_PATH_DEFAULT
).strip()
or SPACE_TRAINING_CONFIG_PATH_DEFAULT
)
env = dict(base_env)
env.update(
{
"MARIS_PERSISTENT_DIR": persistent_dir,
"MARIS_MEMORY_REPO": request.dataset_repo,
"MARIS_MODEL_REPO": request.hub_model_id,
"MARIS_TRAIN_CONFIG_PATH": config_path,
"MARIS_TRAIN_NUM_EPOCHS": str(request.num_epochs),
"MARIS_TRAIN_PUBLISH": "true" if request.push_to_hub else "false",
"MARIS_TRAIN_OUTPUT_DIR": str(output_dir),
"MARIS_LOCAL_MODEL_DIR": str(output_dir),
"MARIS_TRAIN_CONTINUE_FROM_LATEST": (
"true" if request.continue_from_latest_artifact else "false"
),
"HF_PERSISTENT_DIR": persistent_dir,
"HF_DATASET_REPO": request.dataset_repo,
"HF_MODEL_REPO": request.hub_model_id,
"HF_TRAINING_CONFIG_PATH": config_path,
"HF_TRAIN_NUM_EPOCHS": str(request.num_epochs),
"HF_TRAIN_PUSH_TO_HUB": "true" if request.push_to_hub else "false",
"HF_TRAIN_OUTPUT_DIR": str(output_dir),
"HF_LOCAL_MODEL_DIR": str(output_dir),
"HF_TRAIN_CONTINUE_FROM_LATEST": (
"true" if request.continue_from_latest_artifact else "false"
),
"MARIS_TRAIN_DISTRIBUTED_STRATEGY": "none",
"HF_TRAIN_DISTRIBUTED_STRATEGY": "none",
"PYTHONUNBUFFERED": env.get("PYTHONUNBUFFERED", "1"),
}
)
env.pop("MARIS_TRAIN_DISTRIBUTED_CONFIG_PATH", None)
env.pop("HF_TRAIN_DISTRIBUTED_CONFIG_PATH", None)
if continue_model_dir is not None:
env["MARIS_TRAIN_CONTINUE_MODEL_PATH"] = str(continue_model_dir)
env["HF_TRAIN_CONTINUE_MODEL_PATH"] = str(continue_model_dir)
else:
env.pop("MARIS_TRAIN_CONTINUE_MODEL_PATH", None)
env.pop("HF_TRAIN_CONTINUE_MODEL_PATH", None)
if request.model_name:
env["MARIS_TRAIN_BASE_MODEL"] = request.model_name
env["HF_TRAIN_BASE_MODEL"] = request.model_name
env.pop("MARIS_TRAIN_MODEL_PRESET", None)
env.pop("HF_TRAIN_MODEL_PRESET", None)
elif request.model_preset:
env["MARIS_TRAIN_MODEL_PRESET"] = request.model_preset
env["HF_TRAIN_MODEL_PRESET"] = request.model_preset
env.pop("MARIS_TRAIN_BASE_MODEL", None)
env.pop("HF_TRAIN_BASE_MODEL", None)
return env
def has_completed_training_artifacts(output_dir: Path) -> bool:
"""Nosaka, vai Space output direktorij膩 jau ir pabeigta treni艈a artefakti."""
return any(
output_dir.joinpath(marker).is_file() for marker in SPACE_TRAINING_COMPLETION_MARKERS
)
def tail_log(log_path: str | Path, *, max_chars: int = 16000) -> str:
"""Atgrie啪 loga beigas UI vajadz墨b膩m."""
path = Path(log_path)
if not path.is_file():
return ""
content = path.read_text(encoding="utf-8", errors="replace")
return content[-max_chars:]
def read_log_since(log_path: str | Path, offset: int, *, max_chars: int = 8192) -> tuple[str, int]:
"""Nolasa tikai jauno loga da募u, s膩kot no dot膩 offset."""
path = Path(log_path)
if not path.is_file():
return "", 0
file_size = path.stat().st_size
next_offset = max(0, min(offset, file_size))
with path.open(encoding="utf-8", errors="replace") as handle:
handle.seek(next_offset)
chunk = handle.read(max_chars)
next_offset = handle.tell()
return chunk, next_offset
def parse_training_progress(
log_text: str,
*,
request: dict[str, Any] | None = None,
running: bool,
exit_code: int | None,
) -> dict[str, Any]:
"""Heuristiski izvada treni艈a progresu no logiem un st膩vok募a."""
lower_log = log_text.lower()
total_epochs = int(request.get("num_epochs", 0)) or None if request else None
current_epoch: float | None = None
detected_total_epochs = total_epochs
current_step: int | None = None
total_steps: int | None = None
loss: float | None = None
eval_loss: float | None = None
learning_rate: float | None = None
eta_seconds: float | None = None
structured_stage: str | None = None
structured_label: str | None = None
structured_events = _extract_structured_training_events(log_text)
if structured_events:
latest_event = structured_events[-1]
current_epoch = _as_float(latest_event.get("epoch"))
detected_total_epochs = (
_as_int(latest_event.get("total_epochs")) or detected_total_epochs or total_epochs
)
current_step = _as_int(latest_event.get("step"))
total_steps = _as_int(latest_event.get("total_steps"))
loss = _as_float(latest_event.get("loss"))
eval_loss = _as_float(latest_event.get("eval_loss"))
learning_rate = _as_float(latest_event.get("learning_rate"))
eta_seconds = _as_float(latest_event.get("eta_seconds"))
structured_stage = _normalize_stage(latest_event.get("stage"))
structured_label = _normalize_label(latest_event.get("label"))
epoch_matches = list(LOG_EPOCH_RE.finditer(log_text))
if current_epoch is None and epoch_matches:
last_match = epoch_matches[-1]
current_epoch = float(last_match.group(1))
detected_total_epochs = int(float(last_match.group(2)))
elif current_epoch is None:
value_matches = list(VALUE_EPOCH_RE.finditer(log_text))
if value_matches:
current_epoch = float(value_matches[-1].group(1))
if current_step is None or total_steps is None:
step_matches = list(STEP_PROGRESS_RE.finditer(log_text))
if step_matches:
last_match = step_matches[-1]
current_step = int(last_match.group(1))
total_steps = int(last_match.group(2))
if loss is None:
loss_matches = list(LOSS_RE.finditer(log_text))
loss = float(loss_matches[-1].group(1)) if loss_matches else None
if eval_loss is None:
eval_loss_matches = list(EVAL_LOSS_RE.finditer(log_text))
eval_loss = float(eval_loss_matches[-1].group(1)) if eval_loss_matches else None
if learning_rate is None:
learning_rate_matches = list(LEARNING_RATE_RE.finditer(log_text))
learning_rate = float(learning_rate_matches[-1].group(1)) if learning_rate_matches else None
percent = 0
stage = structured_stage or "queued"
label = structured_label or "Gaida treni艈a startu"
if "stop requested by user" in lower_log or "training stopped by user" in lower_log:
stage = "stopped"
label = "Treni艈拧 aptur膿ts p膿c piepras墨juma"
elif exit_code == 0:
stage = "completed"
label = "Treni艈拧 pabeigts veiksm墨gi"
percent = 100
elif exit_code is not None and not running:
stage = "failed"
label = f"Treni艈拧 beidz膩s ar k募奴du (exit {exit_code})"
percent = 100
elif stage == "publishing":
label = structured_label or "Public膿 modeli origin repozitorij膩"
percent = 96
elif stage == "saving":
label = structured_label or "Saglab膩 mode募a artefaktus"
percent = 92
elif stage == "evaluating":
label = structured_label or "Veic valid膩ciju un eval metriku apr膿姆inu"
percent = 88 if current_step else 82
elif stage == "benchmarking":
label = structured_label or "Palai啪 benchmark un release gate p膩rbaudes"
percent = 94
elif stage == "preparing":
label = structured_label or "Sagatavo datus, modeli un cache"
percent = 20
elif any(token in lower_log for token in ("uploading", "pushing", "export_to_hf")):
stage = "publishing"
label = "Public膿 modeli origin repozitorij膩"
percent = 96
elif current_step is not None and total_steps:
stage = structured_stage or "training"
progress_ratio = min(current_step / max(total_steps, 1), 1.0)
percent = min(95, max(35, int(35 + progress_ratio * 55)))
label = structured_label or f"Tren膿 modeli 路 solis {current_step}/{total_steps}"
elif current_epoch is not None:
stage = structured_stage or "training"
epoch_total = detected_total_epochs or total_epochs
if epoch_total:
percent = min(95, max(35, int(35 + min(current_epoch / epoch_total, 1.0) * 55)))
label = structured_label or f"Tren膿 modeli 路 epoha {current_epoch:g}/{epoch_total}"
else:
percent = 65
label = structured_label or f"Tren膿 modeli 路 epoha {current_epoch:g}"
elif any(
token in lower_log
for token in ("tokeniz", "validation split", "dataset", "snapshot", "download", "cache")
):
stage = structured_stage or "preparing"
label = structured_label or "Sagatavo datus, modeli un cache"
percent = 20
elif running:
stage = structured_stage or "starting"
label = structured_label or "Inicializ膿 treni艈u"
percent = 5
return {
"percent": percent,
"stage": stage,
"label": label,
"current_epoch": current_epoch,
"total_epochs": detected_total_epochs or total_epochs,
"loss": loss,
"eval_loss": eval_loss,
"learning_rate": learning_rate,
"current_step": current_step,
"total_steps": total_steps,
"eta_seconds": eta_seconds,
"events_detected": len(structured_events),
}
def _extract_structured_training_events(log_text: str) -> list[dict[str, Any]]:
events: list[dict[str, Any]] = []
for line in reversed(log_text.splitlines()):
if MARIS_PROGRESS_EVENT_KEY not in line:
continue
json_start = line.find("{")
if json_start < 0:
continue
try:
payload = json.loads(line[json_start:])
except json.JSONDecodeError:
continue
if isinstance(payload, dict) and payload.get(MARIS_PROGRESS_EVENT_KEY):
events.append(payload)
if len(events) >= MAX_STRUCTURED_EVENTS:
break
events.reverse()
return events
def _as_float(value: Any) -> float | None:
if isinstance(value, bool) or value in (None, ""):
return None
try:
return float(value)
except (TypeError, ValueError):
return None
def _as_int(value: Any) -> int | None:
if isinstance(value, bool) or value in (None, ""):
return None
try:
return int(float(value))
except (TypeError, ValueError):
return None
def _normalize_stage(value: Any) -> str | None:
if not isinstance(value, str):
return None
normalized = value.strip().lower()
return normalized or None
def _normalize_label(value: Any) -> str | None:
if not isinstance(value, str):
return None
normalized = value.strip()
return normalized or None
def terminate_process_tree(process: Any, *, grace_seconds: float = 10.0) -> int | None:
"""Apst膩dina procesu un t膩 child procesus iesp膿jami korekti."""
if process.poll() is not None:
return process.returncode
try:
if hasattr(os, "killpg"):
os.killpg(process.pid, signal.SIGTERM)
else:
process.terminate()
except ProcessLookupError:
return process.poll()
deadline = time.monotonic() + grace_seconds
while time.monotonic() < deadline:
exit_code = process.poll()
if exit_code is not None:
return exit_code
time.sleep(0.1)
try:
if hasattr(os, "killpg"):
os.killpg(process.pid, signal.SIGKILL)
else:
process.kill()
except ProcessLookupError:
return process.poll()
try:
process.wait(timeout=1)
except Exception:
return process.poll()
return process.returncode
def list_space_model_choices() -> dict[str, dict[str, Any]]:
"""Atgrie啪 UI vajadz墨b膩m pieejamos b膩zes mode募us."""
return list_training_base_models()