puji4ml's picture
Upload 30 files
2b22a59 verified
"""
core/chunker.py - Text Chunking Strategies (LangChain-based)
============================================================
Uses LangChain's text splitters for robust, production-ready chunking
"""
from typing import List
from dataclasses import dataclass
# LangChain text splitters
from langchain_text_splitters import (
RecursiveCharacterTextSplitter,
CharacterTextSplitter,
TokenTextSplitter,
)
@dataclass
class TextChunk:
"""Container for a text chunk with metadata"""
text: str
chunk_id: int
start_char: int
end_char: int
token_count: int
metadata: dict = None
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
class Chunker:
"""Text chunking with multiple strategies using LangChain"""
def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50):
"""
Initialize chunker
Args:
chunk_size: Target chunk size in characters (for char-based) or tokens
chunk_overlap: Number of overlapping characters/tokens between chunks
"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def chunk(self, text: str, strategy: str = "recursive") -> List[TextChunk]:
"""
Chunk text using specified strategy
Args:
text: Text to chunk
strategy: Chunking strategy:
- 'recursive': RecursiveCharacterTextSplitter (RECOMMENDED)
- 'character': Simple character-based splitting
- 'token': Token-based splitting
- 'sentence': Split by sentences (custom)
Returns:
List of TextChunk objects
"""
if strategy == "recursive":
return self._chunk_recursive(text)
elif strategy == "character":
return self._chunk_character(text)
elif strategy == "token":
return self._chunk_token(text)
elif strategy == "sentence":
return self._chunk_sentence(text)
else:
raise ValueError(f"Unknown strategy: {strategy}. Use: recursive, character, token, sentence")
def _estimate_tokens(self, text: str) -> int:
"""Estimate token count (rough: 1 token β‰ˆ 4 chars)"""
return max(1, len(text) // 4)
def _chunk_recursive(self, text: str) -> List[TextChunk]:
"""
Recursive chunking - LangChain's best splitter
Tries to split on paragraphs, then sentences, then words
"""
splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size * 4, # Convert tokens to chars
chunk_overlap=self.chunk_overlap * 4,
length_function=len,
separators=["\n\n", "\n", ". ", " ", ""],
)
chunks_text = splitter.split_text(text)
chunks = []
current_pos = 0
for i, chunk_text in enumerate(chunks_text):
chunks.append(TextChunk(
text=chunk_text,
chunk_id=i,
start_char=current_pos,
end_char=current_pos + len(chunk_text),
token_count=self._estimate_tokens(chunk_text),
metadata={'strategy': 'recursive'}
))
# Approximate next position (accounting for overlap)
current_pos += len(chunk_text) - self.chunk_overlap * 4
return chunks
def _chunk_character(self, text: str) -> List[TextChunk]:
"""Simple character-based chunking"""
splitter = CharacterTextSplitter(
chunk_size=self.chunk_size * 4,
chunk_overlap=self.chunk_overlap * 4,
separator="\n",
length_function=len,
)
chunks_text = splitter.split_text(text)
chunks = []
current_pos = 0
for i, chunk_text in enumerate(chunks_text):
chunks.append(TextChunk(
text=chunk_text,
chunk_id=i,
start_char=current_pos,
end_char=current_pos + len(chunk_text),
token_count=self._estimate_tokens(chunk_text),
metadata={'strategy': 'character'}
))
current_pos += len(chunk_text) - self.chunk_overlap * 4
return chunks
def _chunk_token(self, text: str) -> List[TextChunk]:
"""Token-based chunking (more accurate for LLMs)"""
splitter = TokenTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
)
chunks_text = splitter.split_text(text)
chunks = []
current_pos = 0
for i, chunk_text in enumerate(chunks_text):
chunks.append(TextChunk(
text=chunk_text,
chunk_id=i,
start_char=current_pos,
end_char=current_pos + len(chunk_text),
token_count=self._estimate_tokens(chunk_text),
metadata={'strategy': 'token'}
))
current_pos += len(chunk_text)
return chunks
def _chunk_sentence(self, text: str) -> List[TextChunk]:
"""Sentence-based chunking using RecursiveCharacterTextSplitter"""
splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size * 4,
chunk_overlap=self.chunk_overlap * 4,
length_function=len,
separators=[". ", "! ", "? ", "\n\n", "\n", " "],
)
chunks_text = splitter.split_text(text)
chunks = []
current_pos = 0
for i, chunk_text in enumerate(chunks_text):
chunks.append(TextChunk(
text=chunk_text,
chunk_id=i,
start_char=current_pos,
end_char=current_pos + len(chunk_text),
token_count=self._estimate_tokens(chunk_text),
metadata={'strategy': 'sentence'}
))
current_pos += len(chunk_text) - self.chunk_overlap * 4
return chunks
def get_stats(self, chunks: List[TextChunk]) -> dict:
"""Get statistics about chunks"""
if not chunks:
return {
'num_chunks': 0,
'avg_tokens': 0,
'min_tokens': 0,
'max_tokens': 0,
'total_tokens': 0,
}
token_counts = [c.token_count for c in chunks]
return {
'num_chunks': len(chunks),
'avg_tokens': sum(token_counts) / len(token_counts),
'min_tokens': min(token_counts),
'max_tokens': max(token_counts),
'total_tokens': sum(token_counts),
}
# ============================================================================
# USAGE EXAMPLE
# ============================================================================
if __name__ == "__main__":
# Test text
test_text = """# Introduction to RAG Systems
Retrieval-Augmented Generation (RAG) is a powerful technique that combines information retrieval with text generation.
## How RAG Works
RAG systems work by first retrieving relevant documents from a knowledge base, then using those documents as context for generation.
### Key Components
1. Document chunking
2. Embedding generation
3. Vector storage
4. Retrieval
5. Generation
## Benefits
RAG systems provide more accurate and factual responses compared to pure generative models."""
print("βœ‚οΈ Chunker Test (LangChain-based)")
print("=" * 80)
# Test different strategies
strategies = ["recursive", "character", "token", "sentence"]
for idx, strategy in enumerate(strategies, 1):
print(f"\n[{idx}/{len(strategies)}] πŸ“Š Strategy: {strategy}")
print("-" * 80)
try:
chunker = Chunker(chunk_size=100, chunk_overlap=20)
chunks = chunker.chunk(test_text, strategy=strategy)
stats = chunker.get_stats(chunks)
print(f" βœ… Chunks: {stats['num_chunks']}")
print(f" πŸ“Š Avg tokens: {stats['avg_tokens']:.1f}")
print(f" πŸ“ˆ Range: {stats['min_tokens']}-{stats['max_tokens']} tokens")
print(f" πŸ“ Total tokens: {stats['total_tokens']}")
# Show chunks
for i, chunk in enumerate(chunks):
print(f"\n Chunk {i+1} ({chunk.token_count} tokens):")
print(f" '''{chunk.text}'''")
except Exception as e:
print(f" ❌ Error: {e}")
import traceback
traceback.print_exc()