"""Tests Maris Hugging Face compatibility layer.""" from __future__ import annotations import json import sys import types from pathlib import Path from maris_core.training.hf_compat import ( MARIS_COMPATIBILITY_ARTIFACT_NAME, apply_maris_compatibility_identity, maris_hf_compatible_path, write_maris_compatibility_artifact, ) def test_write_maris_compatibility_artifact_sanitizes_loader_fields(tmp_path: Path) -> None: output_dir = tmp_path / "model" output_dir.mkdir() (output_dir / "config.json").write_text( json.dumps( { "_name_or_path": "MarisUK/maris-ai-master", "model_type": "qwen2", "architectures": ["Qwen2ForCausalLM"], "tokenizer_class": "Qwen2TokenizerFast", "auto_map": {"AutoModelForCausalLM": "qwen2.modeling_qwen2.Qwen2ForCausalLM"}, } ), encoding="utf-8", ) (output_dir / "tokenizer_config.json").write_text( json.dumps({"tokenizer_class": "Qwen2TokenizerFast"}), encoding="utf-8", ) (output_dir / "adapter_config.json").write_text( json.dumps( { "base_model_class": "Qwen2ForCausalLM", "parent_library": "transformers.models.qwen2.modeling_qwen2", } ), encoding="utf-8", ) write_maris_compatibility_artifact(output_dir, maris_model_id="MarisUK/maris-ai-master") apply_maris_compatibility_identity(output_dir) manifest = json.loads( (output_dir / MARIS_COMPATIBILITY_ARTIFACT_NAME).read_text(encoding="utf-8") ) config_payload = json.loads((output_dir / "config.json").read_text(encoding="utf-8")) tokenizer_payload = json.loads( (output_dir / "tokenizer_config.json").read_text(encoding="utf-8") ) adapter_payload = json.loads((output_dir / "adapter_config.json").read_text(encoding="utf-8")) assert manifest["artifact_type"] == "maris-hf-compatibility" assert sorted(manifest["artifacts"]) == [ "adapter_config.json", "config.json", "tokenizer_config.json", ] assert config_payload["model_type"] == "maris" assert config_payload["architectures"] == ["MarisCompatibleCausalLM"] assert tokenizer_payload["tokenizer_class"] == "MarisCompatibleTokenizer" assert adapter_payload["base_model_class"] == "MarisCompatibleCausalLM" assert adapter_payload["parent_library"] == "maris.compat" def test_maris_hf_compatible_path_restores_remote_snapshot(monkeypatch, tmp_path: Path) -> None: snapshot_dir = tmp_path / "snapshot" snapshot_dir.mkdir() (snapshot_dir / "config.json").write_text( json.dumps( { "_name_or_path": "MarisUK/maris-ai-master", "model_type": "qwen2", "architectures": ["Qwen2ForCausalLM"], } ), encoding="utf-8", ) write_maris_compatibility_artifact(snapshot_dir, maris_model_id="MarisUK/maris-ai-master") apply_maris_compatibility_identity(snapshot_dir) monkeypatch.setitem( sys.modules, "huggingface_hub", types.SimpleNamespace( snapshot_download=lambda **kwargs: ( str(snapshot_dir) if kwargs["repo_id"] == "MarisUK/maris-ai-master" and kwargs["repo_type"] == "model" else None ) ), ) monkeypatch.setenv("MARIS_HF_COMPAT_ALLOW_REMOTE_SNAPSHOT", "true") with maris_hf_compatible_path("MarisUK/maris-ai-master") as compatible_path: restored_dir = Path(compatible_path) assert restored_dir != snapshot_dir restored_config = json.loads( restored_dir.joinpath("config.json").read_text(encoding="utf-8") ) assert restored_config["model_type"] == "qwen2" assert restored_config["architectures"] == ["Qwen2ForCausalLM"] original_config = json.loads((snapshot_dir / "config.json").read_text(encoding="utf-8")) assert original_config["model_type"] == "maris" def test_maris_hf_compatible_path_can_force_remote_snapshot_restore( monkeypatch, tmp_path: Path ) -> None: snapshot_dir = tmp_path / "snapshot" snapshot_dir.mkdir() (snapshot_dir / "config.json").write_text( json.dumps( { "_name_or_path": "MarisUK/maris-ai-master", "model_type": "qwen2", "architectures": ["Qwen2ForCausalLM"], } ), encoding="utf-8", ) write_maris_compatibility_artifact(snapshot_dir, maris_model_id="MarisUK/maris-ai-master") apply_maris_compatibility_identity(snapshot_dir) monkeypatch.setitem( sys.modules, "huggingface_hub", types.SimpleNamespace( snapshot_download=lambda **kwargs: ( str(snapshot_dir) if kwargs["repo_id"] == "MarisUK/maris-ai-master" and kwargs["repo_type"] == "model" else None ) ), ) with maris_hf_compatible_path( "MarisUK/maris-ai-master", allow_remote_snapshot=True ) as compatible_path: restored_dir = Path(compatible_path) restored_config = json.loads( restored_dir.joinpath("config.json").read_text(encoding="utf-8") ) assert restored_config["model_type"] == "qwen2" assert restored_config["architectures"] == ["Qwen2ForCausalLM"] def test_maris_hf_compatible_path_restores_custom_runtime_snapshot( monkeypatch, tmp_path: Path ) -> None: snapshot_dir = tmp_path / "custom-runtime" snapshot_dir.mkdir() (snapshot_dir / "config.json").write_text( json.dumps( { "_name_or_path": "custom-user/maris-runtime", "model_type": "llama", "architectures": ["LlamaForCausalLM"], } ), encoding="utf-8", ) write_maris_compatibility_artifact(snapshot_dir, maris_model_id="custom-user/maris-runtime") apply_maris_compatibility_identity(snapshot_dir) monkeypatch.setitem( sys.modules, "huggingface_hub", types.SimpleNamespace( snapshot_download=lambda **kwargs: ( str(snapshot_dir) if kwargs["repo_id"] == "custom-user/maris-runtime" and kwargs["repo_type"] == "model" else None ) ), ) monkeypatch.setenv("MARIS_RUNTIME_COMPAT_ALLOW_REMOTE_SNAPSHOT", "true") with maris_hf_compatible_path("custom-user/maris-runtime") as compatible_path: restored_dir = Path(compatible_path) restored_config = json.loads( restored_dir.joinpath("config.json").read_text(encoding="utf-8") ) assert restored_config["model_type"] == "llama" assert restored_config["architectures"] == ["LlamaForCausalLM"] def test_maris_hf_compatible_path_returns_remote_snapshot_path_without_restore_artifact( monkeypatch, tmp_path: Path, ) -> None: snapshot_dir = tmp_path / "plain-runtime" snapshot_dir.mkdir() (snapshot_dir / "config.json").write_text(json.dumps({"model_type": "llama"}), encoding="utf-8") monkeypatch.setitem( sys.modules, "huggingface_hub", types.SimpleNamespace( snapshot_download=lambda **kwargs: ( str(snapshot_dir) if kwargs["repo_id"] == "custom-user/plain-runtime" and kwargs["repo_type"] == "model" else None ) ), ) monkeypatch.setenv("MARIS_RUNTIME_COMPAT_ALLOW_REMOTE_SNAPSHOT", "true") with maris_hf_compatible_path("custom-user/plain-runtime") as compatible_path: assert compatible_path == str(snapshot_dir)