leaderboard / tests /test_models.py
hotchpotch's picture
Deploy security hardening
7ff9600 verified
from __future__ import annotations
import argparse
import pytest
import torch
from hakari_bench.models import (
ModelLoadConfig,
collect_model_metadata,
load_model,
resolve_model_revision,
resolve_attn_implementation,
resolve_torch_dtype,
)
def test_resolve_torch_dtype_defaults_to_bf16() -> None:
assert resolve_torch_dtype("bf16") is torch.bfloat16
assert resolve_torch_dtype("fp16") is torch.float16
assert resolve_torch_dtype("fp32") is torch.float32
def test_resolve_attn_implementation_rejects_flash_attn_conflict() -> None:
with pytest.raises(ValueError):
resolve_attn_implementation(attn_implementation="sdpa", flash_attn2=True)
def test_load_model_passes_dense_options(monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[dict[str, object]] = []
class FakeSentenceTransformer(torch.nn.Module):
def __init__(self, model_name_or_path: str, **kwargs: object) -> None:
super().__init__()
calls.append({"model_name_or_path": model_name_or_path, **kwargs})
self.max_seq_length = None
self.projection = torch.nn.Linear(2, 2)
self.inner = FakeCustomModule()
monkeypatch.setattr("hakari_bench.models._import_sentence_transformer", lambda: FakeSentenceTransformer)
model = load_model(
ModelLoadConfig(
model_name_or_path="hotchpotch/model",
model_type="dense",
dtype="bf16",
attn_implementation=None,
flash_attn2=True,
device="cpu",
trust_remote_code=True,
max_seq_length=128,
)
)
assert isinstance(model, FakeSentenceTransformer)
assert model.max_seq_length == 128
assert model.projection.weight.dtype is torch.bfloat16
assert model.inner.model.config._attn_implementation == "flash_attention_2"
assert calls == [
{
"model_name_or_path": "hotchpotch/model",
"device": "cpu",
"revision": None,
"trust_remote_code": True,
"model_kwargs": {
"torch_dtype": torch.bfloat16,
"attn_implementation": "flash_attention_2",
},
}
]
def test_load_model_passes_late_interaction_options(monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[dict[str, object]] = []
class FakeColBERT(torch.nn.Module):
def __init__(self, model_name_or_path: str, **kwargs: object) -> None:
super().__init__()
calls.append({"model_name_or_path": model_name_or_path, **kwargs})
self.projection = torch.nn.Linear(2, 2)
monkeypatch.setattr("hakari_bench.models._import_pylate_colbert", lambda: FakeColBERT)
model = load_model(
ModelLoadConfig(
model_name_or_path="lightonai/GTE-ModernColBERT-v1",
model_type="late-interaction",
dtype="fp32",
device="cpu",
trust_remote_code=True,
late_interaction_query_length=64,
late_interaction_document_length=300,
late_interaction_query_prefix="[QueryMarker]",
late_interaction_document_prefix="[DocumentMarker]",
late_interaction_attend_to_expansion_tokens=True,
)
)
assert isinstance(model, FakeColBERT)
assert model.projection.weight.dtype is torch.float32
assert calls == [
{
"model_name_or_path": "lightonai/GTE-ModernColBERT-v1",
"device": "cpu",
"revision": None,
"trust_remote_code": True,
"model_kwargs": {"torch_dtype": torch.float32},
"query_length": 64,
"document_length": 300,
"query_prefix": "[QueryMarker]",
"document_prefix": "[DocumentMarker]",
"attend_to_expansion_tokens": True,
}
]
def test_load_model_reranker_passes_cross_encoder_kwargs(monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[dict[str, object]] = []
class FakeCrossEncoder(torch.nn.Module):
def __init__(self, model_name_or_path: str, **kwargs: object) -> None:
super().__init__()
calls.append({"model_name_or_path": model_name_or_path, **kwargs})
self.model = torch.nn.Linear(2, 1)
monkeypatch.setattr("hakari_bench.models._import_cross_encoder", lambda: FakeCrossEncoder)
model = load_model(
ModelLoadConfig(
model_name_or_path="Qwen/Qwen3-Reranker-0.6B",
model_type="reranker",
dtype="bf16",
device="cuda:0",
trust_remote_code=True,
cross_encoder_kwargs={
"prompts": {"retrieval": "Retrieve relevant passages"},
"default_prompt_name": "retrieval",
"model_kwargs": {"attn_implementation": "sdpa"},
},
)
)
assert isinstance(model, FakeCrossEncoder)
assert calls == [
{
"model_name_or_path": "Qwen/Qwen3-Reranker-0.6B",
"prompts": {"retrieval": "Retrieve relevant passages"},
"default_prompt_name": "retrieval",
"device": "cuda:0",
"revision": None,
"trust_remote_code": True,
"model_kwargs": {
"torch_dtype": torch.bfloat16,
"attn_implementation": "sdpa",
},
}
]
def test_load_model_passes_model_revision_to_huggingface_loaders(monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[dict[str, object]] = []
class FakeSentenceTransformer(torch.nn.Module):
def __init__(self, model_name_or_path: str, **kwargs: object) -> None:
super().__init__()
calls.append({"model_name_or_path": model_name_or_path, **kwargs})
self.projection = torch.nn.Linear(2, 2)
monkeypatch.setattr("hakari_bench.models._import_sentence_transformer", lambda: FakeSentenceTransformer)
load_model(
ModelLoadConfig(
model_name_or_path="hotchpotch/model",
model_type="dense",
dtype="bf16",
device="cpu",
trust_remote_code=True,
model_revision="abc123",
)
)
assert calls[0]["revision"] == "abc123"
def test_resolve_model_revision_uses_full_huggingface_sha(monkeypatch: pytest.MonkeyPatch) -> None:
class FakeInfo:
sha = "0123456789abcdef0123456789abcdef01234567"
class FakeHfApi:
def model_info(self, *, repo_id: str, revision: str | None = None) -> FakeInfo:
assert repo_id == "hotchpotch/model"
assert revision == "main"
return FakeInfo()
monkeypatch.setattr("hakari_bench.models.HfApi", FakeHfApi)
resolve_model_revision.cache_clear()
assert resolve_model_revision("hotchpotch/model", requested_revision="main") == {
"requested": "main",
"resolved": "0123456789abcdef0123456789abcdef01234567",
"source": "huggingface_hub",
}
def test_collect_model_metadata_counts_parameters() -> None:
model = torch.nn.Sequential(torch.nn.Linear(3, 2), torch.nn.Linear(2, 1))
args = argparse.Namespace(
model="toy",
model_type="dense",
dtype="bf16",
device="cpu",
trust_remote_code=False,
attn_implementation=None,
flash_attn2=False,
model_revision=None,
)
metadata = collect_model_metadata(model, args)
assert metadata["method"] == "dense"
assert metadata["id"] == "toy"
assert metadata["source"] == {"type": "huggingface", "name": "toy"}
assert metadata["total_parameters"] == 11
assert metadata["trainable_parameters"] == 11
assert metadata["embedding_parameters"] is None
assert metadata["active_parameters"] is None
def test_collect_model_metadata_records_model_revision(monkeypatch: pytest.MonkeyPatch) -> None:
model = torch.nn.Sequential(torch.nn.Linear(3, 2))
args = argparse.Namespace(
model="hotchpotch/model",
model_type="dense",
dtype="bf16",
device="cpu",
trust_remote_code=False,
attn_implementation=None,
flash_attn2=False,
model_revision="main",
model_source={"type": "huggingface", "name": "hotchpotch/model", "revision_requested": "main"},
)
monkeypatch.setattr(
"hakari_bench.models.resolve_model_revision",
lambda model_id, requested_revision=None: {
"requested": requested_revision,
"resolved": "0123456789abcdef0123456789abcdef01234567",
"source": "huggingface_hub",
},
)
metadata = collect_model_metadata(model, args)
assert metadata["source"] == {
"type": "huggingface",
"name": "hotchpotch/model",
"revision_requested": "main",
"revision": "0123456789abcdef0123456789abcdef01234567",
}
def test_collect_model_metadata_counts_active_parameters_from_input_embeddings() -> None:
model = _FakeSentenceTransformerLike(backbone_attr="auto_model")
args = _metadata_args()
metadata = collect_model_metadata(model, args)
assert metadata["total_parameters"] == 38
assert metadata["trainable_parameters"] == 38
assert metadata["embedding_parameters"] == 20
assert metadata["active_parameters"] == 18
def test_collect_model_metadata_counts_active_parameters_from_st_model_attribute() -> None:
model = _FakeSentenceTransformerLike(backbone_attr="model")
args = _metadata_args()
metadata = collect_model_metadata(model, args)
assert metadata["total_parameters"] == 38
assert metadata["embedding_parameters"] == 20
assert metadata["active_parameters"] == 18
def test_collect_model_metadata_counts_static_embedding_table_as_input_embeddings() -> None:
model = _FakeStaticSentenceTransformerLike()
args = _metadata_args()
metadata = collect_model_metadata(model, args)
assert metadata["total_parameters"] == 20
assert metadata["embedding_parameters"] == 20
assert metadata["active_parameters"] == 0
def test_collect_model_metadata_records_late_interaction_metadata() -> None:
model = _FakeLateInteractionModel()
args = _metadata_args()
args.model_type = "late-interaction"
metadata = collect_model_metadata(model, args)
assert metadata["backend_library"] == "pylate"
assert metadata["similarity_fn_name"] == "MaxSim"
assert metadata["late_interaction"] == {
"architecture": "colbert",
"scoring": "maxsim",
"query_prefix": "[Q] ",
"document_prefix": "[D] ",
"query_length": 32,
"document_length": 300,
"do_query_expansion": True,
"attend_to_expansion_tokens": False,
}
def _metadata_args() -> argparse.Namespace:
return argparse.Namespace(
model="toy",
model_type="dense",
dtype="bf16",
device="cpu",
trust_remote_code=False,
attn_implementation=None,
flash_attn2=False,
)
class _FakeSentenceTransformerLike(torch.nn.Module):
def __init__(self, *, backbone_attr: str) -> None:
super().__init__()
self.add_module("0", _FakeTransformerModule(backbone_attr=backbone_attr))
def __getitem__(self, index: int) -> torch.nn.Module:
module = self._modules[str(index)]
assert isinstance(module, torch.nn.Module)
return module
class _FakeStaticSentenceTransformerLike(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.add_module("0", _FakeStaticEmbeddingModule())
def __getitem__(self, index: int) -> torch.nn.Module:
module = self._modules[str(index)]
assert isinstance(module, torch.nn.Module)
return module
class _FakeStaticEmbeddingModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.embedding = torch.nn.Embedding(5, 4)
class _FakeLateInteractionModel(torch.nn.Module):
similarity_fn_name = "MaxSim"
max_seq_length = None
query_prefix = "[Q] "
document_prefix = "[D] "
query_length = 32
document_length = 300
do_query_expansion = True
attend_to_expansion_tokens = False
def __init__(self) -> None:
super().__init__()
self.projection = torch.nn.Linear(2, 2)
class _FakeTransformerModule(torch.nn.Module):
def __init__(self, *, backbone_attr: str) -> None:
super().__init__()
self.add_module(backbone_attr, _FakeBackbone())
class _FakeBackbone(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.embeddings = torch.nn.Module()
self.input_embeddings = torch.nn.Embedding(5, 4)
self.embeddings.tok_embeddings = self.input_embeddings
self.layers = torch.nn.ModuleList([torch.nn.Linear(4, 3, bias=False)])
self.final_norm = torch.nn.LayerNorm(3)
def get_input_embeddings(self) -> torch.nn.Embedding:
return self.input_embeddings
class _FakeConfig:
def __init__(self) -> None:
self._attn_implementation = "sdpa"
class _FakeNestedModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.config = _FakeConfig()
class FakeCustomModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.model = _FakeNestedModel()