blood-test-explainer / tests /test_model_paths.py
Codex
Default extraction to the fine-tuned MiniCPM-V Hub checkpoint.
c1a5fb4
Raw
History Blame Contribute Delete
3.29 kB
import sys
import tempfile
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from src.model_paths import (
BASE_HF_REPO,
DEFAULT_HF_REPO,
is_transformers_model_dir,
resolve_transformers_model_source,
)
def test_resolve_uses_hub_download_when_no_local_weights(monkeypatch):
with tempfile.TemporaryDirectory() as tmp:
empty_models = Path(tmp) / "models"
empty_models.mkdir()
monkeypatch.setenv("BTE_MODELS_DIR", str(empty_models))
monkeypatch.setenv("HF_HOME", str(Path(tmp) / "hf"))
monkeypatch.setattr(
"src.model_paths.latest_complete_snapshot",
lambda repo_id, hub_cache: None,
)
source = resolve_transformers_model_source(DEFAULT_HF_REPO)
assert source.local_files_only is False
assert source.origin == "hub-download"
assert source.model_id == DEFAULT_HF_REPO
def test_resolve_defaults_to_finetuned_repo(monkeypatch):
with tempfile.TemporaryDirectory() as tmp:
empty_models = Path(tmp) / "models"
empty_models.mkdir()
monkeypatch.setenv("BTE_MODELS_DIR", str(empty_models))
monkeypatch.setenv("HF_HOME", str(Path(tmp) / "hf"))
monkeypatch.delenv("ZEROGPU_MODEL_ID", raising=False)
monkeypatch.setattr(
"src.model_paths.latest_complete_snapshot",
lambda repo_id, hub_cache: None,
)
source = resolve_transformers_model_source()
assert source.model_id == DEFAULT_HF_REPO
assert source.origin == "hub-download"
def test_resolve_base_repo_for_eval_baseline(monkeypatch):
with tempfile.TemporaryDirectory() as tmp:
empty_models = Path(tmp) / "models"
empty_models.mkdir()
monkeypatch.setenv("BTE_MODELS_DIR", str(empty_models))
monkeypatch.setenv("HF_HOME", str(Path(tmp) / "hf"))
monkeypatch.setattr(
"src.model_paths.latest_complete_snapshot",
lambda repo_id, hub_cache: None,
)
source = resolve_transformers_model_source(BASE_HF_REPO)
assert source.model_id == BASE_HF_REPO
def test_resolve_uses_local_dir_when_complete():
with tempfile.TemporaryDirectory() as tmp:
model_dir = Path(tmp) / "MiniCPM-V-4.6"
model_dir.mkdir()
(model_dir / "config.json").write_text("{}", encoding="utf-8")
(model_dir / "model.safetensors").write_bytes(b"test")
source = resolve_transformers_model_source(str(model_dir))
assert source.local_files_only is True
assert source.origin == "local-dir"
assert Path(source.model_id) == model_dir.resolve()
def test_is_transformers_model_dir_requires_weights():
with tempfile.TemporaryDirectory() as tmp:
model_dir = Path(tmp) / "partial"
model_dir.mkdir()
(model_dir / "config.json").write_text("{}", encoding="utf-8")
assert is_transformers_model_dir(model_dir) is False
if __name__ == "__main__":
test_resolve_uses_hub_download_when_no_local_weights()
test_resolve_defaults_to_finetuned_repo()
test_resolve_base_repo_for_eval_baseline()
test_resolve_uses_local_dir_when_complete()
test_is_transformers_model_dir_requires_weights()
print("test_model_paths: ok")