| """Tests human-in-the-loop training artifact staging helpers.""" |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import pytest |
|
|
| from maris_core.training.human_training import ( |
| HumanTrainingRequest, |
| build_human_training_launch_spec, |
| load_human_training_manifest, |
| publish_human_training_artifacts, |
| stage_human_training_artifacts, |
| ) |
|
|
|
|
| def test_human_training_request_requires_input_signal() -> None: |
| with pytest.raises(ValueError): |
| HumanTrainingRequest( |
| dataset_repo="example-user/memory-dataset", |
| model_repo="example-user/custom-model", |
| profile_facts=[], |
| profile_preferences=[], |
| response_instructions=[], |
| conversation_examples=[], |
| preference_pairs=[], |
| eval_examples=[], |
| ) |
|
|
|
|
| def test_stage_human_training_artifacts_builds_staging_manifest(tmp_path: Path) -> None: |
| request = HumanTrainingRequest( |
| dataset_repo="example-user/memory-dataset", |
| hub_model_id="example-user/custom-model", |
| output_subdir="runs/self-profile", |
| continue_model_path="runs/checkpoints/latest", |
| profile_facts=["Man patīk īsas tehniskas atbildes."], |
| profile_preferences=["Atbildi latviski."], |
| response_instructions=["Ja iespējams, dod strukturētu kopsavilkumu."], |
| conversation_examples=[ |
| { |
| "user": "Kas ir mana valodas preference?", |
| "assistant": "Tu dod priekšroku latviešu valodai.", |
| }, |
| { |
| "user": "Kas ir mana valodas preference?", |
| "assistant": "Tu dod priekšroku latviešu valodai.", |
| }, |
| ], |
| preference_pairs=[ |
| { |
| "prompt": "Apraksti manu profilu.", |
| "chosen": "Tu vēlies īsas atbildes latviski.", |
| "rejected": "Es neko nezinu par tavām preferencēm.", |
| "confidence": 0.9, |
| } |
| ], |
| eval_examples=[ |
| { |
| "prompt": "Kā tu atbildēsi turpmāk?", |
| "completion": "Īsi, strukturēti un latviski.", |
| } |
| ], |
| ) |
|
|
| manifest = stage_human_training_artifacts(request, persistent_dir=str(tmp_path)) |
|
|
| assert manifest["artifact_type"] == "human-training-manifest" |
| assert manifest["ready_for_review"] is True |
| assert manifest["ready_for_training"] is True |
| assert manifest["input_summary"]["conversation_examples"] == 2 |
| assert manifest["quality_report"]["duplicates_removed"] >= 1 |
| assert "train_dataset" in manifest["artifacts"] |
| assert "preference_dataset" in manifest["artifacts"] |
| assert "eval_dataset" in manifest["artifacts"] |
|
|
| loaded = load_human_training_manifest(str(tmp_path), manifest["run_id"]) |
| assert loaded["run_id"] == manifest["run_id"] |
| launch_spec = build_human_training_launch_spec(loaded) |
| assert launch_spec.hub_model_id == "example-user/custom-model" |
| assert launch_spec.output_subdir == "runs/self-profile" |
| assert launch_spec.continue_model_path == "runs/checkpoints/latest" |
|
|
|
|
| def test_publish_human_training_artifacts_uploads_all_repo_files(tmp_path: Path) -> None: |
| manifest = stage_human_training_artifacts( |
| HumanTrainingRequest( |
| dataset_repo="example-user/memory-dataset", |
| model_repo="example-user/custom-model", |
| profile_facts=["Man patīk konkrēti tehniski skaidrojumi."], |
| ), |
| persistent_dir=str(tmp_path), |
| ) |
| uploaded: list[tuple[str, str, str]] = [] |
|
|
| def fake_save_file(**kwargs): |
| uploaded.append((kwargs["repo_id"], kwargs["repo_type"], kwargs["path_in_repo"])) |
| return {"saved": True, "path": kwargs["path_in_repo"]} |
|
|
| published = publish_human_training_artifacts(manifest, save_file=fake_save_file) |
|
|
| assert published |
| assert all(item[0] == "example-user/memory-dataset" for item in uploaded) |
| assert all(item[1] == "dataset" for item in uploaded) |
| assert any(item[2].startswith("data/human-training/") for item in uploaded) |
|
|
|
|
| def test_human_training_request_keeps_model_repo_compatibility_alias() -> None: |
| request = HumanTrainingRequest( |
| dataset_repo="example-user/memory-dataset", |
| model_repo="example-user/custom-model", |
| profile_facts=["Svarīga ir profesionāla latviešu valoda."], |
| ) |
|
|
| assert request.hub_model_id == "example-user/custom-model" |
| assert request.model_repo == "example-user/custom-model" |
|
|