File size: 4,213 Bytes
fda8fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from __future__ import annotations

import importlib.util
import re
from types import SimpleNamespace

import pytest

HAS_TORCH = importlib.util.find_spec("torch") is not None
if HAS_TORCH:
    import torch
else:  # pragma: no cover - environment dependent
    torch = None


class MockTokenizer:
    def __call__(
        self,
        text: str,
        *,
        add_special_tokens: bool = False,
        return_offsets_mapping: bool = False,
        return_tensors: str | None = None,
    ):
        if torch is None:  # pragma: no cover - environment dependent
            raise RuntimeError("MockTokenizer requires torch.")
        matches = list(re.finditer(r"\S+", text))
        input_ids = [index + 1 for index, _match in enumerate(matches)]
        result = {
            "input_ids": torch.tensor([input_ids], dtype=torch.long)
            if return_tensors == "pt"
            else [input_ids]
        }
        if return_offsets_mapping:
            offsets = [[match.start(), match.end()] for match in matches]
            result["offset_mapping"] = (
                torch.tensor([offsets], dtype=torch.long)
                if return_tensors == "pt"
                else offsets
            )
        return result


if torch is not None:
    class FakeSelfAttention(torch.nn.Module):
        def __init__(self, hidden_size: int, heads: int) -> None:
            super().__init__()
            self.heads = heads
            self.scale = hidden_size ** -0.5

        def forward(self, hidden_states, attention_mask=None, output_attentions=False, **_kwargs):
            scores = torch.einsum("bqd,bkd->bqk", hidden_states, hidden_states) * self.scale
            scores = scores.unsqueeze(1).repeat(1, self.heads, 1, 1)
            causal_mask = torch.triu(
                torch.ones(
                    scores.shape[-2:],
                    dtype=torch.bool,
                    device=hidden_states.device,
                ),
                diagonal=1,
            )
            scores = scores.masked_fill(
                causal_mask.unsqueeze(0).unsqueeze(0),
                torch.finfo(scores.dtype).min,
            )
            if attention_mask is not None and attention_mask.dim() == 4:
                scores = scores + attention_mask
            attention = torch.softmax(scores, dim=-1)
            context = torch.einsum("bhqk,bkd->bhqd", attention, hidden_states).mean(dim=1)
            return hidden_states + context, attention


    class FakeLayer(torch.nn.Module):
        def __init__(self, hidden_size: int, heads: int) -> None:
            super().__init__()
            self.self_attn = FakeSelfAttention(hidden_size=hidden_size, heads=heads)


    class FakeCausalLM(torch.nn.Module):
        def __init__(self, vocab_size: int = 32, hidden_size: int = 8, num_layers: int = 2, heads: int = 2):
            super().__init__()
            self.config = SimpleNamespace(_attn_implementation="eager")
            self.embed = torch.nn.Embedding(vocab_size, hidden_size)
            self.model = SimpleNamespace(
                layers=torch.nn.ModuleList(
                    [FakeLayer(hidden_size=hidden_size, heads=heads) for _ in range(num_layers)]
                )
            )
            self.lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False)

        def forward(self, input_ids, attention_mask=None, output_attentions=False, return_dict=True, **_kwargs):
            hidden_states = self.embed(input_ids)
            attentions = []
            for layer in self.model.layers:
                hidden_states, attention = layer.self_attn(
                    hidden_states,
                    attention_mask=attention_mask,
                    output_attentions=output_attentions,
                )
                attentions.append(attention)
            logits = self.lm_head(hidden_states)
            if return_dict:
                return SimpleNamespace(logits=logits, attentions=tuple(attentions))
            return logits, tuple(attentions)


@pytest.fixture()
def mock_tokenizer() -> MockTokenizer:
    if torch is None:  # pragma: no cover - environment dependent
        pytest.skip("torch not installed")
    return MockTokenizer()