File size: 1,774 Bytes
fda8fb3
 
 
99b7b64
fda8fb3
 
 
99b7b64
 
fda8fb3
 
 
 
 
 
 
 
 
 
 
99b7b64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from __future__ import annotations

import pytest
from app.analysis.sentence_split import SentenceSpan

pytest.importorskip("torch")

import torch

from app.analysis.sentence_split import split_sentences
from app.analysis.token_boundaries import tokenize_with_sentence_ranges


def test_token_boundaries_cover_full_sequence(mock_tokenizer) -> None:
    text = "Alpha beta. Gamma delta."
    spans = split_sentences(text)
    mapping = tokenize_with_sentence_ranges(text, spans, mock_tokenizer)

    assert mapping.token_ranges == [(0, 2), (2, 4)]
    assert mapping.input_ids.shape == (1, 4)


def test_token_boundaries_absorb_separator_gaps() -> None:
    class GapTokenizer:
        def __call__(
            self,
            text: str,
            *,
            add_special_tokens: bool = False,
            return_offsets_mapping: bool = False,
            return_tensors: str | None = None,
        ):
            offsets = [[0, 5], [5, 10], [10, 12], [12, 17]]
            input_ids = [1, 2, 3, 4]
            result = {
                "input_ids": torch.tensor([input_ids], dtype=torch.long)
                if return_tensors == "pt"
                else [input_ids]
            }
            if return_offsets_mapping:
                result["offset_mapping"] = (
                    torch.tensor([offsets], dtype=torch.long)
                    if return_tensors == "pt"
                    else offsets
                )
            return result

    text = "Alpha beta. Gamma."
    spans = [
        SentenceSpan(text="Alpha beta. ", start_char=0, end_char=12),
        SentenceSpan(text="Gamma.", start_char=12, end_char=18),
    ]
    mapping = tokenize_with_sentence_ranges(text, spans, GapTokenizer())

    assert mapping.token_ranges == [(0, 3), (3, 4)]