cot-anc / app /analysis /token_boundaries.py
BART-ender's picture
Fix sentence token alignment
99b7b64 verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import torch
from app.analysis.sentence_split import SentenceSpan
@dataclass(slots=True)
class TokenizedSentenceMapping:
input_ids: torch.Tensor
token_ranges: list[tuple[int, int]]
offsets: list[tuple[int, int]]
text: str
def truncate_text_to_token_limit(text: str, tokenizer: Any, max_tokens: int) -> str:
if max_tokens <= 0:
raise ValueError("max_tokens must be positive.")
encoded = tokenizer(
text,
add_special_tokens=False,
return_offsets_mapping=True,
)
offsets = encoded["offset_mapping"]
if len(offsets) <= max_tokens:
return text
end_char = offsets[max_tokens - 1][1]
return text[:end_char]
def tokenize_with_sentence_ranges(
text: str,
sentence_spans: list[SentenceSpan],
tokenizer: Any,
) -> TokenizedSentenceMapping:
encoded = tokenizer(
text,
add_special_tokens=False,
return_offsets_mapping=True,
return_tensors="pt",
)
input_ids = encoded["input_ids"]
raw_offsets = encoded["offset_mapping"][0].tolist()
offsets = [(int(start), int(end)) for start, end in raw_offsets]
token_ranges: list[tuple[int, int]] = []
for span in sentence_spans:
overlapping = [
token_index
for token_index, (token_start, token_end) in enumerate(offsets)
if token_end > span.start_char and token_start < span.end_char
]
if not overlapping:
raise ValueError(
f"Sentence span {span.start_char}:{span.end_char} mapped to zero tokens."
)
token_ranges.append((overlapping[0], overlapping[-1] + 1))
if token_ranges:
adjusted_ranges: list[tuple[int, int]] = []
previous_end = 0
for index, (start, end) in enumerate(token_ranges):
if index == 0 and start > 0:
start = 0
if start < previous_end:
raise ValueError("Sentence token ranges overlap after alignment.")
if start > previous_end and adjusted_ranges:
adjusted_start, _ = adjusted_ranges[-1]
adjusted_ranges[-1] = (adjusted_start, start)
adjusted_ranges.append((start, end))
previous_end = end
if adjusted_ranges:
first_start, first_end = adjusted_ranges[0]
adjusted_ranges[0] = (0, first_end)
last_start, last_end = adjusted_ranges[-1]
adjusted_ranges[-1] = (last_start, len(offsets))
token_ranges = adjusted_ranges
return TokenizedSentenceMapping(
input_ids=input_ids,
token_ranges=token_ranges,
offsets=offsets,
text=text,
)