Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import importlib.util | |
| import re | |
| from types import SimpleNamespace | |
| import pytest | |
| HAS_TORCH = importlib.util.find_spec("torch") is not None | |
| if HAS_TORCH: | |
| import torch | |
| else: # pragma: no cover - environment dependent | |
| torch = None | |
| class MockTokenizer: | |
| def __call__( | |
| self, | |
| text: str, | |
| *, | |
| add_special_tokens: bool = False, | |
| return_offsets_mapping: bool = False, | |
| return_tensors: str | None = None, | |
| ): | |
| if torch is None: # pragma: no cover - environment dependent | |
| raise RuntimeError("MockTokenizer requires torch.") | |
| matches = list(re.finditer(r"\S+", text)) | |
| input_ids = [index + 1 for index, _match in enumerate(matches)] | |
| result = { | |
| "input_ids": torch.tensor([input_ids], dtype=torch.long) | |
| if return_tensors == "pt" | |
| else [input_ids] | |
| } | |
| if return_offsets_mapping: | |
| offsets = [[match.start(), match.end()] for match in matches] | |
| result["offset_mapping"] = ( | |
| torch.tensor([offsets], dtype=torch.long) | |
| if return_tensors == "pt" | |
| else offsets | |
| ) | |
| return result | |
| if torch is not None: | |
| class FakeSelfAttention(torch.nn.Module): | |
| def __init__(self, hidden_size: int, heads: int) -> None: | |
| super().__init__() | |
| self.heads = heads | |
| self.scale = hidden_size ** -0.5 | |
| def forward(self, hidden_states, attention_mask=None, output_attentions=False, **_kwargs): | |
| scores = torch.einsum("bqd,bkd->bqk", hidden_states, hidden_states) * self.scale | |
| scores = scores.unsqueeze(1).repeat(1, self.heads, 1, 1) | |
| causal_mask = torch.triu( | |
| torch.ones( | |
| scores.shape[-2:], | |
| dtype=torch.bool, | |
| device=hidden_states.device, | |
| ), | |
| diagonal=1, | |
| ) | |
| scores = scores.masked_fill( | |
| causal_mask.unsqueeze(0).unsqueeze(0), | |
| torch.finfo(scores.dtype).min, | |
| ) | |
| if attention_mask is not None and attention_mask.dim() == 4: | |
| scores = scores + attention_mask | |
| attention = torch.softmax(scores, dim=-1) | |
| context = torch.einsum("bhqk,bkd->bhqd", attention, hidden_states).mean(dim=1) | |
| return hidden_states + context, attention | |
| class FakeLayer(torch.nn.Module): | |
| def __init__(self, hidden_size: int, heads: int) -> None: | |
| super().__init__() | |
| self.self_attn = FakeSelfAttention(hidden_size=hidden_size, heads=heads) | |
| class FakeCausalLM(torch.nn.Module): | |
| def __init__(self, vocab_size: int = 32, hidden_size: int = 8, num_layers: int = 2, heads: int = 2): | |
| super().__init__() | |
| self.config = SimpleNamespace(_attn_implementation="eager") | |
| self.embed = torch.nn.Embedding(vocab_size, hidden_size) | |
| self.model = SimpleNamespace( | |
| layers=torch.nn.ModuleList( | |
| [FakeLayer(hidden_size=hidden_size, heads=heads) for _ in range(num_layers)] | |
| ) | |
| ) | |
| self.lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False) | |
| def forward(self, input_ids, attention_mask=None, output_attentions=False, return_dict=True, **_kwargs): | |
| hidden_states = self.embed(input_ids) | |
| attentions = [] | |
| for layer in self.model.layers: | |
| hidden_states, attention = layer.self_attn( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| attentions.append(attention) | |
| logits = self.lm_head(hidden_states) | |
| if return_dict: | |
| return SimpleNamespace(logits=logits, attentions=tuple(attentions)) | |
| return logits, tuple(attentions) | |
| def mock_tokenizer() -> MockTokenizer: | |
| if torch is None: # pragma: no cover - environment dependent | |
| pytest.skip("torch not installed") | |
| return MockTokenizer() | |