maris-ai-master / core-python /tests /test_training_hf_compat.py
MarisUK's picture
Maris AI model sync
f440f03 verified
"""Tests Maris Hugging Face compatibility layer."""
from __future__ import annotations
import json
import sys
import types
from pathlib import Path
from maris_core.training.hf_compat import (
MARIS_COMPATIBILITY_ARTIFACT_NAME,
apply_maris_compatibility_identity,
maris_hf_compatible_path,
write_maris_compatibility_artifact,
)
def test_write_maris_compatibility_artifact_sanitizes_loader_fields(tmp_path: Path) -> None:
output_dir = tmp_path / "model"
output_dir.mkdir()
(output_dir / "config.json").write_text(
json.dumps(
{
"_name_or_path": "MarisUK/maris-ai-master",
"model_type": "qwen2",
"architectures": ["Qwen2ForCausalLM"],
"tokenizer_class": "Qwen2TokenizerFast",
"auto_map": {"AutoModelForCausalLM": "qwen2.modeling_qwen2.Qwen2ForCausalLM"},
}
),
encoding="utf-8",
)
(output_dir / "tokenizer_config.json").write_text(
json.dumps({"tokenizer_class": "Qwen2TokenizerFast"}),
encoding="utf-8",
)
(output_dir / "adapter_config.json").write_text(
json.dumps(
{
"base_model_class": "Qwen2ForCausalLM",
"parent_library": "transformers.models.qwen2.modeling_qwen2",
}
),
encoding="utf-8",
)
write_maris_compatibility_artifact(output_dir, maris_model_id="MarisUK/maris-ai-master")
apply_maris_compatibility_identity(output_dir)
manifest = json.loads(
(output_dir / MARIS_COMPATIBILITY_ARTIFACT_NAME).read_text(encoding="utf-8")
)
config_payload = json.loads((output_dir / "config.json").read_text(encoding="utf-8"))
tokenizer_payload = json.loads(
(output_dir / "tokenizer_config.json").read_text(encoding="utf-8")
)
adapter_payload = json.loads((output_dir / "adapter_config.json").read_text(encoding="utf-8"))
assert manifest["artifact_type"] == "maris-hf-compatibility"
assert sorted(manifest["artifacts"]) == [
"adapter_config.json",
"config.json",
"tokenizer_config.json",
]
assert config_payload["model_type"] == "maris"
assert config_payload["architectures"] == ["MarisCompatibleCausalLM"]
assert tokenizer_payload["tokenizer_class"] == "MarisCompatibleTokenizer"
assert adapter_payload["base_model_class"] == "MarisCompatibleCausalLM"
assert adapter_payload["parent_library"] == "maris.compat"
def test_maris_hf_compatible_path_restores_remote_snapshot(monkeypatch, tmp_path: Path) -> None:
snapshot_dir = tmp_path / "snapshot"
snapshot_dir.mkdir()
(snapshot_dir / "config.json").write_text(
json.dumps(
{
"_name_or_path": "MarisUK/maris-ai-master",
"model_type": "qwen2",
"architectures": ["Qwen2ForCausalLM"],
}
),
encoding="utf-8",
)
write_maris_compatibility_artifact(snapshot_dir, maris_model_id="MarisUK/maris-ai-master")
apply_maris_compatibility_identity(snapshot_dir)
monkeypatch.setitem(
sys.modules,
"huggingface_hub",
types.SimpleNamespace(
snapshot_download=lambda **kwargs: (
str(snapshot_dir)
if kwargs["repo_id"] == "MarisUK/maris-ai-master" and kwargs["repo_type"] == "model"
else None
)
),
)
monkeypatch.setenv("MARIS_HF_COMPAT_ALLOW_REMOTE_SNAPSHOT", "true")
with maris_hf_compatible_path("MarisUK/maris-ai-master") as compatible_path:
restored_dir = Path(compatible_path)
assert restored_dir != snapshot_dir
restored_config = json.loads(
restored_dir.joinpath("config.json").read_text(encoding="utf-8")
)
assert restored_config["model_type"] == "qwen2"
assert restored_config["architectures"] == ["Qwen2ForCausalLM"]
original_config = json.loads((snapshot_dir / "config.json").read_text(encoding="utf-8"))
assert original_config["model_type"] == "maris"
def test_maris_hf_compatible_path_can_force_remote_snapshot_restore(
monkeypatch, tmp_path: Path
) -> None:
snapshot_dir = tmp_path / "snapshot"
snapshot_dir.mkdir()
(snapshot_dir / "config.json").write_text(
json.dumps(
{
"_name_or_path": "MarisUK/maris-ai-master",
"model_type": "qwen2",
"architectures": ["Qwen2ForCausalLM"],
}
),
encoding="utf-8",
)
write_maris_compatibility_artifact(snapshot_dir, maris_model_id="MarisUK/maris-ai-master")
apply_maris_compatibility_identity(snapshot_dir)
monkeypatch.setitem(
sys.modules,
"huggingface_hub",
types.SimpleNamespace(
snapshot_download=lambda **kwargs: (
str(snapshot_dir)
if kwargs["repo_id"] == "MarisUK/maris-ai-master" and kwargs["repo_type"] == "model"
else None
)
),
)
with maris_hf_compatible_path(
"MarisUK/maris-ai-master", allow_remote_snapshot=True
) as compatible_path:
restored_dir = Path(compatible_path)
restored_config = json.loads(
restored_dir.joinpath("config.json").read_text(encoding="utf-8")
)
assert restored_config["model_type"] == "qwen2"
assert restored_config["architectures"] == ["Qwen2ForCausalLM"]
def test_maris_hf_compatible_path_restores_custom_runtime_snapshot(
monkeypatch, tmp_path: Path
) -> None:
snapshot_dir = tmp_path / "custom-runtime"
snapshot_dir.mkdir()
(snapshot_dir / "config.json").write_text(
json.dumps(
{
"_name_or_path": "custom-user/maris-runtime",
"model_type": "llama",
"architectures": ["LlamaForCausalLM"],
}
),
encoding="utf-8",
)
write_maris_compatibility_artifact(snapshot_dir, maris_model_id="custom-user/maris-runtime")
apply_maris_compatibility_identity(snapshot_dir)
monkeypatch.setitem(
sys.modules,
"huggingface_hub",
types.SimpleNamespace(
snapshot_download=lambda **kwargs: (
str(snapshot_dir)
if kwargs["repo_id"] == "custom-user/maris-runtime"
and kwargs["repo_type"] == "model"
else None
)
),
)
monkeypatch.setenv("MARIS_RUNTIME_COMPAT_ALLOW_REMOTE_SNAPSHOT", "true")
with maris_hf_compatible_path("custom-user/maris-runtime") as compatible_path:
restored_dir = Path(compatible_path)
restored_config = json.loads(
restored_dir.joinpath("config.json").read_text(encoding="utf-8")
)
assert restored_config["model_type"] == "llama"
assert restored_config["architectures"] == ["LlamaForCausalLM"]
def test_maris_hf_compatible_path_returns_remote_snapshot_path_without_restore_artifact(
monkeypatch,
tmp_path: Path,
) -> None:
snapshot_dir = tmp_path / "plain-runtime"
snapshot_dir.mkdir()
(snapshot_dir / "config.json").write_text(json.dumps({"model_type": "llama"}), encoding="utf-8")
monkeypatch.setitem(
sys.modules,
"huggingface_hub",
types.SimpleNamespace(
snapshot_download=lambda **kwargs: (
str(snapshot_dir)
if kwargs["repo_id"] == "custom-user/plain-runtime"
and kwargs["repo_type"] == "model"
else None
)
),
)
monkeypatch.setenv("MARIS_RUNTIME_COMPAT_ALLOW_REMOTE_SNAPSHOT", "true")
with maris_hf_compatible_path("custom-user/plain-runtime") as compatible_path:
assert compatible_path == str(snapshot_dir)