Spaces:
Sleeping
Sleeping
File size: 4,213 Bytes
fda8fb3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | from __future__ import annotations
import importlib.util
import re
from types import SimpleNamespace
import pytest
HAS_TORCH = importlib.util.find_spec("torch") is not None
if HAS_TORCH:
import torch
else: # pragma: no cover - environment dependent
torch = None
class MockTokenizer:
def __call__(
self,
text: str,
*,
add_special_tokens: bool = False,
return_offsets_mapping: bool = False,
return_tensors: str | None = None,
):
if torch is None: # pragma: no cover - environment dependent
raise RuntimeError("MockTokenizer requires torch.")
matches = list(re.finditer(r"\S+", text))
input_ids = [index + 1 for index, _match in enumerate(matches)]
result = {
"input_ids": torch.tensor([input_ids], dtype=torch.long)
if return_tensors == "pt"
else [input_ids]
}
if return_offsets_mapping:
offsets = [[match.start(), match.end()] for match in matches]
result["offset_mapping"] = (
torch.tensor([offsets], dtype=torch.long)
if return_tensors == "pt"
else offsets
)
return result
if torch is not None:
class FakeSelfAttention(torch.nn.Module):
def __init__(self, hidden_size: int, heads: int) -> None:
super().__init__()
self.heads = heads
self.scale = hidden_size ** -0.5
def forward(self, hidden_states, attention_mask=None, output_attentions=False, **_kwargs):
scores = torch.einsum("bqd,bkd->bqk", hidden_states, hidden_states) * self.scale
scores = scores.unsqueeze(1).repeat(1, self.heads, 1, 1)
causal_mask = torch.triu(
torch.ones(
scores.shape[-2:],
dtype=torch.bool,
device=hidden_states.device,
),
diagonal=1,
)
scores = scores.masked_fill(
causal_mask.unsqueeze(0).unsqueeze(0),
torch.finfo(scores.dtype).min,
)
if attention_mask is not None and attention_mask.dim() == 4:
scores = scores + attention_mask
attention = torch.softmax(scores, dim=-1)
context = torch.einsum("bhqk,bkd->bhqd", attention, hidden_states).mean(dim=1)
return hidden_states + context, attention
class FakeLayer(torch.nn.Module):
def __init__(self, hidden_size: int, heads: int) -> None:
super().__init__()
self.self_attn = FakeSelfAttention(hidden_size=hidden_size, heads=heads)
class FakeCausalLM(torch.nn.Module):
def __init__(self, vocab_size: int = 32, hidden_size: int = 8, num_layers: int = 2, heads: int = 2):
super().__init__()
self.config = SimpleNamespace(_attn_implementation="eager")
self.embed = torch.nn.Embedding(vocab_size, hidden_size)
self.model = SimpleNamespace(
layers=torch.nn.ModuleList(
[FakeLayer(hidden_size=hidden_size, heads=heads) for _ in range(num_layers)]
)
)
self.lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False)
def forward(self, input_ids, attention_mask=None, output_attentions=False, return_dict=True, **_kwargs):
hidden_states = self.embed(input_ids)
attentions = []
for layer in self.model.layers:
hidden_states, attention = layer.self_attn(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
attentions.append(attention)
logits = self.lm_head(hidden_states)
if return_dict:
return SimpleNamespace(logits=logits, attentions=tuple(attentions))
return logits, tuple(attentions)
@pytest.fixture()
def mock_tokenizer() -> MockTokenizer:
if torch is None: # pragma: no cover - environment dependent
pytest.skip("torch not installed")
return MockTokenizer()
|