AI_Toolkit / src /core /TokenChunker.py
NavyDevilDoc's picture
Upload 10 files
c0f31c1 verified
"""
TokenChunker.py
A module for token-based document chunking with configurable overlap and preprocessing.
Features:
- Token-based document splitting with overlap
- Content validation and token counting
- Smart boundary detection to preserve word integrity
- Compatible with multiple tokenizer types (tiktoken, transformers, basic)
"""
import logging
import re
from typing import List, Optional, Dict, Any
from langchain_core.documents import Document
from core.BaseChunker import BaseChunker
logger = logging.getLogger(__name__)
class TokenChunker(BaseChunker):
"""Handles document chunking at the token level with configurable overlap."""
def __init__(
self,
model_name=None,
embedding_model=None,
chunk_size: int = 256,
chunk_overlap: int = 50,
min_chunk_size: int = 50
):
"""
Initialize token chunker with specified models and parameters.
Args:
model_name: Name of the model for tokenization
embedding_model: Model for generating embeddings
chunk_size: Maximum tokens per chunk
chunk_overlap: Number of tokens to overlap between chunks
min_chunk_size: Minimum tokens for a valid chunk
"""
super().__init__(model_name, embedding_model)
# Validate chunking parameters
if chunk_overlap >= chunk_size:
raise ValueError("chunk_overlap must be less than chunk_size")
if min_chunk_size <= 0:
raise ValueError("min_chunk_size must be positive")
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.min_chunk_size = min_chunk_size
self.chunk_stats = []
logger.info(f"TokenChunker initialized: chunk_size={chunk_size}, overlap={chunk_overlap}, min_size={min_chunk_size}")
def _smart_tokenize(self, text: str) -> List[str]:
"""
Tokenize text while preserving word boundaries for reconstruction.
Args:
text: The text content to tokenize
Returns:
List of tokens that can be cleanly rejoined
"""
if not text.strip():
return []
try:
if self.uses_tiktoken:
# For tiktoken, we need a hybrid approach to preserve boundaries
return self._tiktoken_boundary_aware_split(text)
elif hasattr(self.tokenizer, 'tokenize'):
# For transformers tokenizers
tokens = self.tokenizer.tokenize(text)
return self._clean_subword_tokens(tokens)
else:
# Fallback to intelligent word splitting
return self._word_boundary_split(text)
except Exception as e:
logger.warning(f"Tokenization failed: {e}. Using word boundary fallback.")
return self._word_boundary_split(text)
def _tiktoken_boundary_aware_split(self, text: str) -> List[str]:
"""
Split text in a way that's compatible with tiktoken while preserving boundaries.
Args:
text: Input text
Returns:
List of text segments that approximate tokens
"""
# Get actual token count for validation
target_token_count = self.count_tokens(text)
# Split on natural boundaries (spaces, punctuation)
words = re.findall(r'\S+|\s+', text)
# If we have roughly the right number of words, use them
if abs(len(words) - target_token_count) / max(target_token_count, 1) < 0.3:
return [w for w in words if w.strip()]
# Otherwise, use a more granular split
segments = re.findall(r'\w+|[^\w\s]|\s+', text)
return [s for s in segments if s.strip()]
def _clean_subword_tokens(self, tokens: List[str]) -> List[str]:
"""
Clean subword tokens for better reconstruction.
Args:
tokens: Raw tokens from tokenizer
Returns:
Cleaned tokens
"""
cleaned = []
for token in tokens:
# Remove special tokens but keep the content
if token.startswith('##'):
# BERT-style subwords
cleaned.append(token[2:])
elif token.startswith('▁'):
# SentencePiece-style
cleaned.append(' ' + token[1:])
else:
cleaned.append(token)
return [t for t in cleaned if t.strip()]
def _word_boundary_split(self, text: str) -> List[str]:
"""
Split text on word boundaries as fallback tokenization.
Args:
text: Input text
Returns:
List of words
"""
# Split on whitespace but preserve some punctuation as separate tokens
tokens = re.findall(r'\w+|[.!?;,]', text)
return tokens
def _detokenize(self, tokens: List[str]) -> str:
"""
Reconstruct text from tokens, handling different tokenizer types.
Args:
tokens: List of token strings
Returns:
Reconstructed text
"""
if not tokens:
return ""
if self.uses_tiktoken or not hasattr(self.tokenizer, 'tokenize'):
# For tiktoken and basic tokenizers, use space joining with smart spacing
result = ""
for i, token in enumerate(tokens):
if not token.strip():
continue
if i == 0:
result = token
elif token in '.,!?;:':
result += token
elif result and result[-1] in '.,!?;:':
result += " " + token
else:
result += " " + token
return result
else:
# For transformers tokenizers, handle subword reconstruction
text = "".join(tokens)
# Clean up spacing around punctuation
text = re.sub(r'\s+([.!?;,])', r'\1', text)
text = re.sub(r'\s+', ' ', text)
return text.strip()
def _create_token_chunks(self, tokens: List[str]) -> List[List[str]]:
"""
Split tokens into overlapping chunks of specified size.
Args:
tokens: List of token strings
Returns:
List of token chunks
"""
if not tokens:
return []
chunks = []
start = 0
while start < len(tokens):
# Calculate end position for this chunk
end = min(start + self.chunk_size, len(tokens))
# Extract the chunk
chunk_tokens = tokens[start:end]
# Only add chunks that meet minimum size requirement
if len(chunk_tokens) >= self.min_chunk_size:
chunks.append(chunk_tokens)
self.chunk_stats.append(f"Created chunk with {len(chunk_tokens)} tokens")
else:
self.chunk_stats.append(f"Skipped small chunk with {len(chunk_tokens)} tokens")
# Break if we've reached the end
if end >= len(tokens):
break
# Calculate next start position with overlap
start = end - self.chunk_overlap
# Ensure forward progress
if start <= 0:
start = end
return chunks
def _process_single_chunk(self, chunk_tokens: List[str], chunk_index: int,
source_metadata: Dict[str, Any]) -> Optional[Document]:
"""
Process a single token chunk into a Document with metadata.
Args:
chunk_tokens: List of tokens for this chunk
chunk_index: Index of this chunk in the document
source_metadata: Metadata from source document
Returns:
Document object with processed content and metadata, or None if invalid
"""
# Reconstruct text from tokens
chunk_text = self._detokenize(chunk_tokens)
# Validate chunk content
if not self.is_content_valid(chunk_text, min_tokens=self.min_chunk_size):
self.chunk_stats.append(f"Chunk {chunk_index} failed validation")
return None
# Analyze the chunk content
stats = self.analyze_text(chunk_text)
# Create comprehensive metadata
metadata = source_metadata.copy()
metadata.update({
"chunk_index": chunk_index,
"chunk_type": "token",
"chunking_method": "token_based",
"token_count": len(chunk_tokens),
"char_count": stats["char_count"],
"sentence_count": stats["sentence_count"],
"word_count": stats["word_count"],
"chunk_size_limit": self.chunk_size,
"chunk_overlap": self.chunk_overlap
})
return Document(page_content=chunk_text, metadata=metadata)
def token_process_document(self, file_path: str, preprocess: bool = True) -> List[Document]:
"""
Process document using token-based chunking with overlap.
Args:
file_path: Path to the document file
preprocess: Whether to preprocess text content
Returns:
List of Document objects, one per valid token chunk
"""
try:
self.chunk_stats = [] # Reset stats for this document
raw_pages = self.load_document(file_path)
processed_chunks = []
logger.info(f"Processing document with {len(raw_pages)} pages using token chunking")
# Combine all pages into a single text for token-based processing
full_text = ""
combined_metadata = {}
page_info = [] # Track which pages contributed to the text
for page_idx, page in enumerate(raw_pages):
content = page.page_content
# Skip invalid content
if not self.is_content_valid(content):
logger.debug(f"Skipping invalid content on page {page_idx + 1}")
continue
# Preprocess if requested
if preprocess:
content = self.preprocess_text(content)
if not self.is_content_valid(content):
continue
# Track page information
page_info.append({
"page_number": page_idx + 1,
"original_metadata": page.metadata
})
# Combine text with page separation
if full_text:
full_text += "\n\n" + content
else:
full_text = content
# Use metadata from first valid page as base
combined_metadata = page.metadata.copy()
# Update combined metadata to reflect all pages
if page_info:
combined_metadata.update({
"total_pages_processed": len(page_info),
"page_range": f"{page_info[0]['page_number']}-{page_info[-1]['page_number']}",
"source_pages": [str(p["page_number"]) for p in page_info] # ✅ Convert to list of strings
})
# Remove the single "page" field since this represents multiple pages
combined_metadata.pop("page", None)
if not full_text.strip():
logger.warning("No valid content found in document")
return []
# Tokenize the entire document
all_tokens = self._smart_tokenize(full_text)
logger.info(f"Document tokenized into {len(all_tokens)} tokens")
if len(all_tokens) < self.min_chunk_size:
logger.warning(f"Document too short for chunking ({len(all_tokens)} tokens)")
return []
# Create overlapping token chunks
token_chunks = self._create_token_chunks(all_tokens)
logger.info(f"Created {len(token_chunks)} token chunks")
# Convert token chunks to Document objects
for chunk_idx, chunk_tokens in enumerate(token_chunks):
chunk_doc = self._process_single_chunk(
chunk_tokens,
chunk_idx,
combined_metadata
)
if chunk_doc:
processed_chunks.append(chunk_doc)
# Output processing statistics
if self.chunk_stats:
logger.info("\n".join(self.chunk_stats))
logger.info(f"Processed {len(processed_chunks)} valid token chunks")
return processed_chunks
except Exception as e:
logger.error(f"Error in token_process_document: {e}")
raise
def process_document(self, file_path: str, preprocess: bool = True) -> List[Document]:
"""
Process document using token chunking strategy (implements abstract method).
Args:
file_path: Path to the document file
preprocess: Whether to preprocess text content
Returns:
List of Document objects, one per valid token chunk
"""
return self.token_process_document(file_path, preprocess)
def process_text_file(self, file_path: str, preprocess: bool = True) -> List[Document]:
"""
Process text file directly using token-based chunking with overlap.
Args:
file_path: Path to the text file
preprocess: Whether to preprocess text content
Returns:
List of Document objects, one per valid token chunk
"""
try:
from pathlib import Path
from datetime import datetime
self.chunk_stats = [] # Reset stats for this document
# Load the text file directly
content = self.load_text_file(file_path)
# Clean the text using the same logic as PDF conversion
content = self.clean_text_for_processing(content)
# Basic validation
if not self.is_content_valid(content):
logger.warning("Text file content failed validation")
return []
# Light preprocessing if requested (no header/footer removal for txt files)
if preprocess:
# Only apply basic text cleaning, not aggressive preprocessing
content = ' '.join(content.split()) # Normalize whitespace
# Create file-level metadata
file_path_obj = Path(file_path)
file_metadata = {
"source": file_path,
"file_name": file_path_obj.name,
"file_type": "txt",
"total_characters": len(content),
"processing_timestamp": datetime.now().isoformat(),
}
logger.info(f"Processing text file: {file_path_obj.name} ({len(content)} characters)")
# Tokenize the entire document
all_tokens = self._smart_tokenize(content)
logger.info(f"Text file tokenized into {len(all_tokens)} tokens")
if len(all_tokens) < self.min_chunk_size:
logger.warning(f"Text file too short for chunking ({len(all_tokens)} tokens)")
return []
# Create overlapping token chunks
token_chunks = self._create_token_chunks(all_tokens)
logger.info(f"Created {len(token_chunks)} token chunks from text file")
# Convert token chunks to Document objects
processed_chunks = []
for chunk_idx, chunk_tokens in enumerate(token_chunks):
chunk_doc = self._process_single_chunk(
chunk_tokens,
chunk_idx,
file_metadata
)
if chunk_doc:
processed_chunks.append(chunk_doc)
# Output processing statistics
if self.chunk_stats:
logger.info("\n".join(self.chunk_stats))
logger.info(f"Processed {len(processed_chunks)} valid token chunks from text file")
return processed_chunks
except Exception as e:
logger.error(f"Error processing text file: {e}")
raise