Spaces:
Running
Running
| from __future__ import annotations | |
| import argparse | |
| import pytest | |
| import torch | |
| from hakari_bench.models import ( | |
| ModelLoadConfig, | |
| collect_model_metadata, | |
| load_model, | |
| resolve_model_revision, | |
| resolve_attn_implementation, | |
| resolve_torch_dtype, | |
| ) | |
| def test_resolve_torch_dtype_defaults_to_bf16() -> None: | |
| assert resolve_torch_dtype("bf16") is torch.bfloat16 | |
| assert resolve_torch_dtype("fp16") is torch.float16 | |
| assert resolve_torch_dtype("fp32") is torch.float32 | |
| def test_resolve_attn_implementation_rejects_flash_attn_conflict() -> None: | |
| with pytest.raises(ValueError): | |
| resolve_attn_implementation(attn_implementation="sdpa", flash_attn2=True) | |
| def test_load_model_passes_dense_options(monkeypatch: pytest.MonkeyPatch) -> None: | |
| calls: list[dict[str, object]] = [] | |
| class FakeSentenceTransformer(torch.nn.Module): | |
| def __init__(self, model_name_or_path: str, **kwargs: object) -> None: | |
| super().__init__() | |
| calls.append({"model_name_or_path": model_name_or_path, **kwargs}) | |
| self.max_seq_length = None | |
| self.projection = torch.nn.Linear(2, 2) | |
| self.inner = FakeCustomModule() | |
| monkeypatch.setattr("hakari_bench.models._import_sentence_transformer", lambda: FakeSentenceTransformer) | |
| model = load_model( | |
| ModelLoadConfig( | |
| model_name_or_path="hotchpotch/model", | |
| model_type="dense", | |
| dtype="bf16", | |
| attn_implementation=None, | |
| flash_attn2=True, | |
| device="cpu", | |
| trust_remote_code=True, | |
| max_seq_length=128, | |
| ) | |
| ) | |
| assert isinstance(model, FakeSentenceTransformer) | |
| assert model.max_seq_length == 128 | |
| assert model.projection.weight.dtype is torch.bfloat16 | |
| assert model.inner.model.config._attn_implementation == "flash_attention_2" | |
| assert calls == [ | |
| { | |
| "model_name_or_path": "hotchpotch/model", | |
| "device": "cpu", | |
| "revision": None, | |
| "trust_remote_code": True, | |
| "model_kwargs": { | |
| "torch_dtype": torch.bfloat16, | |
| "attn_implementation": "flash_attention_2", | |
| }, | |
| } | |
| ] | |
| def test_load_model_passes_late_interaction_options(monkeypatch: pytest.MonkeyPatch) -> None: | |
| calls: list[dict[str, object]] = [] | |
| class FakeColBERT(torch.nn.Module): | |
| def __init__(self, model_name_or_path: str, **kwargs: object) -> None: | |
| super().__init__() | |
| calls.append({"model_name_or_path": model_name_or_path, **kwargs}) | |
| self.projection = torch.nn.Linear(2, 2) | |
| monkeypatch.setattr("hakari_bench.models._import_pylate_colbert", lambda: FakeColBERT) | |
| model = load_model( | |
| ModelLoadConfig( | |
| model_name_or_path="lightonai/GTE-ModernColBERT-v1", | |
| model_type="late-interaction", | |
| dtype="fp32", | |
| device="cpu", | |
| trust_remote_code=True, | |
| late_interaction_query_length=64, | |
| late_interaction_document_length=300, | |
| late_interaction_query_prefix="[QueryMarker]", | |
| late_interaction_document_prefix="[DocumentMarker]", | |
| late_interaction_attend_to_expansion_tokens=True, | |
| ) | |
| ) | |
| assert isinstance(model, FakeColBERT) | |
| assert model.projection.weight.dtype is torch.float32 | |
| assert calls == [ | |
| { | |
| "model_name_or_path": "lightonai/GTE-ModernColBERT-v1", | |
| "device": "cpu", | |
| "revision": None, | |
| "trust_remote_code": True, | |
| "model_kwargs": {"torch_dtype": torch.float32}, | |
| "query_length": 64, | |
| "document_length": 300, | |
| "query_prefix": "[QueryMarker]", | |
| "document_prefix": "[DocumentMarker]", | |
| "attend_to_expansion_tokens": True, | |
| } | |
| ] | |
| def test_load_model_reranker_passes_cross_encoder_kwargs(monkeypatch: pytest.MonkeyPatch) -> None: | |
| calls: list[dict[str, object]] = [] | |
| class FakeCrossEncoder(torch.nn.Module): | |
| def __init__(self, model_name_or_path: str, **kwargs: object) -> None: | |
| super().__init__() | |
| calls.append({"model_name_or_path": model_name_or_path, **kwargs}) | |
| self.model = torch.nn.Linear(2, 1) | |
| monkeypatch.setattr("hakari_bench.models._import_cross_encoder", lambda: FakeCrossEncoder) | |
| model = load_model( | |
| ModelLoadConfig( | |
| model_name_or_path="Qwen/Qwen3-Reranker-0.6B", | |
| model_type="reranker", | |
| dtype="bf16", | |
| device="cuda:0", | |
| trust_remote_code=True, | |
| cross_encoder_kwargs={ | |
| "prompts": {"retrieval": "Retrieve relevant passages"}, | |
| "default_prompt_name": "retrieval", | |
| "model_kwargs": {"attn_implementation": "sdpa"}, | |
| }, | |
| ) | |
| ) | |
| assert isinstance(model, FakeCrossEncoder) | |
| assert calls == [ | |
| { | |
| "model_name_or_path": "Qwen/Qwen3-Reranker-0.6B", | |
| "prompts": {"retrieval": "Retrieve relevant passages"}, | |
| "default_prompt_name": "retrieval", | |
| "device": "cuda:0", | |
| "revision": None, | |
| "trust_remote_code": True, | |
| "model_kwargs": { | |
| "torch_dtype": torch.bfloat16, | |
| "attn_implementation": "sdpa", | |
| }, | |
| } | |
| ] | |
| def test_load_model_passes_model_revision_to_huggingface_loaders(monkeypatch: pytest.MonkeyPatch) -> None: | |
| calls: list[dict[str, object]] = [] | |
| class FakeSentenceTransformer(torch.nn.Module): | |
| def __init__(self, model_name_or_path: str, **kwargs: object) -> None: | |
| super().__init__() | |
| calls.append({"model_name_or_path": model_name_or_path, **kwargs}) | |
| self.projection = torch.nn.Linear(2, 2) | |
| monkeypatch.setattr("hakari_bench.models._import_sentence_transformer", lambda: FakeSentenceTransformer) | |
| load_model( | |
| ModelLoadConfig( | |
| model_name_or_path="hotchpotch/model", | |
| model_type="dense", | |
| dtype="bf16", | |
| device="cpu", | |
| trust_remote_code=True, | |
| model_revision="abc123", | |
| ) | |
| ) | |
| assert calls[0]["revision"] == "abc123" | |
| def test_resolve_model_revision_uses_full_huggingface_sha(monkeypatch: pytest.MonkeyPatch) -> None: | |
| class FakeInfo: | |
| sha = "0123456789abcdef0123456789abcdef01234567" | |
| class FakeHfApi: | |
| def model_info(self, *, repo_id: str, revision: str | None = None) -> FakeInfo: | |
| assert repo_id == "hotchpotch/model" | |
| assert revision == "main" | |
| return FakeInfo() | |
| monkeypatch.setattr("hakari_bench.models.HfApi", FakeHfApi) | |
| resolve_model_revision.cache_clear() | |
| assert resolve_model_revision("hotchpotch/model", requested_revision="main") == { | |
| "requested": "main", | |
| "resolved": "0123456789abcdef0123456789abcdef01234567", | |
| "source": "huggingface_hub", | |
| } | |
| def test_collect_model_metadata_counts_parameters() -> None: | |
| model = torch.nn.Sequential(torch.nn.Linear(3, 2), torch.nn.Linear(2, 1)) | |
| args = argparse.Namespace( | |
| model="toy", | |
| model_type="dense", | |
| dtype="bf16", | |
| device="cpu", | |
| trust_remote_code=False, | |
| attn_implementation=None, | |
| flash_attn2=False, | |
| model_revision=None, | |
| ) | |
| metadata = collect_model_metadata(model, args) | |
| assert metadata["method"] == "dense" | |
| assert metadata["id"] == "toy" | |
| assert metadata["source"] == {"type": "huggingface", "name": "toy"} | |
| assert metadata["total_parameters"] == 11 | |
| assert metadata["trainable_parameters"] == 11 | |
| assert metadata["embedding_parameters"] is None | |
| assert metadata["active_parameters"] is None | |
| def test_collect_model_metadata_records_model_revision(monkeypatch: pytest.MonkeyPatch) -> None: | |
| model = torch.nn.Sequential(torch.nn.Linear(3, 2)) | |
| args = argparse.Namespace( | |
| model="hotchpotch/model", | |
| model_type="dense", | |
| dtype="bf16", | |
| device="cpu", | |
| trust_remote_code=False, | |
| attn_implementation=None, | |
| flash_attn2=False, | |
| model_revision="main", | |
| model_source={"type": "huggingface", "name": "hotchpotch/model", "revision_requested": "main"}, | |
| ) | |
| monkeypatch.setattr( | |
| "hakari_bench.models.resolve_model_revision", | |
| lambda model_id, requested_revision=None: { | |
| "requested": requested_revision, | |
| "resolved": "0123456789abcdef0123456789abcdef01234567", | |
| "source": "huggingface_hub", | |
| }, | |
| ) | |
| metadata = collect_model_metadata(model, args) | |
| assert metadata["source"] == { | |
| "type": "huggingface", | |
| "name": "hotchpotch/model", | |
| "revision_requested": "main", | |
| "revision": "0123456789abcdef0123456789abcdef01234567", | |
| } | |
| def test_collect_model_metadata_counts_active_parameters_from_input_embeddings() -> None: | |
| model = _FakeSentenceTransformerLike(backbone_attr="auto_model") | |
| args = _metadata_args() | |
| metadata = collect_model_metadata(model, args) | |
| assert metadata["total_parameters"] == 38 | |
| assert metadata["trainable_parameters"] == 38 | |
| assert metadata["embedding_parameters"] == 20 | |
| assert metadata["active_parameters"] == 18 | |
| def test_collect_model_metadata_counts_active_parameters_from_st_model_attribute() -> None: | |
| model = _FakeSentenceTransformerLike(backbone_attr="model") | |
| args = _metadata_args() | |
| metadata = collect_model_metadata(model, args) | |
| assert metadata["total_parameters"] == 38 | |
| assert metadata["embedding_parameters"] == 20 | |
| assert metadata["active_parameters"] == 18 | |
| def test_collect_model_metadata_counts_static_embedding_table_as_input_embeddings() -> None: | |
| model = _FakeStaticSentenceTransformerLike() | |
| args = _metadata_args() | |
| metadata = collect_model_metadata(model, args) | |
| assert metadata["total_parameters"] == 20 | |
| assert metadata["embedding_parameters"] == 20 | |
| assert metadata["active_parameters"] == 0 | |
| def test_collect_model_metadata_records_late_interaction_metadata() -> None: | |
| model = _FakeLateInteractionModel() | |
| args = _metadata_args() | |
| args.model_type = "late-interaction" | |
| metadata = collect_model_metadata(model, args) | |
| assert metadata["backend_library"] == "pylate" | |
| assert metadata["similarity_fn_name"] == "MaxSim" | |
| assert metadata["late_interaction"] == { | |
| "architecture": "colbert", | |
| "scoring": "maxsim", | |
| "query_prefix": "[Q] ", | |
| "document_prefix": "[D] ", | |
| "query_length": 32, | |
| "document_length": 300, | |
| "do_query_expansion": True, | |
| "attend_to_expansion_tokens": False, | |
| } | |
| def _metadata_args() -> argparse.Namespace: | |
| return argparse.Namespace( | |
| model="toy", | |
| model_type="dense", | |
| dtype="bf16", | |
| device="cpu", | |
| trust_remote_code=False, | |
| attn_implementation=None, | |
| flash_attn2=False, | |
| ) | |
| class _FakeSentenceTransformerLike(torch.nn.Module): | |
| def __init__(self, *, backbone_attr: str) -> None: | |
| super().__init__() | |
| self.add_module("0", _FakeTransformerModule(backbone_attr=backbone_attr)) | |
| def __getitem__(self, index: int) -> torch.nn.Module: | |
| module = self._modules[str(index)] | |
| assert isinstance(module, torch.nn.Module) | |
| return module | |
| class _FakeStaticSentenceTransformerLike(torch.nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.add_module("0", _FakeStaticEmbeddingModule()) | |
| def __getitem__(self, index: int) -> torch.nn.Module: | |
| module = self._modules[str(index)] | |
| assert isinstance(module, torch.nn.Module) | |
| return module | |
| class _FakeStaticEmbeddingModule(torch.nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.embedding = torch.nn.Embedding(5, 4) | |
| class _FakeLateInteractionModel(torch.nn.Module): | |
| similarity_fn_name = "MaxSim" | |
| max_seq_length = None | |
| query_prefix = "[Q] " | |
| document_prefix = "[D] " | |
| query_length = 32 | |
| document_length = 300 | |
| do_query_expansion = True | |
| attend_to_expansion_tokens = False | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.projection = torch.nn.Linear(2, 2) | |
| class _FakeTransformerModule(torch.nn.Module): | |
| def __init__(self, *, backbone_attr: str) -> None: | |
| super().__init__() | |
| self.add_module(backbone_attr, _FakeBackbone()) | |
| class _FakeBackbone(torch.nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.embeddings = torch.nn.Module() | |
| self.input_embeddings = torch.nn.Embedding(5, 4) | |
| self.embeddings.tok_embeddings = self.input_embeddings | |
| self.layers = torch.nn.ModuleList([torch.nn.Linear(4, 3, bias=False)]) | |
| self.final_norm = torch.nn.LayerNorm(3) | |
| def get_input_embeddings(self) -> torch.nn.Embedding: | |
| return self.input_embeddings | |
| class _FakeConfig: | |
| def __init__(self) -> None: | |
| self._attn_implementation = "sdpa" | |
| class _FakeNestedModel(torch.nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.config = _FakeConfig() | |
| class FakeCustomModule(torch.nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.model = _FakeNestedModel() | |