File size: 2,792 Bytes
fda8fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99b7b64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fda8fb3
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from __future__ import annotations

from dataclasses import dataclass
from typing import Any

import torch

from app.analysis.sentence_split import SentenceSpan


@dataclass(slots=True)
class TokenizedSentenceMapping:
    input_ids: torch.Tensor
    token_ranges: list[tuple[int, int]]
    offsets: list[tuple[int, int]]
    text: str


def truncate_text_to_token_limit(text: str, tokenizer: Any, max_tokens: int) -> str:
    if max_tokens <= 0:
        raise ValueError("max_tokens must be positive.")
    encoded = tokenizer(
        text,
        add_special_tokens=False,
        return_offsets_mapping=True,
    )
    offsets = encoded["offset_mapping"]
    if len(offsets) <= max_tokens:
        return text
    end_char = offsets[max_tokens - 1][1]
    return text[:end_char]


def tokenize_with_sentence_ranges(
    text: str,
    sentence_spans: list[SentenceSpan],
    tokenizer: Any,
) -> TokenizedSentenceMapping:
    encoded = tokenizer(
        text,
        add_special_tokens=False,
        return_offsets_mapping=True,
        return_tensors="pt",
    )

    input_ids = encoded["input_ids"]
    raw_offsets = encoded["offset_mapping"][0].tolist()
    offsets = [(int(start), int(end)) for start, end in raw_offsets]
    token_ranges: list[tuple[int, int]] = []

    for span in sentence_spans:
        overlapping = [
            token_index
            for token_index, (token_start, token_end) in enumerate(offsets)
            if token_end > span.start_char and token_start < span.end_char
        ]
        if not overlapping:
            raise ValueError(
                f"Sentence span {span.start_char}:{span.end_char} mapped to zero tokens."
            )
        token_ranges.append((overlapping[0], overlapping[-1] + 1))

    if token_ranges:
        adjusted_ranges: list[tuple[int, int]] = []
        previous_end = 0
        for index, (start, end) in enumerate(token_ranges):
            if index == 0 and start > 0:
                start = 0
            if start < previous_end:
                raise ValueError("Sentence token ranges overlap after alignment.")
            if start > previous_end and adjusted_ranges:
                adjusted_start, _ = adjusted_ranges[-1]
                adjusted_ranges[-1] = (adjusted_start, start)
            adjusted_ranges.append((start, end))
            previous_end = end

        if adjusted_ranges:
            first_start, first_end = adjusted_ranges[0]
            adjusted_ranges[0] = (0, first_end)
            last_start, last_end = adjusted_ranges[-1]
            adjusted_ranges[-1] = (last_start, len(offsets))

        token_ranges = adjusted_ranges

    return TokenizedSentenceMapping(
        input_ids=input_ids,
        token_ranges=token_ranges,
        offsets=offsets,
        text=text,
    )