cot-anc / tests /test_token_boundaries.py
BART-ender's picture
Fix sentence token alignment
99b7b64 verified
from __future__ import annotations
import pytest
from app.analysis.sentence_split import SentenceSpan
pytest.importorskip("torch")
import torch
from app.analysis.sentence_split import split_sentences
from app.analysis.token_boundaries import tokenize_with_sentence_ranges
def test_token_boundaries_cover_full_sequence(mock_tokenizer) -> None:
text = "Alpha beta. Gamma delta."
spans = split_sentences(text)
mapping = tokenize_with_sentence_ranges(text, spans, mock_tokenizer)
assert mapping.token_ranges == [(0, 2), (2, 4)]
assert mapping.input_ids.shape == (1, 4)
def test_token_boundaries_absorb_separator_gaps() -> None:
class GapTokenizer:
def __call__(
self,
text: str,
*,
add_special_tokens: bool = False,
return_offsets_mapping: bool = False,
return_tensors: str | None = None,
):
offsets = [[0, 5], [5, 10], [10, 12], [12, 17]]
input_ids = [1, 2, 3, 4]
result = {
"input_ids": torch.tensor([input_ids], dtype=torch.long)
if return_tensors == "pt"
else [input_ids]
}
if return_offsets_mapping:
result["offset_mapping"] = (
torch.tensor([offsets], dtype=torch.long)
if return_tensors == "pt"
else offsets
)
return result
text = "Alpha beta. Gamma."
spans = [
SentenceSpan(text="Alpha beta. ", start_char=0, end_char=12),
SentenceSpan(text="Gamma.", start_char=12, end_char=18),
]
mapping = tokenize_with_sentence_ranges(text, spans, GapTokenizer())
assert mapping.token_ranges == [(0, 3), (3, 4)]