| """Tests HuggingFace datasets ielādēm.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| import subprocess |
| import sys |
| import types |
| from pathlib import Path |
|
|
| import pytest |
|
|
| import maris_core.data.datasets as datasets_module |
| from maris_core.data.datasets import HFDatasetError, _find_snapshot_data_files, load_hf_dataset |
| from maris_core.data.preprocessing import record_to_training_text |
| from maris_core.data.validator import DatasetValidationError, validate_dataset_dir |
| from maris_core.training.train import train |
|
|
| REPO_ID = "MarisUK/maris-ai-memory" |
| REPO_ROOT = Path(__file__).resolve().parents[2] |
|
|
|
|
| class FakeGeneratedDataset(list): |
| @classmethod |
| def from_generator(cls, generator): |
| return cls(list(generator())) |
|
|
|
|
| class FakeDatasetDict(dict): |
| pass |
|
|
|
|
| @pytest.mark.parametrize( |
| "error_message", |
| [ |
| "The directory doesn't contain any data files", |
| f"No (supported) data files found in {REPO_ID}", |
| ], |
| ) |
| def test_load_hf_dataset_falls_back_to_snapshot_for_nested_jsonl_files( |
| monkeypatch, |
| tmp_path: Path, |
| error_message: str, |
| ) -> None: |
| repo_id = REPO_ID |
| snapshot_dir = tmp_path / "snapshot" |
| data_file = snapshot_dir / "data" / "conversation" / "sample.jsonl" |
| data_file.parent.mkdir(parents=True) |
| data_file.write_text( |
| json.dumps({"user": "Sveiki", "assistant": "Sveiks!"}, ensure_ascii=False) + "\n", |
| encoding="utf-8", |
| ) |
|
|
| calls: list[tuple[object, object, object]] = [] |
|
|
| class EmptyDatasetError(Exception): |
| pass |
|
|
| def fake_load_dataset(path, *args, **kwargs): |
| calls.append((path, args, kwargs)) |
| if path == repo_id: |
| raise EmptyDatasetError(error_message) |
| if path == "json": |
| return {"train": [str(data_file)]} |
| raise AssertionError(f"Unexpected load_dataset call: {path!r}") |
|
|
| def fake_snapshot_download(**kwargs): |
| assert kwargs["repo_id"] == repo_id |
| assert kwargs["repo_type"] == "dataset" |
| return str(snapshot_dir) |
|
|
| monkeypatch.setitem( |
| sys.modules, |
| "datasets", |
| types.SimpleNamespace( |
| Dataset=FakeGeneratedDataset, |
| DatasetDict=FakeDatasetDict, |
| load_dataset=fake_load_dataset, |
| ), |
| ) |
| monkeypatch.setitem( |
| sys.modules, |
| "huggingface_hub", |
| types.SimpleNamespace(snapshot_download=fake_snapshot_download), |
| ) |
| monkeypatch.delenv("HF_TOKEN", raising=False) |
|
|
| dataset = load_hf_dataset(repo_id) |
|
|
| assert list(dataset["train"]) == [ |
| {"text": record_to_training_text({"user": "Sveiki", "assistant": "Sveiks!"})} |
| ] |
| assert calls == [(repo_id, (), {"token": None})] |
|
|
|
|
| def test_load_hf_dataset_falls_back_for_wrapped_dataset_generation_error( |
| monkeypatch, |
| tmp_path: Path, |
| ) -> None: |
| repo_id = REPO_ID |
| snapshot_dir = tmp_path / "snapshot" |
| data_file = snapshot_dir / "data" / "conversation" / "sample.jsonl" |
| data_file.parent.mkdir(parents=True) |
| data_file.write_text( |
| json.dumps({"user": "Sveiki", "assistant": "Sveiks!"}, ensure_ascii=False) + "\n", |
| encoding="utf-8", |
| ) |
|
|
| calls: list[tuple[object, object, object]] = [] |
|
|
| class EmptyDatasetError(Exception): |
| pass |
|
|
| class DatasetGenerationError(Exception): |
| pass |
|
|
| def fake_load_dataset(path, *args, **kwargs): |
| calls.append((path, args, kwargs)) |
| if path == repo_id: |
| try: |
| raise EmptyDatasetError(f"No (supported) data files found in {repo_id}") |
| except EmptyDatasetError as source_exc: |
| raise DatasetGenerationError( |
| "An error occurred while generating the dataset" |
| ) from source_exc |
| if path == "json": |
| return {"train": [str(data_file)]} |
| raise AssertionError(f"Unexpected load_dataset call: {path!r}") |
|
|
| def fake_snapshot_download(**kwargs): |
| assert kwargs["repo_id"] == repo_id |
| assert kwargs["repo_type"] == "dataset" |
| return str(snapshot_dir) |
|
|
| monkeypatch.setitem( |
| sys.modules, |
| "datasets", |
| types.SimpleNamespace( |
| Dataset=FakeGeneratedDataset, |
| DatasetDict=FakeDatasetDict, |
| load_dataset=fake_load_dataset, |
| ), |
| ) |
| monkeypatch.setitem( |
| sys.modules, |
| "huggingface_hub", |
| types.SimpleNamespace(snapshot_download=fake_snapshot_download), |
| ) |
| monkeypatch.delenv("HF_TOKEN", raising=False) |
|
|
| dataset = load_hf_dataset(repo_id) |
|
|
| assert list(dataset["train"]) == [ |
| {"text": record_to_training_text({"user": "Sveiki", "assistant": "Sveiks!"})} |
| ] |
| assert calls == [(repo_id, (), {"token": None})] |
|
|
|
|
| @pytest.mark.parametrize( |
| "error_message", |
| [ |
| "An error occurred while generating the dataset", |
| "Dataset generation failed in a different way", |
| ], |
| ) |
| def test_load_hf_dataset_falls_back_for_generic_dataset_generation_error( |
| monkeypatch, |
| tmp_path: Path, |
| error_message: str, |
| ) -> None: |
| repo_id = REPO_ID |
| snapshot_dir = tmp_path / "snapshot" |
| data_file = snapshot_dir / "data" / "conversation" / "sample.jsonl" |
| data_file.parent.mkdir(parents=True) |
| data_file.write_text( |
| json.dumps({"user": "Sveiki", "assistant": "Sveiks!"}, ensure_ascii=False) + "\n", |
| encoding="utf-8", |
| ) |
|
|
| calls: list[tuple[object, object, object]] = [] |
|
|
| class DatasetGenerationError(Exception): |
| pass |
|
|
| def fake_load_dataset(path, *args, **kwargs): |
| calls.append((path, args, kwargs)) |
| if path == repo_id: |
| raise DatasetGenerationError(error_message) |
| raise AssertionError(f"Unexpected load_dataset call: {path!r}") |
|
|
| def fake_snapshot_download(**kwargs): |
| assert kwargs["repo_id"] == repo_id |
| assert kwargs["repo_type"] == "dataset" |
| return str(snapshot_dir) |
|
|
| monkeypatch.setitem( |
| sys.modules, |
| "datasets", |
| types.SimpleNamespace( |
| Dataset=FakeGeneratedDataset, |
| DatasetDict=FakeDatasetDict, |
| load_dataset=fake_load_dataset, |
| ), |
| ) |
| monkeypatch.setitem( |
| sys.modules, |
| "huggingface_hub", |
| types.SimpleNamespace(snapshot_download=fake_snapshot_download), |
| ) |
| monkeypatch.delenv("HF_TOKEN", raising=False) |
|
|
| dataset = load_hf_dataset(repo_id) |
|
|
| assert list(dataset["train"]) == [ |
| {"text": record_to_training_text({"user": "Sveiki", "assistant": "Sveiks!"})} |
| ] |
| assert calls == [(repo_id, (), {"token": None})] |
|
|
|
|
| def test_load_hf_dataset_falls_back_for_schema_cast_error(monkeypatch, tmp_path: Path) -> None: |
| repo_id = REPO_ID |
| snapshot_dir = tmp_path / "snapshot" |
| first_file = snapshot_dir / "data" / "image" / "sample-1.jsonl" |
| second_file = snapshot_dir / "data" / "image" / "sample-2.jsonl" |
| first_file.parent.mkdir(parents=True) |
| first_file.write_text( |
| json.dumps( |
| { |
| "timestamp": "2026-03-30T12:05:00Z", |
| "prompt": "A", |
| "metadata": {"style": "cinematic"}, |
| }, |
| ensure_ascii=False, |
| ) |
| + "\n", |
| encoding="utf-8", |
| ) |
| second_file.write_text( |
| json.dumps( |
| { |
| "timestamp": "2026-03-30T12:06:00Z", |
| "prompt": "B", |
| "metadata": {"image_b64": "..."}, |
| }, |
| ensure_ascii=False, |
| ) |
| + "\n", |
| encoding="utf-8", |
| ) |
|
|
| calls: list[tuple[object, object, object]] = [] |
|
|
| class DatasetGenerationError(Exception): |
| pass |
|
|
| def fake_load_dataset(path, *args, **kwargs): |
| calls.append((path, args, kwargs)) |
| if path == repo_id: |
| try: |
| raise TypeError( |
| "Couldn't cast array of type\nstruct<image_b64: string>\nto\n{'style': Value('string')}" |
| ) |
| except TypeError as source_exc: |
| raise DatasetGenerationError( |
| "An error occurred while generating the dataset" |
| ) from source_exc |
| raise AssertionError(f"Unexpected load_dataset call: {path!r}") |
|
|
| def fake_snapshot_download(**kwargs): |
| assert kwargs["repo_id"] == repo_id |
| assert kwargs["repo_type"] == "dataset" |
| return str(snapshot_dir) |
|
|
| monkeypatch.setitem( |
| sys.modules, |
| "datasets", |
| types.SimpleNamespace( |
| Dataset=FakeGeneratedDataset, |
| DatasetDict=FakeDatasetDict, |
| load_dataset=fake_load_dataset, |
| ), |
| ) |
| monkeypatch.setitem( |
| sys.modules, |
| "huggingface_hub", |
| types.SimpleNamespace(snapshot_download=fake_snapshot_download), |
| ) |
| monkeypatch.delenv("HF_TOKEN", raising=False) |
|
|
| dataset = load_hf_dataset(repo_id) |
|
|
| assert list(dataset["train"]) == [ |
| { |
| "text": record_to_training_text( |
| { |
| "timestamp": "2026-03-30T12:05:00Z", |
| "prompt": "A", |
| "metadata": {"style": "cinematic"}, |
| } |
| ) |
| }, |
| { |
| "text": record_to_training_text( |
| { |
| "timestamp": "2026-03-30T12:06:00Z", |
| "prompt": "B", |
| "metadata": {"image_b64": "..."}, |
| } |
| ) |
| }, |
| ] |
| assert calls == [(repo_id, (), {"token": None})] |
|
|
|
|
| def test_load_hf_dataset_downloads_repo_data_files_when_snapshot_cache_is_incomplete( |
| monkeypatch, |
| tmp_path: Path, |
| ) -> None: |
| repo_id = REPO_ID |
| snapshot_dir = tmp_path / "snapshot" |
| snapshot_dir.mkdir() |
| (snapshot_dir / "README.md").write_text("metadata only", encoding="utf-8") |
| data_file = snapshot_dir / "data" / "conversation" / "sample.jsonl" |
|
|
| calls: list[tuple[object, object, object]] = [] |
| snapshot_calls: list[dict[str, object]] = [] |
|
|
| class EmptyDatasetError(Exception): |
| pass |
|
|
| class FakeHfApi: |
| def __init__(self, *, token): |
| self.token = token |
|
|
| def list_repo_files(self, *, repo_id, repo_type): |
| assert repo_id == REPO_ID |
| assert repo_type == "dataset" |
| return ["README.md", "data/conversation/sample.jsonl"] |
|
|
| def fake_load_dataset(path, *args, **kwargs): |
| calls.append((path, args, kwargs)) |
| if path == repo_id: |
| raise EmptyDatasetError("The directory doesn't contain any data files") |
| if path == "json": |
| return {"train": [str(data_file)]} |
| raise AssertionError(f"Unexpected load_dataset call: {path!r}") |
|
|
| def fake_snapshot_download(**kwargs): |
| snapshot_calls.append(kwargs) |
| if kwargs.get("allow_patterns"): |
| data_file.parent.mkdir(parents=True, exist_ok=True) |
| data_file.write_text( |
| json.dumps({"user": "Sveiki", "assistant": "Sveiks!"}, ensure_ascii=False) + "\n", |
| encoding="utf-8", |
| ) |
| return str(snapshot_dir) |
|
|
| monkeypatch.setitem( |
| sys.modules, |
| "datasets", |
| types.SimpleNamespace( |
| Dataset=FakeGeneratedDataset, |
| DatasetDict=FakeDatasetDict, |
| load_dataset=fake_load_dataset, |
| ), |
| ) |
| monkeypatch.setitem( |
| sys.modules, |
| "huggingface_hub", |
| types.SimpleNamespace( |
| HfApi=FakeHfApi, |
| snapshot_download=fake_snapshot_download, |
| ), |
| ) |
| monkeypatch.delenv("HF_TOKEN", raising=False) |
|
|
| dataset = load_hf_dataset(repo_id) |
|
|
| assert list(dataset["train"]) == [ |
| {"text": record_to_training_text({"user": "Sveiki", "assistant": "Sveiks!"})} |
| ] |
| assert calls == [(repo_id, (), {"token": None})] |
| assert snapshot_calls == [ |
| {"repo_id": repo_id, "repo_type": "dataset", "token": None}, |
| { |
| "repo_id": repo_id, |
| "repo_type": "dataset", |
| "token": None, |
| "allow_patterns": ["data/conversation/sample.jsonl"], |
| }, |
| ] |
|
|
|
|
| def test_load_hf_dataset_raises_when_snapshot_contains_no_supported_files( |
| monkeypatch, |
| tmp_path: Path, |
| ) -> None: |
| repo_id = REPO_ID |
| snapshot_dir = tmp_path / "snapshot" |
| snapshot_dir.mkdir() |
| (snapshot_dir / ".gitattributes").write_text("*.jsonl filter=lfs", encoding="utf-8") |
|
|
| class EmptyDatasetError(Exception): |
| pass |
|
|
| class FakeHfApi: |
| def __init__(self, *, token): |
| self.token = token |
|
|
| def list_repo_files(self, *, repo_id, repo_type): |
| assert repo_id == REPO_ID |
| assert repo_type == "dataset" |
| return [".gitattributes"] |
|
|
| def fake_load_dataset(path, *args, **kwargs): |
| if path == repo_id: |
| raise EmptyDatasetError("The directory doesn't contain any data files") |
| raise AssertionError(f"Unexpected load_dataset call: {path!r}") |
|
|
| def fake_snapshot_download(**kwargs): |
| return str(snapshot_dir) |
|
|
| monkeypatch.setitem( |
| sys.modules, |
| "datasets", |
| types.SimpleNamespace(load_dataset=fake_load_dataset), |
| ) |
| monkeypatch.setitem( |
| sys.modules, |
| "huggingface_hub", |
| types.SimpleNamespace(HfApi=FakeHfApi, snapshot_download=fake_snapshot_download), |
| ) |
|
|
| try: |
| load_hf_dataset(repo_id) |
| except HFDatasetError as exc: |
| assert "nesatur nevienu atbalstītu datu failu" in str(exc) |
| assert "data/conversation/bootstrap.jsonl" in str(exc) |
| assert ".gitattributes" in str(exc) |
| assert "Git LFS" in str(exc) |
| else: |
| raise AssertionError("load_hf_dataset() should fail when snapshot has no supported files") |
|
|
|
|
| def test_load_hf_dataset_raises_when_repo_lists_data_files_but_snapshot_stays_empty( |
| monkeypatch, |
| tmp_path: Path, |
| ) -> None: |
| repo_id = REPO_ID |
| snapshot_dir = tmp_path / "snapshot" |
| snapshot_dir.mkdir() |
| (snapshot_dir / ".gitattributes").write_text("*.jsonl filter=lfs", encoding="utf-8") |
|
|
| class EmptyDatasetError(Exception): |
| pass |
|
|
| class FakeHfApi: |
| def __init__(self, *, token): |
| self.token = token |
|
|
| def list_repo_files(self, *, repo_id, repo_type): |
| assert repo_id == REPO_ID |
| assert repo_type == "dataset" |
| return ["data/conversation/sample.jsonl"] |
|
|
| def fake_load_dataset(path, *args, **kwargs): |
| if path == repo_id: |
| raise EmptyDatasetError("The directory doesn't contain any data files") |
| raise AssertionError(f"Unexpected load_dataset call: {path!r}") |
|
|
| def fake_snapshot_download(**kwargs): |
| return str(snapshot_dir) |
|
|
| monkeypatch.setitem( |
| sys.modules, |
| "datasets", |
| types.SimpleNamespace(load_dataset=fake_load_dataset), |
| ) |
| monkeypatch.setitem( |
| sys.modules, |
| "huggingface_hub", |
| types.SimpleNamespace(HfApi=FakeHfApi, snapshot_download=fake_snapshot_download), |
| ) |
|
|
| with pytest.raises(HFDatasetError) as excinfo: |
| load_hf_dataset(repo_id) |
|
|
| message = str(excinfo.value) |
| assert "satur atbalstītus datu failus" in message |
| assert "data/conversation/sample.jsonl" in message |
| assert ".gitattributes" in message |
|
|
|
|
| def test_load_hf_dataset_retries_direct_load_with_fallback_cache_dir( |
| monkeypatch, |
| tmp_path: Path, |
| ) -> None: |
| repo_id = REPO_ID |
| fallback_cache_dir = tmp_path / "hf-cache" |
| calls: list[tuple[object, object, object]] = [] |
|
|
| class CacheLockError(OSError): |
| pass |
|
|
| def fake_load_dataset(path, *args, **kwargs): |
| calls.append((path, args, kwargs)) |
| if path != repo_id: |
| raise AssertionError(f"Unexpected load_dataset call: {path!r}") |
| if "cache_dir" not in kwargs: |
| raise CacheLockError( |
| 5, |
| "Input/output error", |
| "/data/.cache/huggingface/hub/.locks/datasets--MarisUK--maris-ai-memory/test.lock", |
| ) |
| return FakeDatasetDict({"train": [{"text": "ok"}]}) |
|
|
| monkeypatch.setitem( |
| sys.modules, |
| "datasets", |
| types.SimpleNamespace( |
| Dataset=FakeGeneratedDataset, |
| DatasetDict=FakeDatasetDict, |
| load_dataset=fake_load_dataset, |
| ), |
| ) |
| monkeypatch.setenv("MARIS_HF_FALLBACK_CACHE_DIR", str(fallback_cache_dir)) |
| monkeypatch.delenv("HF_TOKEN", raising=False) |
|
|
| dataset = load_hf_dataset(repo_id) |
|
|
| assert list(dataset["train"]) == [{"text": "ok"}] |
| assert calls == [ |
| (repo_id, (), {"token": None}), |
| (repo_id, (), {"token": None, "cache_dir": str(fallback_cache_dir)}), |
| ] |
| assert fallback_cache_dir.is_dir() |
|
|
|
|
| def test_load_hf_dataset_retries_snapshot_download_with_fallback_cache_dir( |
| monkeypatch, |
| tmp_path: Path, |
| ) -> None: |
| repo_id = REPO_ID |
| fallback_cache_dir = tmp_path / "hf-cache" |
| snapshot_dir = tmp_path / "snapshot" |
| data_file = snapshot_dir / "data" / "conversation" / "sample.csv" |
| data_file.parent.mkdir(parents=True) |
| data_file.write_text("user,assistant\nSveiki,Sveiks!\n", encoding="utf-8") |
|
|
| load_calls: list[tuple[object, object, object]] = [] |
| snapshot_calls: list[dict[str, object]] = [] |
|
|
| class EmptyDatasetError(Exception): |
| pass |
|
|
| class CacheLockError(OSError): |
| pass |
|
|
| def fake_load_dataset(path, *args, **kwargs): |
| load_calls.append((path, args, kwargs)) |
| if path == repo_id: |
| raise EmptyDatasetError("The directory doesn't contain any data files") |
| if path == "csv": |
| assert kwargs["cache_dir"] == str(fallback_cache_dir) |
| return FakeDatasetDict({"train": [{"text": "ok"}]}) |
| raise AssertionError(f"Unexpected load_dataset call: {path!r}") |
|
|
| def fake_snapshot_download(**kwargs): |
| snapshot_calls.append(kwargs) |
| if "cache_dir" not in kwargs: |
| raise CacheLockError( |
| 5, |
| "Input/output error", |
| "/data/.cache/huggingface/hub/.locks/datasets--MarisUK--maris-ai-memory/test.lock", |
| ) |
| return str(snapshot_dir) |
|
|
| monkeypatch.setitem( |
| sys.modules, |
| "datasets", |
| types.SimpleNamespace( |
| Dataset=FakeGeneratedDataset, |
| DatasetDict=FakeDatasetDict, |
| load_dataset=fake_load_dataset, |
| ), |
| ) |
| monkeypatch.setitem( |
| sys.modules, |
| "huggingface_hub", |
| types.SimpleNamespace(snapshot_download=fake_snapshot_download), |
| ) |
| monkeypatch.setenv("MARIS_HF_FALLBACK_CACHE_DIR", str(fallback_cache_dir)) |
| monkeypatch.delenv("HF_TOKEN", raising=False) |
|
|
| dataset = load_hf_dataset(repo_id) |
|
|
| assert list(dataset["train"]) == [{"text": "ok"}] |
| assert load_calls == [ |
| (repo_id, (), {"token": None}), |
| ( |
| "csv", |
| (), |
| {"data_files": {"train": [str(data_file)]}, "cache_dir": str(fallback_cache_dir)}, |
| ), |
| ] |
| assert snapshot_calls == [ |
| {"repo_id": repo_id, "repo_type": "dataset", "token": None}, |
| { |
| "repo_id": repo_id, |
| "repo_type": "dataset", |
| "token": None, |
| "cache_dir": str(fallback_cache_dir), |
| }, |
| ] |
| assert fallback_cache_dir.is_dir() |
|
|
|
|
| def test_find_snapshot_data_files_skips_walk_os_errors( |
| monkeypatch, |
| tmp_path: Path, |
| ) -> None: |
| snapshot_dir = tmp_path / "snapshot" |
| data_file = snapshot_dir / "data" / "conversation" / "sample.jsonl" |
| data_file.parent.mkdir(parents=True) |
| data_file.write_text( |
| json.dumps({"user": "Sveiki", "assistant": "Sveiks!"}, ensure_ascii=False) + "\n", |
| encoding="utf-8", |
| ) |
|
|
| def fake_walk( |
| path: str | os.PathLike[str], |
| *, |
| topdown: bool = True, |
| onerror: object | None = None, |
| followlinks: bool = False, |
| ): |
| assert Path(path) == snapshot_dir / "data" |
| assert topdown is True |
| assert followlinks is False |
| walk_error = OSError(22, "Invalid argument") |
| walk_error.filename = "/proc/19/task/19/net" |
| if onerror is not None: |
| onerror(walk_error) |
| yield (os.fspath(data_file.parent), [], [data_file.name]) |
|
|
| monkeypatch.setattr(datasets_module.os, "walk", fake_walk) |
|
|
| dataset_format, data_files = _find_snapshot_data_files(snapshot_dir) |
|
|
| assert dataset_format == "json" |
| assert data_files == [str(data_file)] |
|
|
|
|
| def test_load_hf_dataset_rejects_filesystem_root_snapshot(monkeypatch) -> None: |
| repo_id = REPO_ID |
|
|
| class EmptyDatasetError(Exception): |
| pass |
|
|
| def fake_load_dataset(path: str, *args: object, **kwargs: object): |
| del args, kwargs |
| if path == repo_id: |
| raise EmptyDatasetError("The directory doesn't contain any data files") |
| raise AssertionError(f"Unexpected load_dataset call: {path!r}") |
|
|
| def fake_snapshot_download(**kwargs: object) -> str: |
| assert kwargs["repo_id"] == repo_id |
| assert kwargs["repo_type"] == "dataset" |
| return "/" |
|
|
| monkeypatch.setitem( |
| sys.modules, |
| "datasets", |
| types.SimpleNamespace(load_dataset=fake_load_dataset), |
| ) |
| monkeypatch.setitem( |
| sys.modules, |
| "huggingface_hub", |
| types.SimpleNamespace(snapshot_download=fake_snapshot_download), |
| ) |
|
|
| with pytest.raises(HFDatasetError, match="nav droša rekursīvai skenēšanai"): |
| load_hf_dataset(repo_id) |
|
|
|
|
| def test_load_hf_dataset_merges_named_splits_into_train(monkeypatch) -> None: |
| repo_id = REPO_ID |
| calls: list[tuple[object, object, object]] = [] |
|
|
| def fake_load_dataset(path, *args, **kwargs): |
| calls.append((path, args, kwargs)) |
| if path == repo_id: |
| return { |
| "conversation": [{"user": "Sveiki", "assistant": "Labdien"}], |
| "image": [{"prompt": "Uzzīmē kaķi", "metadata": {"style": "anime"}}], |
| } |
| raise AssertionError(f"Unexpected load_dataset call: {path!r}") |
|
|
| monkeypatch.setitem( |
| sys.modules, |
| "datasets", |
| types.SimpleNamespace( |
| Dataset=FakeGeneratedDataset, |
| DatasetDict=FakeDatasetDict, |
| load_dataset=fake_load_dataset, |
| ), |
| ) |
| monkeypatch.delenv("HF_TOKEN", raising=False) |
|
|
| dataset = load_hf_dataset(repo_id) |
|
|
| assert list(dataset["train"]) == [ |
| {"text": record_to_training_text({"user": "Sveiki", "assistant": "Labdien"})}, |
| { |
| "text": record_to_training_text( |
| {"prompt": "Uzzīmē kaķi", "metadata": {"style": "anime"}} |
| ) |
| }, |
| ] |
| assert calls == [(repo_id, (), {"token": None})] |
|
|
|
|
| def test_load_hf_dataset_fallback_normalizes_mixed_top_level_keys_to_text( |
| monkeypatch, |
| tmp_path: Path, |
| ) -> None: |
| repo_id = REPO_ID |
| snapshot_dir = tmp_path / "snapshot" |
| first_file = snapshot_dir / "data" / "autonomous" / "sample-1.jsonl" |
| second_file = snapshot_dir / "data" / "autonomous" / "sample-2.jsonl" |
| first_file.parent.mkdir(parents=True) |
| first_file.write_text( |
| json.dumps( |
| { |
| "timestamp": "2026-03-30T12:05:00Z", |
| "prompt": "Izveido plānu", |
| "metadata": {"priority": "high"}, |
| }, |
| ensure_ascii=False, |
| ) |
| + "\n", |
| encoding="utf-8", |
| ) |
| second_file.write_text( |
| json.dumps( |
| { |
| "timestamp": "2026-03-30T12:06:00Z", |
| "workflow": {"name": "planner", "step": 2}, |
| "metadata": {"status": "running"}, |
| }, |
| ensure_ascii=False, |
| ) |
| + "\n", |
| encoding="utf-8", |
| ) |
|
|
| class DatasetGenerationError(Exception): |
| pass |
|
|
| class EmptyDatasetError(Exception): |
| pass |
|
|
| class SchemaCheckingDataset(list): |
| @classmethod |
| def from_generator(cls, generator): |
| rows = list(generator()) |
| if any(set(row) != {"text"} for row in rows): |
| raise DatasetGenerationError("generator output schema is not stable") |
| return cls(rows) |
|
|
| def fake_load_dataset(path, *args, **kwargs): |
| if path == repo_id: |
| try: |
| raise TypeError( |
| "Couldn't cast\n" |
| "timestamp: timestamp[s]\n" |
| "prompt: string\n" |
| "metadata: struct\n" |
| "| 'workflow'" |
| ) |
| except TypeError as source_exc: |
| raise DatasetGenerationError( |
| "An error occurred while generating the dataset" |
| ) from source_exc |
| raise AssertionError(f"Unexpected load_dataset call: {path!r}") |
|
|
| def fake_snapshot_download(**kwargs): |
| assert kwargs["repo_id"] == repo_id |
| assert kwargs["repo_type"] == "dataset" |
| return str(snapshot_dir) |
|
|
| monkeypatch.setitem( |
| sys.modules, |
| "datasets", |
| types.SimpleNamespace( |
| Dataset=SchemaCheckingDataset, |
| DatasetDict=FakeDatasetDict, |
| load_dataset=fake_load_dataset, |
| ), |
| ) |
| monkeypatch.setitem( |
| sys.modules, |
| "huggingface_hub", |
| types.SimpleNamespace(snapshot_download=fake_snapshot_download), |
| ) |
| monkeypatch.delenv("HF_TOKEN", raising=False) |
|
|
| dataset = load_hf_dataset(repo_id) |
|
|
| assert list(dataset["train"]) == [ |
| { |
| "text": record_to_training_text( |
| { |
| "timestamp": "2026-03-30T12:05:00Z", |
| "prompt": "Izveido plānu", |
| "metadata": {"priority": "high"}, |
| } |
| ) |
| }, |
| { |
| "text": record_to_training_text( |
| { |
| "timestamp": "2026-03-30T12:06:00Z", |
| "workflow": {"name": "planner", "step": 2}, |
| "metadata": {"status": "running"}, |
| } |
| ) |
| }, |
| ] |
|
|
|
|
| def test_load_hf_dataset_raises_actionable_error_for_invalid_snapshot_jsonl( |
| monkeypatch, |
| tmp_path: Path, |
| ) -> None: |
| repo_id = REPO_ID |
| snapshot_dir = tmp_path / "snapshot" |
| data_file = snapshot_dir / "data" / "conversation" / "broken.jsonl" |
| data_file.parent.mkdir(parents=True) |
| data_file.write_text('{"user":"Sveiki"}\n{"assistant":\n', encoding="utf-8") |
|
|
| class DatasetGenerationError(Exception): |
| pass |
|
|
| class FakeGeneratedDatasetWithWrapper(list): |
| @classmethod |
| def from_generator(cls, generator): |
| try: |
| return cls(list(generator())) |
| except Exception as source_exc: |
| raise DatasetGenerationError( |
| "An error occurred while generating the dataset" |
| ) from source_exc |
|
|
| class FakeDatasetDict(dict): |
| pass |
|
|
| class EmptyDatasetError(Exception): |
| pass |
|
|
| def fake_load_dataset(path, *args, **kwargs): |
| if path == repo_id: |
| raise EmptyDatasetError("The directory doesn't contain any data files") |
| raise AssertionError(f"Unexpected load_dataset call: {path!r}") |
|
|
| def fake_snapshot_download(**kwargs): |
| assert kwargs["repo_id"] == repo_id |
| assert kwargs["repo_type"] == "dataset" |
| return str(snapshot_dir) |
|
|
| monkeypatch.setitem( |
| sys.modules, |
| "datasets", |
| types.SimpleNamespace( |
| Dataset=FakeGeneratedDatasetWithWrapper, |
| DatasetDict=FakeDatasetDict, |
| load_dataset=fake_load_dataset, |
| ), |
| ) |
| monkeypatch.setitem( |
| sys.modules, |
| "huggingface_hub", |
| types.SimpleNamespace(snapshot_download=fake_snapshot_download), |
| ) |
| monkeypatch.delenv("HF_TOKEN", raising=False) |
|
|
| with pytest.raises(HFDatasetError) as excinfo: |
| load_hf_dataset(repo_id) |
|
|
| message = str(excinfo.value) |
| assert "datu failus neizdevās nolasīt apmācībai" in message |
| assert "data/conversation/broken.jsonl" in message |
| assert "An error occurred while generating the dataset" in message |
|
|
|
|
| def test_load_hf_dataset_raises_actionable_error_for_invalid_snapshot_csv( |
| monkeypatch, |
| tmp_path: Path, |
| ) -> None: |
| repo_id = REPO_ID |
| snapshot_dir = tmp_path / "snapshot" |
| data_file = snapshot_dir / "data" / "conversation" / "broken.csv" |
| data_file.parent.mkdir(parents=True) |
| data_file.write_text("user,assistant\nSveiki,Labdien\n", encoding="utf-8") |
|
|
| class DatasetGenerationError(Exception): |
| pass |
|
|
| class EmptyDatasetError(Exception): |
| pass |
|
|
| def fake_load_dataset(path, *args, **kwargs): |
| if path == repo_id: |
| raise EmptyDatasetError("The directory doesn't contain any data files") |
| if path == "csv": |
| raise DatasetGenerationError("An error occurred while generating the dataset") |
| raise AssertionError(f"Unexpected load_dataset call: {path!r}") |
|
|
| def fake_snapshot_download(**kwargs): |
| assert kwargs["repo_id"] == repo_id |
| assert kwargs["repo_type"] == "dataset" |
| return str(snapshot_dir) |
|
|
| monkeypatch.setitem( |
| sys.modules, |
| "datasets", |
| types.SimpleNamespace(load_dataset=fake_load_dataset), |
| ) |
| monkeypatch.setitem( |
| sys.modules, |
| "huggingface_hub", |
| types.SimpleNamespace(snapshot_download=fake_snapshot_download), |
| ) |
| monkeypatch.delenv("HF_TOKEN", raising=False) |
|
|
| with pytest.raises(HFDatasetError) as excinfo: |
| load_hf_dataset(repo_id) |
|
|
| message = str(excinfo.value) |
| assert "datu failus neizdevās nolasīt apmācībai" in message |
| assert "data/conversation/broken.csv" in message |
| assert "An error occurred while generating the dataset" in message |
|
|
|
|
| def test_train_exits_with_actionable_message_when_dataset_repo_is_empty(monkeypatch) -> None: |
| message = ( |
| "HF dataset repo MarisUK/maris-ai-memory pašlaik nesatur nevienu atbalstītu datu failu " |
| "(.jsonl, .json, .csv, .parquet)." |
| ) |
|
|
| monkeypatch.setattr( |
| "maris_core.training.train.load_hf_dataset", |
| lambda repo_id: (_ for _ in ()).throw(HFDatasetError(message)), |
| ) |
|
|
| with pytest.raises(SystemExit) as excinfo: |
| train() |
|
|
| assert str(excinfo.value) == message |
|
|
|
|
| def test_validate_dataset_dir_accepts_repo_bootstrap_dataset() -> None: |
| summary = validate_dataset_dir(REPO_ROOT / "data") |
|
|
| assert summary.files_checked == 6 |
| assert summary.total_records == 121 |
| assert summary.duplicate_count == 0 |
| assert summary.counts_by_category == { |
| "conversation": 28, |
| "code": 45, |
| "image": 12, |
| "music": 12, |
| "video": 12, |
| "autonomous": 12, |
| } |
|
|
|
|
| def test_validate_dataset_dir_rejects_invalid_records_and_duplicates(tmp_path: Path) -> None: |
| dataset_dir = tmp_path / "data" |
| code_dir = dataset_dir / "code" |
| code_dir.mkdir(parents=True) |
| (code_dir / "sample.jsonl").write_text( |
| "\n".join( |
| [ |
| json.dumps( |
| { |
| "timestamp": "2026-04-06T01:00:00Z", |
| "type": "code", |
| "prompt": "Write a validator", |
| "metadata": {"language": "python"}, |
| "source": "test-bootstrap", |
| } |
| ), |
| json.dumps( |
| { |
| "timestamp": "not-a-date", |
| "type": "video", |
| "prompt": "Write a validator", |
| "metadata": {}, |
| } |
| ), |
| ] |
| ) |
| + "\n", |
| encoding="utf-8", |
| ) |
|
|
| with pytest.raises(DatasetValidationError) as exc_info: |
| validate_dataset_dir(dataset_dir) |
|
|
| issues = exc_info.value.issues |
| assert any("timestamp" in issue for issue in issues) |
| assert any("lauks 'type' ir 'video'" in issue for issue in issues) |
| assert any("lauka 'source'" in issue for issue in issues) |
| assert any("laukā 'metadata'" in issue for issue in issues) |
| assert any("dublikāts" in issue for issue in issues) |
|
|
|
|
| def test_validate_datasets_cli_validates_repo_bootstrap_dataset() -> None: |
| script_path = REPO_ROOT / "core-python" / "scripts" / "validate_datasets.py" |
|
|
| assert REPO_ROOT.is_dir() |
| assert script_path.is_file() |
|
|
| result = subprocess.run( |
| [ |
| sys.executable, |
| str(script_path), |
| str(REPO_ROOT / "data"), |
| ], |
| check=False, |
| capture_output=True, |
| text=True, |
| ) |
|
|
| assert result.returncode == 0 |
| assert "Dataset validācija veiksmīga" in result.stdout |
|
|