from __future__ import annotations from dataclasses import dataclass from typing import Any import torch from app.analysis.sentence_split import SentenceSpan @dataclass(slots=True) class TokenizedSentenceMapping: input_ids: torch.Tensor token_ranges: list[tuple[int, int]] offsets: list[tuple[int, int]] text: str def truncate_text_to_token_limit(text: str, tokenizer: Any, max_tokens: int) -> str: if max_tokens <= 0: raise ValueError("max_tokens must be positive.") encoded = tokenizer( text, add_special_tokens=False, return_offsets_mapping=True, ) offsets = encoded["offset_mapping"] if len(offsets) <= max_tokens: return text end_char = offsets[max_tokens - 1][1] return text[:end_char] def tokenize_with_sentence_ranges( text: str, sentence_spans: list[SentenceSpan], tokenizer: Any, ) -> TokenizedSentenceMapping: encoded = tokenizer( text, add_special_tokens=False, return_offsets_mapping=True, return_tensors="pt", ) input_ids = encoded["input_ids"] raw_offsets = encoded["offset_mapping"][0].tolist() offsets = [(int(start), int(end)) for start, end in raw_offsets] token_ranges: list[tuple[int, int]] = [] for span in sentence_spans: overlapping = [ token_index for token_index, (token_start, token_end) in enumerate(offsets) if token_end > span.start_char and token_start < span.end_char ] if not overlapping: raise ValueError( f"Sentence span {span.start_char}:{span.end_char} mapped to zero tokens." ) token_ranges.append((overlapping[0], overlapping[-1] + 1)) if token_ranges: adjusted_ranges: list[tuple[int, int]] = [] previous_end = 0 for index, (start, end) in enumerate(token_ranges): if index == 0 and start > 0: start = 0 if start < previous_end: raise ValueError("Sentence token ranges overlap after alignment.") if start > previous_end and adjusted_ranges: adjusted_start, _ = adjusted_ranges[-1] adjusted_ranges[-1] = (adjusted_start, start) adjusted_ranges.append((start, end)) previous_end = end if adjusted_ranges: first_start, first_end = adjusted_ranges[0] adjusted_ranges[0] = (0, first_end) last_start, last_end = adjusted_ranges[-1] adjusted_ranges[-1] = (last_start, len(offsets)) token_ranges = adjusted_ranges return TokenizedSentenceMapping( input_ids=input_ids, token_ranges=token_ranges, offsets=offsets, text=text, )