cot-anc / tests /conftest.py
BART-ender's picture
Deploy Thought Anchors
fda8fb3 verified
from __future__ import annotations
import importlib.util
import re
from types import SimpleNamespace
import pytest
HAS_TORCH = importlib.util.find_spec("torch") is not None
if HAS_TORCH:
import torch
else: # pragma: no cover - environment dependent
torch = None
class MockTokenizer:
def __call__(
self,
text: str,
*,
add_special_tokens: bool = False,
return_offsets_mapping: bool = False,
return_tensors: str | None = None,
):
if torch is None: # pragma: no cover - environment dependent
raise RuntimeError("MockTokenizer requires torch.")
matches = list(re.finditer(r"\S+", text))
input_ids = [index + 1 for index, _match in enumerate(matches)]
result = {
"input_ids": torch.tensor([input_ids], dtype=torch.long)
if return_tensors == "pt"
else [input_ids]
}
if return_offsets_mapping:
offsets = [[match.start(), match.end()] for match in matches]
result["offset_mapping"] = (
torch.tensor([offsets], dtype=torch.long)
if return_tensors == "pt"
else offsets
)
return result
if torch is not None:
class FakeSelfAttention(torch.nn.Module):
def __init__(self, hidden_size: int, heads: int) -> None:
super().__init__()
self.heads = heads
self.scale = hidden_size ** -0.5
def forward(self, hidden_states, attention_mask=None, output_attentions=False, **_kwargs):
scores = torch.einsum("bqd,bkd->bqk", hidden_states, hidden_states) * self.scale
scores = scores.unsqueeze(1).repeat(1, self.heads, 1, 1)
causal_mask = torch.triu(
torch.ones(
scores.shape[-2:],
dtype=torch.bool,
device=hidden_states.device,
),
diagonal=1,
)
scores = scores.masked_fill(
causal_mask.unsqueeze(0).unsqueeze(0),
torch.finfo(scores.dtype).min,
)
if attention_mask is not None and attention_mask.dim() == 4:
scores = scores + attention_mask
attention = torch.softmax(scores, dim=-1)
context = torch.einsum("bhqk,bkd->bhqd", attention, hidden_states).mean(dim=1)
return hidden_states + context, attention
class FakeLayer(torch.nn.Module):
def __init__(self, hidden_size: int, heads: int) -> None:
super().__init__()
self.self_attn = FakeSelfAttention(hidden_size=hidden_size, heads=heads)
class FakeCausalLM(torch.nn.Module):
def __init__(self, vocab_size: int = 32, hidden_size: int = 8, num_layers: int = 2, heads: int = 2):
super().__init__()
self.config = SimpleNamespace(_attn_implementation="eager")
self.embed = torch.nn.Embedding(vocab_size, hidden_size)
self.model = SimpleNamespace(
layers=torch.nn.ModuleList(
[FakeLayer(hidden_size=hidden_size, heads=heads) for _ in range(num_layers)]
)
)
self.lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False)
def forward(self, input_ids, attention_mask=None, output_attentions=False, return_dict=True, **_kwargs):
hidden_states = self.embed(input_ids)
attentions = []
for layer in self.model.layers:
hidden_states, attention = layer.self_attn(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
attentions.append(attention)
logits = self.lm_head(hidden_states)
if return_dict:
return SimpleNamespace(logits=logits, attentions=tuple(attentions))
return logits, tuple(attentions)
@pytest.fixture()
def mock_tokenizer() -> MockTokenizer:
if torch is None: # pragma: no cover - environment dependent
pytest.skip("torch not installed")
return MockTokenizer()