from __future__ import annotations import importlib.util import re from types import SimpleNamespace import pytest HAS_TORCH = importlib.util.find_spec("torch") is not None if HAS_TORCH: import torch else: # pragma: no cover - environment dependent torch = None class MockTokenizer: def __call__( self, text: str, *, add_special_tokens: bool = False, return_offsets_mapping: bool = False, return_tensors: str | None = None, ): if torch is None: # pragma: no cover - environment dependent raise RuntimeError("MockTokenizer requires torch.") matches = list(re.finditer(r"\S+", text)) input_ids = [index + 1 for index, _match in enumerate(matches)] result = { "input_ids": torch.tensor([input_ids], dtype=torch.long) if return_tensors == "pt" else [input_ids] } if return_offsets_mapping: offsets = [[match.start(), match.end()] for match in matches] result["offset_mapping"] = ( torch.tensor([offsets], dtype=torch.long) if return_tensors == "pt" else offsets ) return result if torch is not None: class FakeSelfAttention(torch.nn.Module): def __init__(self, hidden_size: int, heads: int) -> None: super().__init__() self.heads = heads self.scale = hidden_size ** -0.5 def forward(self, hidden_states, attention_mask=None, output_attentions=False, **_kwargs): scores = torch.einsum("bqd,bkd->bqk", hidden_states, hidden_states) * self.scale scores = scores.unsqueeze(1).repeat(1, self.heads, 1, 1) causal_mask = torch.triu( torch.ones( scores.shape[-2:], dtype=torch.bool, device=hidden_states.device, ), diagonal=1, ) scores = scores.masked_fill( causal_mask.unsqueeze(0).unsqueeze(0), torch.finfo(scores.dtype).min, ) if attention_mask is not None and attention_mask.dim() == 4: scores = scores + attention_mask attention = torch.softmax(scores, dim=-1) context = torch.einsum("bhqk,bkd->bhqd", attention, hidden_states).mean(dim=1) return hidden_states + context, attention class FakeLayer(torch.nn.Module): def __init__(self, hidden_size: int, heads: int) -> None: super().__init__() self.self_attn = FakeSelfAttention(hidden_size=hidden_size, heads=heads) class FakeCausalLM(torch.nn.Module): def __init__(self, vocab_size: int = 32, hidden_size: int = 8, num_layers: int = 2, heads: int = 2): super().__init__() self.config = SimpleNamespace(_attn_implementation="eager") self.embed = torch.nn.Embedding(vocab_size, hidden_size) self.model = SimpleNamespace( layers=torch.nn.ModuleList( [FakeLayer(hidden_size=hidden_size, heads=heads) for _ in range(num_layers)] ) ) self.lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False) def forward(self, input_ids, attention_mask=None, output_attentions=False, return_dict=True, **_kwargs): hidden_states = self.embed(input_ids) attentions = [] for layer in self.model.layers: hidden_states, attention = layer.self_attn( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, ) attentions.append(attention) logits = self.lm_head(hidden_states) if return_dict: return SimpleNamespace(logits=logits, attentions=tuple(attentions)) return logits, tuple(attentions) @pytest.fixture() def mock_tokenizer() -> MockTokenizer: if torch is None: # pragma: no cover - environment dependent pytest.skip("torch not installed") return MockTokenizer()