lt_space / app /models /text_chunker.py
Arsive2's picture
Updated comments
d0d0352
import logging
import os
import re
from dataclasses import dataclass
from typing import List, Optional
import nltk
from nltk.tokenize import sent_tokenize
nltk_data_path = os.environ.get('NLTK_DATA', '/app/nltk_data')
nltk.data.path.append(nltk_data_path)
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
try:
nltk.download('punkt', download_dir=nltk_data_path)
except Exception as e:
logging.warning(f"Failed to download NLTK data: {e}")
logger = logging.getLogger(__name__)
@dataclass
class TextChunk:
"""Class to represent a chunk of text with metadata"""
text: str
index: int
token_count: int
is_partial_sentence: bool = False
original_start: int = 0
original_end: int = 0
class TextChunker:
"""
A utility class for chunking large texts into smaller pieces while preserving
sentence boundaries and context where possible.
"""
def __init__(
self,
max_tokens: int = 450,
overlap_tokens: int = 50,
preserve_paragraphs: bool = True
):
"""
Initialize the TextChunker.
Args:
max_tokens: Maximum number of tokens per chunk
overlap_tokens: Number of tokens to overlap between chunks
preserve_paragraphs: Whether to try to preserve paragraph boundaries
"""
self.max_tokens = max_tokens
self.overlap_tokens = overlap_tokens
self.preserve_paragraphs = preserve_paragraphs
def preprocess_text(self, text: str) -> str:
"""Clean and normalize text before chunking."""
if not text:
return ""
text = re.sub(r'\n\s*\n', '\n', text)
text = re.sub(r'[\r\t\f\v]', ' ', text)
text = re.sub(r' +', ' ', text)
text = re.sub(r' *\n *', '\n', text)
text = text.strip()
text = re.sub(r'•\s*', '• ', text)
text = re.sub(r'^\s*[-*]\s+', '• ', text, flags=re.MULTILINE)
return text
def estimate_tokens(self, text: str) -> int:
"""
Estimate the number of tokens in a text string.
This is a rough approximation - actual token count may vary by tokenizer.
"""
words = re.findall(r'\b\w+\b|[^\w\s]', text)
return len(words)
def split_into_sentences(self, text: str) -> List[str]:
"""Split text into sentences using NLTK."""
try:
return sent_tokenize(text)
except Exception as e:
logger.warning(f"Error in sentence tokenization: {e}")
return [s.strip() + '.' for s in text.split('.') if s.strip()]
def get_chunk_text(self, sentences: List[str], start_idx: int, max_tokens: int) -> tuple:
"""
Get chunk text starting from start_idx that fits within max_tokens.
Returns tuple of (chunk_text, end_idx, is_partial_sentence).
"""
current_tokens = 0
current_sentences = []
is_partial = False
for i in range(start_idx, len(sentences)):
sentence = sentences[i]
sentence_tokens = self.estimate_tokens(sentence)
if sentence_tokens > max_tokens:
if not current_sentences: # First sentence
words = sentence.split()
current_chunk = []
word_count = 0
for word in words:
word_tokens = self.estimate_tokens(word)
if word_count + word_tokens <= max_tokens:
current_chunk.append(word)
word_count += word_tokens
else:
break
chunk_text = ' '.join(current_chunk)
is_partial = True
return chunk_text, i, is_partial
break
if current_tokens + sentence_tokens > max_tokens and current_sentences:
break
current_sentences.append(sentence)
current_tokens += sentence_tokens
return ' '.join(current_sentences), start_idx + len(current_sentences), is_partial
def create_chunks(self, text: str) -> List[TextChunk]:
"""
Split text into chunks that respect sentence boundaries where possible.
Args:
text: Input text to be chunked
Returns:
List of TextChunk objects
"""
text = self.preprocess_text(text)
if not text:
return []
chunks = []
current_idx = 0
if self.preserve_paragraphs:
paragraphs = text.split('\n')
else:
paragraphs = [text]
for para in paragraphs:
if not para.strip():
continue
sentences = self.split_into_sentences(para)
para_start = 0
while para_start < len(sentences):
chunk_text, next_start, is_partial = self.get_chunk_text(
sentences, para_start, self.max_tokens
)
if not chunk_text:
break
original_start = text.find(chunk_text)
original_end = original_start + len(chunk_text)
chunks.append(TextChunk(
text=chunk_text,
index=current_idx,
token_count=self.estimate_tokens(chunk_text),
is_partial_sentence=is_partial,
original_start=original_start,
original_end=original_end
))
current_idx += 1
para_start = next_start if not is_partial else next_start + 1
return chunks
def combine_translations(self, original_text: str, chunks: List[TextChunk],
translations: List[str]) -> str:
"""
Combine translated chunks back into a single text, handling overlaps.
Args:
original_text: Original input text
chunks: List of TextChunk objects
translations: List of translated text chunks
Returns:
Combined translated text
"""
if len(chunks) != len(translations):
raise ValueError("Number of chunks and translations must match")
if len(chunks) == 0:
return ""
if len(chunks) == 1:
return translations[0]
result = []
for i, (chunk, translation) in enumerate(zip(chunks, translations)):
if i > 0 and chunk.is_partial_sentence:
prev_translation = translations[i-1]
overlap = self._find_overlap(prev_translation, translation)
if overlap:
translation = translation[len(overlap):]
result.append(translation)
return ' '.join(result)
def _find_overlap(self, text1: str, text2: str, min_length: int = 10) -> Optional[str]:
"""Find overlapping text between two strings."""
if not text1 or not text2:
return None
end_text = text1[-100:] # Look at last 100 chars
start_text = text2[:100] # Look at first 100 chars
overlap = None
for length in range(min(len(end_text), len(start_text)), min_length - 1, -1):
if end_text[-length:] == start_text[:length]:
overlap = start_text[:length]
break
return overlap