maris-ai-master / core-python /tests /test_space_ui.py
MarisUK's picture
Maris AI model sync
f440f03 verified
"""Tests Maris treniņu UI palīgfunkcijām."""
from __future__ import annotations
from pathlib import Path
import pytest
from maris_core.training.space_ui import (
SpaceTrainingRequest,
build_space_training_command,
build_space_training_env,
has_completed_training_artifacts,
list_space_model_choices,
parse_training_progress,
read_log_since,
resolve_optional_persistent_path,
resolve_output_dir,
)
def test_space_training_request_rejects_invalid_repo_id() -> None:
with pytest.raises(ValueError):
SpaceTrainingRequest(dataset_repo="invalid repo")
def test_space_training_request_rejects_non_maris_repo_ids() -> None:
with pytest.raises(ValueError):
SpaceTrainingRequest(dataset_repo="someone-else/not-maris-memory")
with pytest.raises(ValueError):
SpaceTrainingRequest(model_repo="someone-else/not-maris-model")
def test_resolve_output_dir_keeps_path_inside_persistent_root(tmp_path: Path) -> None:
output_dir = resolve_output_dir(str(tmp_path), "runs/session-1")
assert output_dir == tmp_path / "runs" / "session-1"
def test_resolve_output_dir_rejects_escape_attempt(tmp_path: Path) -> None:
with pytest.raises(ValueError):
resolve_output_dir(str(tmp_path), "../escape")
def test_build_space_training_command_prefers_custom_model_name() -> None:
request = SpaceTrainingRequest(model_preset="coding", model_name="Qwen/Qwen2.5-1.5B-Instruct")
command = build_space_training_command("/tmp/train-hf.sh", request)
assert command == [
"bash",
"/tmp/train-hf.sh",
"--model-name",
"Qwen/Qwen2.5-1.5B-Instruct",
]
def test_space_training_request_accepts_custom_model_without_preset() -> None:
request = SpaceTrainingRequest(model_preset="", model_name="meta-llama/Llama-3.2-3B-Instruct")
assert request.model_preset == ""
assert request.model_name == "meta-llama/Llama-3.2-3B-Instruct"
def test_build_space_training_env_uses_preset_and_persistent_storage(tmp_path: Path) -> None:
request = SpaceTrainingRequest(
model_preset="coding",
hub_model_id="MarisUK/maris-ai-lv",
output_subdir="runs/coder",
continue_model_path="runs/checkpoints",
push_to_hub=False,
)
env = build_space_training_env({}, request, str(tmp_path))
assert env["HF_PERSISTENT_DIR"] == str(tmp_path)
assert env["HF_TRAIN_OUTPUT_DIR"] == str(tmp_path / "runs" / "coder")
assert env["HF_LOCAL_MODEL_DIR"] == str(tmp_path / "runs" / "coder")
assert env["HF_MODEL_REPO"] == "MarisUK/maris-ai-lv"
assert env["HF_TRAIN_MODEL_PRESET"] == "coding"
assert env["HF_TRAINING_CONFIG_PATH"] == "huggingface/training-config.json"
assert env["MARIS_TRAIN_CONFIG_PATH"] == "huggingface/training-config.json"
assert env["HF_TRAIN_PUSH_TO_HUB"] == "false"
assert env["HF_TRAIN_CONTINUE_FROM_LATEST"] == "true"
assert env["HF_TRAIN_CONTINUE_MODEL_PATH"] == str(tmp_path / "runs" / "checkpoints")
assert env["HF_TRAIN_DISTRIBUTED_STRATEGY"] == "none"
assert env["MARIS_TRAIN_DISTRIBUTED_STRATEGY"] == "none"
assert env["PYTHONUNBUFFERED"] == "1"
def test_build_space_training_env_clears_inherited_distributed_overrides(tmp_path: Path) -> None:
request = SpaceTrainingRequest(model_preset="balanced")
env = build_space_training_env(
{
"HF_TRAIN_DISTRIBUTED_STRATEGY": "deepspeed",
"MARIS_TRAIN_DISTRIBUTED_STRATEGY": "fsdp",
"HF_TRAIN_DISTRIBUTED_CONFIG_PATH": "/tmp/deepspeed.json",
"MARIS_TRAIN_DISTRIBUTED_CONFIG_PATH": "/tmp/fsdp.json",
},
request,
str(tmp_path),
)
assert env["HF_TRAIN_DISTRIBUTED_STRATEGY"] == "none"
assert env["MARIS_TRAIN_DISTRIBUTED_STRATEGY"] == "none"
assert "HF_TRAIN_DISTRIBUTED_CONFIG_PATH" not in env
assert "MARIS_TRAIN_DISTRIBUTED_CONFIG_PATH" not in env
def test_build_space_training_env_allows_explicit_space_config_override(tmp_path: Path) -> None:
request = SpaceTrainingRequest(model_preset="balanced")
env = build_space_training_env(
{"MARIS_SPACE_TRAIN_CONFIG_PATH": "huggingface/custom-space-config.json"},
request,
str(tmp_path),
)
assert env["HF_TRAINING_CONFIG_PATH"] == "huggingface/custom-space-config.json"
assert env["MARIS_TRAIN_CONFIG_PATH"] == "huggingface/custom-space-config.json"
def test_has_completed_training_artifacts_detects_finished_space_run(tmp_path: Path) -> None:
output_dir = tmp_path / "runs" / "demo"
output_dir.mkdir(parents=True)
assert has_completed_training_artifacts(output_dir) is False
(output_dir / "training-metrics.json").write_text("{}", encoding="utf-8")
assert has_completed_training_artifacts(output_dir) is True
def test_list_space_model_choices_exposes_presets() -> None:
choices = list_space_model_choices()
assert {"balanced", "reasoning", "coding", "lightweight"}.issubset(choices)
def test_list_space_model_choices_can_include_large_external_models(monkeypatch) -> None:
monkeypatch.setenv(
"MARIS_TRAIN_EXTRA_MODELS",
(
'{"qwen-880b":{"model_name":"Qwen/Qwen3-880B-Instruct",'
'"label":"Qwen ultra preset",'
'"description":"Large external preset for giant-model experiments."}}'
),
)
choices = list_space_model_choices()
assert choices["qwen-880b"]["model_name"] == "Qwen/Qwen3-880B-Instruct"
assert choices["qwen-880b"]["label"] == "Qwen ultra preset"
def test_space_training_request_defaults_to_balanced_model_selection() -> None:
request = SpaceTrainingRequest(model_preset="", model_name="")
assert request.model_preset == "balanced"
assert request.model_name == ""
def test_space_training_request_accepts_separate_hub_model_id() -> None:
request = SpaceTrainingRequest(
model_repo="",
hub_model_id="MarisUK/maris-ai-lv",
model_preset="",
model_name="meta-llama/Llama-3.2-3B-Instruct",
)
assert request.hub_model_id == "MarisUK/maris-ai-lv"
assert request.model_repo == "MarisUK/maris-ai-lv"
def test_resolve_optional_persistent_path_returns_none_for_empty_value(tmp_path: Path) -> None:
assert resolve_optional_persistent_path(str(tmp_path), "") is None
def test_read_log_since_reads_only_delta(tmp_path: Path) -> None:
log_path = tmp_path / "train.log"
log_path.write_text("line-1\nline-2\n", encoding="utf-8")
first_chunk, first_offset = read_log_since(log_path, 0)
second_chunk, second_offset = read_log_since(log_path, first_offset)
assert first_chunk == "line-1\nline-2\n"
assert second_chunk == ""
assert second_offset == first_offset
def test_parse_training_progress_detects_epoch_and_loss() -> None:
progress = parse_training_progress(
"Epoch 2/4\n{'loss': 0.125, 'epoch': 2.0}\n",
request={"num_epochs": 4},
running=True,
exit_code=None,
)
assert progress["stage"] == "training"
assert progress["percent"] >= 60
assert progress["current_epoch"] == 2.0
assert progress["total_epochs"] == 4
assert progress["loss"] == 0.125
def test_parse_training_progress_reports_structured_preparing_stage() -> None:
progress = parse_training_progress(
'{"maris_training_event": true, "event": "prepare_model", "stage": "preparing", "label": "Ielādē tokenizeri un modeli"}\n',
request={"num_epochs": 3},
running=True,
exit_code=None,
)
assert progress["stage"] == "preparing"
assert progress["label"] == "Ielādē tokenizeri un modeli"
assert progress["percent"] == 20
assert progress["events_detected"] == 1
def test_parse_training_progress_reports_completion() -> None:
progress = parse_training_progress(
"Training complete\n",
request={"num_epochs": 3},
running=False,
exit_code=0,
)
assert progress["stage"] == "completed"
assert progress["percent"] == 100
def test_parse_training_progress_prefers_structured_events() -> None:
progress = parse_training_progress(
"\n".join(
[
'{"maris_training_event": true, "event": "log", "stage": "training", "label": "Trenē modeli · solis 12/40", "epoch": 1.5, "total_epochs": 4, "step": 12, "total_steps": 40, "loss": 0.2451, "eval_loss": 0.1987, "learning_rate": 0.0002, "eta_seconds": 180}',
"Epoch 1/4",
]
),
request={"num_epochs": 4},
running=True,
exit_code=None,
)
assert progress["stage"] == "training"
assert progress["label"] == "Trenē modeli · solis 12/40"
assert progress["current_epoch"] == 1.5
assert progress["total_epochs"] == 4
assert progress["current_step"] == 12
assert progress["total_steps"] == 40
assert progress["loss"] == 0.2451
assert progress["eval_loss"] == 0.1987
assert progress["learning_rate"] == 0.0002
assert progress["eta_seconds"] == 180
assert progress["events_detected"] == 1