| """Tests Maris Hugging Face compatibility layer.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import sys |
| import types |
| from pathlib import Path |
|
|
| from maris_core.training.hf_compat import ( |
| MARIS_COMPATIBILITY_ARTIFACT_NAME, |
| apply_maris_compatibility_identity, |
| maris_hf_compatible_path, |
| write_maris_compatibility_artifact, |
| ) |
|
|
|
|
| def test_write_maris_compatibility_artifact_sanitizes_loader_fields(tmp_path: Path) -> None: |
| output_dir = tmp_path / "model" |
| output_dir.mkdir() |
| (output_dir / "config.json").write_text( |
| json.dumps( |
| { |
| "_name_or_path": "MarisUK/maris-ai-master", |
| "model_type": "qwen2", |
| "architectures": ["Qwen2ForCausalLM"], |
| "tokenizer_class": "Qwen2TokenizerFast", |
| "auto_map": {"AutoModelForCausalLM": "qwen2.modeling_qwen2.Qwen2ForCausalLM"}, |
| } |
| ), |
| encoding="utf-8", |
| ) |
| (output_dir / "tokenizer_config.json").write_text( |
| json.dumps({"tokenizer_class": "Qwen2TokenizerFast"}), |
| encoding="utf-8", |
| ) |
| (output_dir / "adapter_config.json").write_text( |
| json.dumps( |
| { |
| "base_model_class": "Qwen2ForCausalLM", |
| "parent_library": "transformers.models.qwen2.modeling_qwen2", |
| } |
| ), |
| encoding="utf-8", |
| ) |
|
|
| write_maris_compatibility_artifact(output_dir, maris_model_id="MarisUK/maris-ai-master") |
| apply_maris_compatibility_identity(output_dir) |
|
|
| manifest = json.loads( |
| (output_dir / MARIS_COMPATIBILITY_ARTIFACT_NAME).read_text(encoding="utf-8") |
| ) |
| config_payload = json.loads((output_dir / "config.json").read_text(encoding="utf-8")) |
| tokenizer_payload = json.loads( |
| (output_dir / "tokenizer_config.json").read_text(encoding="utf-8") |
| ) |
| adapter_payload = json.loads((output_dir / "adapter_config.json").read_text(encoding="utf-8")) |
|
|
| assert manifest["artifact_type"] == "maris-hf-compatibility" |
| assert sorted(manifest["artifacts"]) == [ |
| "adapter_config.json", |
| "config.json", |
| "tokenizer_config.json", |
| ] |
| assert config_payload["model_type"] == "maris" |
| assert config_payload["architectures"] == ["MarisCompatibleCausalLM"] |
| assert tokenizer_payload["tokenizer_class"] == "MarisCompatibleTokenizer" |
| assert adapter_payload["base_model_class"] == "MarisCompatibleCausalLM" |
| assert adapter_payload["parent_library"] == "maris.compat" |
|
|
|
|
| def test_maris_hf_compatible_path_restores_remote_snapshot(monkeypatch, tmp_path: Path) -> None: |
| snapshot_dir = tmp_path / "snapshot" |
| snapshot_dir.mkdir() |
| (snapshot_dir / "config.json").write_text( |
| json.dumps( |
| { |
| "_name_or_path": "MarisUK/maris-ai-master", |
| "model_type": "qwen2", |
| "architectures": ["Qwen2ForCausalLM"], |
| } |
| ), |
| encoding="utf-8", |
| ) |
| write_maris_compatibility_artifact(snapshot_dir, maris_model_id="MarisUK/maris-ai-master") |
| apply_maris_compatibility_identity(snapshot_dir) |
|
|
| monkeypatch.setitem( |
| sys.modules, |
| "huggingface_hub", |
| types.SimpleNamespace( |
| snapshot_download=lambda **kwargs: ( |
| str(snapshot_dir) |
| if kwargs["repo_id"] == "MarisUK/maris-ai-master" and kwargs["repo_type"] == "model" |
| else None |
| ) |
| ), |
| ) |
| monkeypatch.setenv("MARIS_HF_COMPAT_ALLOW_REMOTE_SNAPSHOT", "true") |
|
|
| with maris_hf_compatible_path("MarisUK/maris-ai-master") as compatible_path: |
| restored_dir = Path(compatible_path) |
| assert restored_dir != snapshot_dir |
| restored_config = json.loads( |
| restored_dir.joinpath("config.json").read_text(encoding="utf-8") |
| ) |
| assert restored_config["model_type"] == "qwen2" |
| assert restored_config["architectures"] == ["Qwen2ForCausalLM"] |
|
|
| original_config = json.loads((snapshot_dir / "config.json").read_text(encoding="utf-8")) |
| assert original_config["model_type"] == "maris" |
|
|
|
|
| def test_maris_hf_compatible_path_can_force_remote_snapshot_restore( |
| monkeypatch, tmp_path: Path |
| ) -> None: |
| snapshot_dir = tmp_path / "snapshot" |
| snapshot_dir.mkdir() |
| (snapshot_dir / "config.json").write_text( |
| json.dumps( |
| { |
| "_name_or_path": "MarisUK/maris-ai-master", |
| "model_type": "qwen2", |
| "architectures": ["Qwen2ForCausalLM"], |
| } |
| ), |
| encoding="utf-8", |
| ) |
| write_maris_compatibility_artifact(snapshot_dir, maris_model_id="MarisUK/maris-ai-master") |
| apply_maris_compatibility_identity(snapshot_dir) |
|
|
| monkeypatch.setitem( |
| sys.modules, |
| "huggingface_hub", |
| types.SimpleNamespace( |
| snapshot_download=lambda **kwargs: ( |
| str(snapshot_dir) |
| if kwargs["repo_id"] == "MarisUK/maris-ai-master" and kwargs["repo_type"] == "model" |
| else None |
| ) |
| ), |
| ) |
|
|
| with maris_hf_compatible_path( |
| "MarisUK/maris-ai-master", allow_remote_snapshot=True |
| ) as compatible_path: |
| restored_dir = Path(compatible_path) |
| restored_config = json.loads( |
| restored_dir.joinpath("config.json").read_text(encoding="utf-8") |
| ) |
| assert restored_config["model_type"] == "qwen2" |
| assert restored_config["architectures"] == ["Qwen2ForCausalLM"] |
|
|
|
|
| def test_maris_hf_compatible_path_restores_custom_runtime_snapshot( |
| monkeypatch, tmp_path: Path |
| ) -> None: |
| snapshot_dir = tmp_path / "custom-runtime" |
| snapshot_dir.mkdir() |
| (snapshot_dir / "config.json").write_text( |
| json.dumps( |
| { |
| "_name_or_path": "custom-user/maris-runtime", |
| "model_type": "llama", |
| "architectures": ["LlamaForCausalLM"], |
| } |
| ), |
| encoding="utf-8", |
| ) |
| write_maris_compatibility_artifact(snapshot_dir, maris_model_id="custom-user/maris-runtime") |
| apply_maris_compatibility_identity(snapshot_dir) |
|
|
| monkeypatch.setitem( |
| sys.modules, |
| "huggingface_hub", |
| types.SimpleNamespace( |
| snapshot_download=lambda **kwargs: ( |
| str(snapshot_dir) |
| if kwargs["repo_id"] == "custom-user/maris-runtime" |
| and kwargs["repo_type"] == "model" |
| else None |
| ) |
| ), |
| ) |
| monkeypatch.setenv("MARIS_RUNTIME_COMPAT_ALLOW_REMOTE_SNAPSHOT", "true") |
|
|
| with maris_hf_compatible_path("custom-user/maris-runtime") as compatible_path: |
| restored_dir = Path(compatible_path) |
| restored_config = json.loads( |
| restored_dir.joinpath("config.json").read_text(encoding="utf-8") |
| ) |
| assert restored_config["model_type"] == "llama" |
| assert restored_config["architectures"] == ["LlamaForCausalLM"] |
|
|
|
|
| def test_maris_hf_compatible_path_returns_remote_snapshot_path_without_restore_artifact( |
| monkeypatch, |
| tmp_path: Path, |
| ) -> None: |
| snapshot_dir = tmp_path / "plain-runtime" |
| snapshot_dir.mkdir() |
| (snapshot_dir / "config.json").write_text(json.dumps({"model_type": "llama"}), encoding="utf-8") |
|
|
| monkeypatch.setitem( |
| sys.modules, |
| "huggingface_hub", |
| types.SimpleNamespace( |
| snapshot_download=lambda **kwargs: ( |
| str(snapshot_dir) |
| if kwargs["repo_id"] == "custom-user/plain-runtime" |
| and kwargs["repo_type"] == "model" |
| else None |
| ) |
| ), |
| ) |
| monkeypatch.setenv("MARIS_RUNTIME_COMPAT_ALLOW_REMOTE_SNAPSHOT", "true") |
|
|
| with maris_hf_compatible_path("custom-user/plain-runtime") as compatible_path: |
| assert compatible_path == str(snapshot_dir) |
|
|