ai_exec / src /data_processing /text_segmenter.py
Chaitanya-aitf's picture
Upload 38 files
45ee481 verified
"""
Text Segmenter Module
Split blog content into semantic chunks for training.
Preserves complete thoughts and handles various content structures.
Example usage:
segmenter = TextSegmenter(target_tokens=384, overlap_tokens=50)
segments = segmenter.segment_posts(blog_posts)
"""
import re
from dataclasses import dataclass, field
from typing import Optional
from loguru import logger
try:
import tiktoken
TIKTOKEN_AVAILABLE = True
except ImportError:
TIKTOKEN_AVAILABLE = False
logger.warning("tiktoken not available, using approximate token counting")
@dataclass
class TextSegment:
"""Represents a segment of text for training."""
content: str
token_count: int
source_post_index: int
source_post_title: str
segment_index: int
is_complete: bool # Whether segment ends at a natural boundary
metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
"""Convert to dictionary for serialization."""
return {
"content": self.content,
"token_count": self.token_count,
"source_post_index": self.source_post_index,
"source_post_title": self.source_post_title,
"segment_index": self.segment_index,
"is_complete": self.is_complete,
"metadata": self.metadata,
}
class TextSegmenter:
"""
Split text into semantic chunks suitable for LLM training.
Features:
- Paragraph-level segmentation
- Preserves complete thoughts/arguments
- Handles lists, quotes, code blocks
- Configurable target size with overlap
Example:
>>> segmenter = TextSegmenter(target_tokens=384)
>>> segments = segmenter.segment_text("Long blog post content...")
>>> for seg in segments:
... print(f"Segment {seg.segment_index}: {seg.token_count} tokens")
"""
# Patterns for content structure detection
LIST_ITEM_PATTERN = re.compile(r"^[\s]*[-*•]\s+", re.MULTILINE)
NUMBERED_LIST_PATTERN = re.compile(r"^[\s]*\d+[.)\]]\s+", re.MULTILINE)
CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```", re.MULTILINE)
BLOCKQUOTE_PATTERN = re.compile(r"^>\s+", re.MULTILINE)
# Sentence boundary pattern
SENTENCE_END_PATTERN = re.compile(r"[.!?]+[\s]+")
def __init__(
self,
target_tokens: int = 384,
min_tokens: int = 100,
max_tokens: int = 512,
overlap_tokens: int = 50,
encoding_name: str = "cl100k_base",
):
"""
Initialize the text segmenter.
Args:
target_tokens: Target token count per segment (256-512 recommended)
min_tokens: Minimum tokens for a valid segment
max_tokens: Maximum tokens before forcing a split
overlap_tokens: Token overlap between consecutive segments
encoding_name: Tiktoken encoding name for token counting
"""
self.target_tokens = target_tokens
self.min_tokens = min_tokens
self.max_tokens = max_tokens
self.overlap_tokens = overlap_tokens
# Initialize tokenizer
if TIKTOKEN_AVAILABLE:
try:
self.encoding = tiktoken.get_encoding(encoding_name)
logger.debug(f"Using tiktoken encoding: {encoding_name}")
except Exception as e:
logger.warning(f"Failed to load tiktoken: {e}, using approximation")
self.encoding = None
else:
self.encoding = None
def count_tokens(self, text: str) -> int:
"""
Count tokens in text.
Args:
text: Text to count tokens for
Returns:
Token count
"""
if self.encoding:
return len(self.encoding.encode(text))
else:
# Approximate: ~4 chars per token for English
# Adjust for Japanese/mixed content (~2 chars per token)
# Use a conservative estimate
return len(text) // 3
def segment_posts(self, posts: list) -> list[TextSegment]:
"""
Segment multiple blog posts.
Args:
posts: List of BlogPost objects
Returns:
List of TextSegment objects
"""
all_segments = []
for post in posts:
post_segments = self.segment_text(
text=post.content,
source_post_index=post.index,
source_post_title=post.title,
)
all_segments.extend(post_segments)
logger.info(f"Created {len(all_segments)} segments from {len(posts)} posts")
return all_segments
def segment_text(
self,
text: str,
source_post_index: int = 0,
source_post_title: str = "Unknown",
) -> list[TextSegment]:
"""
Segment a single text into chunks.
Args:
text: Text content to segment
source_post_index: Index of source post
source_post_title: Title of source post
Returns:
List of TextSegment objects
"""
if not text.strip():
return []
# First, split into paragraphs
paragraphs = self._split_into_paragraphs(text)
# Then, group paragraphs into segments
segments = self._group_paragraphs(
paragraphs, source_post_index, source_post_title
)
return segments
def _split_into_paragraphs(self, text: str) -> list[dict]:
"""
Split text into paragraphs while preserving structure.
Args:
text: Text to split
Returns:
List of paragraph dicts with content and metadata
"""
# Preserve code blocks as single units
code_blocks = self.CODE_BLOCK_PATTERN.findall(text)
for i, block in enumerate(code_blocks):
text = text.replace(block, f"__CODE_BLOCK_{i}__")
# Split on double newlines
raw_paragraphs = re.split(r"\n{2,}", text)
paragraphs = []
for para in raw_paragraphs:
para = para.strip()
if not para:
continue
# Restore code blocks
for i, block in enumerate(code_blocks):
para = para.replace(f"__CODE_BLOCK_{i}__", block)
# Determine paragraph type
para_type = self._detect_paragraph_type(para)
paragraphs.append({
"content": para,
"type": para_type,
"tokens": self.count_tokens(para),
})
return paragraphs
def _detect_paragraph_type(self, text: str) -> str:
"""Detect the type of paragraph for better segmentation."""
if self.CODE_BLOCK_PATTERN.search(text):
return "code"
if self.LIST_ITEM_PATTERN.match(text):
return "list"
if self.NUMBERED_LIST_PATTERN.match(text):
return "numbered_list"
if self.BLOCKQUOTE_PATTERN.match(text):
return "quote"
if text.startswith("#"):
return "header"
return "text"
def _group_paragraphs(
self,
paragraphs: list[dict],
source_post_index: int,
source_post_title: str,
) -> list[TextSegment]:
"""
Group paragraphs into segments of appropriate size.
Args:
paragraphs: List of paragraph dicts
source_post_index: Index of source post
source_post_title: Title of source post
Returns:
List of TextSegment objects
"""
segments = []
current_content = []
current_tokens = 0
segment_index = 0
for i, para in enumerate(paragraphs):
para_tokens = para["tokens"]
# If single paragraph exceeds max, split it
if para_tokens > self.max_tokens:
# First, save current segment if not empty
if current_content:
segments.append(self._create_segment(
content="\n\n".join(current_content),
tokens=current_tokens,
source_post_index=source_post_index,
source_post_title=source_post_title,
segment_index=segment_index,
is_complete=False,
))
segment_index += 1
current_content = []
current_tokens = 0
# Split large paragraph
sub_segments = self._split_large_paragraph(
para["content"],
source_post_index,
source_post_title,
segment_index,
)
segments.extend(sub_segments)
segment_index += len(sub_segments)
continue
# Check if adding this paragraph would exceed target
if current_tokens + para_tokens > self.target_tokens and current_content:
# Save current segment
segments.append(self._create_segment(
content="\n\n".join(current_content),
tokens=current_tokens,
source_post_index=source_post_index,
source_post_title=source_post_title,
segment_index=segment_index,
is_complete=True,
))
segment_index += 1
# Start new segment with overlap if configured
if self.overlap_tokens > 0 and current_content:
overlap_content = self._get_overlap_content(
current_content, self.overlap_tokens
)
current_content = [overlap_content] if overlap_content else []
current_tokens = self.count_tokens(overlap_content) if overlap_content else 0
else:
current_content = []
current_tokens = 0
current_content.append(para["content"])
current_tokens += para_tokens
# Don't forget the last segment
if current_content:
# Only add if meets minimum token requirement
if current_tokens >= self.min_tokens:
segments.append(self._create_segment(
content="\n\n".join(current_content),
tokens=current_tokens,
source_post_index=source_post_index,
source_post_title=source_post_title,
segment_index=segment_index,
is_complete=True,
))
elif segments:
# Merge with previous segment if too short
last_segment = segments[-1]
merged_content = last_segment.content + "\n\n" + "\n\n".join(current_content)
segments[-1] = self._create_segment(
content=merged_content,
tokens=self.count_tokens(merged_content),
source_post_index=source_post_index,
source_post_title=source_post_title,
segment_index=last_segment.segment_index,
is_complete=True,
)
return segments
def _split_large_paragraph(
self,
text: str,
source_post_index: int,
source_post_title: str,
start_index: int,
) -> list[TextSegment]:
"""
Split a large paragraph into smaller segments at sentence boundaries.
Args:
text: Text to split
source_post_index: Source post index
source_post_title: Source post title
start_index: Starting segment index
Returns:
List of TextSegment objects
"""
# Split into sentences
sentences = self.SENTENCE_END_PATTERN.split(text)
sentences = [s.strip() for s in sentences if s.strip()]
segments = []
current_sentences = []
current_tokens = 0
segment_index = start_index
for sentence in sentences:
sent_tokens = self.count_tokens(sentence)
if current_tokens + sent_tokens > self.target_tokens and current_sentences:
# Save current segment
content = " ".join(current_sentences)
segments.append(self._create_segment(
content=content,
tokens=current_tokens,
source_post_index=source_post_index,
source_post_title=source_post_title,
segment_index=segment_index,
is_complete=False,
))
segment_index += 1
current_sentences = []
current_tokens = 0
current_sentences.append(sentence)
current_tokens += sent_tokens
# Last segment
if current_sentences:
content = " ".join(current_sentences)
segments.append(self._create_segment(
content=content,
tokens=self.count_tokens(content),
source_post_index=source_post_index,
source_post_title=source_post_title,
segment_index=segment_index,
is_complete=True,
))
return segments
def _get_overlap_content(self, paragraphs: list[str], target_tokens: int) -> str:
"""
Get content from the end of paragraphs for overlap.
Args:
paragraphs: List of paragraph strings
target_tokens: Target tokens for overlap
Returns:
Overlap content string
"""
# Start from the last paragraph and work backwards
overlap_parts = []
current_tokens = 0
for para in reversed(paragraphs):
para_tokens = self.count_tokens(para)
if current_tokens + para_tokens <= target_tokens:
overlap_parts.insert(0, para)
current_tokens += para_tokens
else:
# Take partial from this paragraph (last sentences)
sentences = self.SENTENCE_END_PATTERN.split(para)
for sent in reversed(sentences):
sent = sent.strip()
if not sent:
continue
sent_tokens = self.count_tokens(sent)
if current_tokens + sent_tokens <= target_tokens:
overlap_parts.insert(0, sent)
current_tokens += sent_tokens
else:
break
break
return " ".join(overlap_parts) if overlap_parts else ""
def _create_segment(
self,
content: str,
tokens: int,
source_post_index: int,
source_post_title: str,
segment_index: int,
is_complete: bool,
) -> TextSegment:
"""Create a TextSegment object."""
return TextSegment(
content=content,
token_count=tokens,
source_post_index=source_post_index,
source_post_title=source_post_title,
segment_index=segment_index,
is_complete=is_complete,
)
def main():
"""CLI entry point for testing the segmenter."""
import argparse
import json
parser = argparse.ArgumentParser(
description="Segment text into chunks for LLM training",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python text_segmenter.py input.txt --output segments.json
python text_segmenter.py input.txt --target-tokens 256
python text_segmenter.py input.txt --overlap 30
""",
)
parser.add_argument("input", help="Input text file")
parser.add_argument("--output", "-o", help="Output JSON file")
parser.add_argument(
"--target-tokens",
type=int,
default=384,
help="Target tokens per segment (default: 384)",
)
parser.add_argument(
"--min-tokens",
type=int,
default=100,
help="Minimum tokens per segment (default: 100)",
)
parser.add_argument(
"--max-tokens",
type=int,
default=512,
help="Maximum tokens per segment (default: 512)",
)
parser.add_argument(
"--overlap",
type=int,
default=50,
help="Overlap tokens between segments (default: 50)",
)
args = parser.parse_args()
segmenter = TextSegmenter(
target_tokens=args.target_tokens,
min_tokens=args.min_tokens,
max_tokens=args.max_tokens,
overlap_tokens=args.overlap,
)
with open(args.input, "r", encoding="utf-8") as f:
text = f.read()
segments = segmenter.segment_text(text)
print(f"\nCreated {len(segments)} segments:")
print("-" * 50)
for seg in segments:
print(f"\n[{seg.segment_index}] {seg.token_count} tokens (complete: {seg.is_complete})")
print(f" {seg.content[:80]}...")
if args.output:
output_data = [s.to_dict() for s in segments]
with open(args.output, "w", encoding="utf-8") as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
print(f"\nSaved to: {args.output}")
if __name__ == "__main__":
main()