""" Text Segmenter Module Split blog content into semantic chunks for training. Preserves complete thoughts and handles various content structures. Example usage: segmenter = TextSegmenter(target_tokens=384, overlap_tokens=50) segments = segmenter.segment_posts(blog_posts) """ import re from dataclasses import dataclass, field from typing import Optional from loguru import logger try: import tiktoken TIKTOKEN_AVAILABLE = True except ImportError: TIKTOKEN_AVAILABLE = False logger.warning("tiktoken not available, using approximate token counting") @dataclass class TextSegment: """Represents a segment of text for training.""" content: str token_count: int source_post_index: int source_post_title: str segment_index: int is_complete: bool # Whether segment ends at a natural boundary metadata: dict = field(default_factory=dict) def to_dict(self) -> dict: """Convert to dictionary for serialization.""" return { "content": self.content, "token_count": self.token_count, "source_post_index": self.source_post_index, "source_post_title": self.source_post_title, "segment_index": self.segment_index, "is_complete": self.is_complete, "metadata": self.metadata, } class TextSegmenter: """ Split text into semantic chunks suitable for LLM training. Features: - Paragraph-level segmentation - Preserves complete thoughts/arguments - Handles lists, quotes, code blocks - Configurable target size with overlap Example: >>> segmenter = TextSegmenter(target_tokens=384) >>> segments = segmenter.segment_text("Long blog post content...") >>> for seg in segments: ... print(f"Segment {seg.segment_index}: {seg.token_count} tokens") """ # Patterns for content structure detection LIST_ITEM_PATTERN = re.compile(r"^[\s]*[-*•]\s+", re.MULTILINE) NUMBERED_LIST_PATTERN = re.compile(r"^[\s]*\d+[.)\]]\s+", re.MULTILINE) CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```", re.MULTILINE) BLOCKQUOTE_PATTERN = re.compile(r"^>\s+", re.MULTILINE) # Sentence boundary pattern SENTENCE_END_PATTERN = re.compile(r"[.!?]+[\s]+") def __init__( self, target_tokens: int = 384, min_tokens: int = 100, max_tokens: int = 512, overlap_tokens: int = 50, encoding_name: str = "cl100k_base", ): """ Initialize the text segmenter. Args: target_tokens: Target token count per segment (256-512 recommended) min_tokens: Minimum tokens for a valid segment max_tokens: Maximum tokens before forcing a split overlap_tokens: Token overlap between consecutive segments encoding_name: Tiktoken encoding name for token counting """ self.target_tokens = target_tokens self.min_tokens = min_tokens self.max_tokens = max_tokens self.overlap_tokens = overlap_tokens # Initialize tokenizer if TIKTOKEN_AVAILABLE: try: self.encoding = tiktoken.get_encoding(encoding_name) logger.debug(f"Using tiktoken encoding: {encoding_name}") except Exception as e: logger.warning(f"Failed to load tiktoken: {e}, using approximation") self.encoding = None else: self.encoding = None def count_tokens(self, text: str) -> int: """ Count tokens in text. Args: text: Text to count tokens for Returns: Token count """ if self.encoding: return len(self.encoding.encode(text)) else: # Approximate: ~4 chars per token for English # Adjust for Japanese/mixed content (~2 chars per token) # Use a conservative estimate return len(text) // 3 def segment_posts(self, posts: list) -> list[TextSegment]: """ Segment multiple blog posts. Args: posts: List of BlogPost objects Returns: List of TextSegment objects """ all_segments = [] for post in posts: post_segments = self.segment_text( text=post.content, source_post_index=post.index, source_post_title=post.title, ) all_segments.extend(post_segments) logger.info(f"Created {len(all_segments)} segments from {len(posts)} posts") return all_segments def segment_text( self, text: str, source_post_index: int = 0, source_post_title: str = "Unknown", ) -> list[TextSegment]: """ Segment a single text into chunks. Args: text: Text content to segment source_post_index: Index of source post source_post_title: Title of source post Returns: List of TextSegment objects """ if not text.strip(): return [] # First, split into paragraphs paragraphs = self._split_into_paragraphs(text) # Then, group paragraphs into segments segments = self._group_paragraphs( paragraphs, source_post_index, source_post_title ) return segments def _split_into_paragraphs(self, text: str) -> list[dict]: """ Split text into paragraphs while preserving structure. Args: text: Text to split Returns: List of paragraph dicts with content and metadata """ # Preserve code blocks as single units code_blocks = self.CODE_BLOCK_PATTERN.findall(text) for i, block in enumerate(code_blocks): text = text.replace(block, f"__CODE_BLOCK_{i}__") # Split on double newlines raw_paragraphs = re.split(r"\n{2,}", text) paragraphs = [] for para in raw_paragraphs: para = para.strip() if not para: continue # Restore code blocks for i, block in enumerate(code_blocks): para = para.replace(f"__CODE_BLOCK_{i}__", block) # Determine paragraph type para_type = self._detect_paragraph_type(para) paragraphs.append({ "content": para, "type": para_type, "tokens": self.count_tokens(para), }) return paragraphs def _detect_paragraph_type(self, text: str) -> str: """Detect the type of paragraph for better segmentation.""" if self.CODE_BLOCK_PATTERN.search(text): return "code" if self.LIST_ITEM_PATTERN.match(text): return "list" if self.NUMBERED_LIST_PATTERN.match(text): return "numbered_list" if self.BLOCKQUOTE_PATTERN.match(text): return "quote" if text.startswith("#"): return "header" return "text" def _group_paragraphs( self, paragraphs: list[dict], source_post_index: int, source_post_title: str, ) -> list[TextSegment]: """ Group paragraphs into segments of appropriate size. Args: paragraphs: List of paragraph dicts source_post_index: Index of source post source_post_title: Title of source post Returns: List of TextSegment objects """ segments = [] current_content = [] current_tokens = 0 segment_index = 0 for i, para in enumerate(paragraphs): para_tokens = para["tokens"] # If single paragraph exceeds max, split it if para_tokens > self.max_tokens: # First, save current segment if not empty if current_content: segments.append(self._create_segment( content="\n\n".join(current_content), tokens=current_tokens, source_post_index=source_post_index, source_post_title=source_post_title, segment_index=segment_index, is_complete=False, )) segment_index += 1 current_content = [] current_tokens = 0 # Split large paragraph sub_segments = self._split_large_paragraph( para["content"], source_post_index, source_post_title, segment_index, ) segments.extend(sub_segments) segment_index += len(sub_segments) continue # Check if adding this paragraph would exceed target if current_tokens + para_tokens > self.target_tokens and current_content: # Save current segment segments.append(self._create_segment( content="\n\n".join(current_content), tokens=current_tokens, source_post_index=source_post_index, source_post_title=source_post_title, segment_index=segment_index, is_complete=True, )) segment_index += 1 # Start new segment with overlap if configured if self.overlap_tokens > 0 and current_content: overlap_content = self._get_overlap_content( current_content, self.overlap_tokens ) current_content = [overlap_content] if overlap_content else [] current_tokens = self.count_tokens(overlap_content) if overlap_content else 0 else: current_content = [] current_tokens = 0 current_content.append(para["content"]) current_tokens += para_tokens # Don't forget the last segment if current_content: # Only add if meets minimum token requirement if current_tokens >= self.min_tokens: segments.append(self._create_segment( content="\n\n".join(current_content), tokens=current_tokens, source_post_index=source_post_index, source_post_title=source_post_title, segment_index=segment_index, is_complete=True, )) elif segments: # Merge with previous segment if too short last_segment = segments[-1] merged_content = last_segment.content + "\n\n" + "\n\n".join(current_content) segments[-1] = self._create_segment( content=merged_content, tokens=self.count_tokens(merged_content), source_post_index=source_post_index, source_post_title=source_post_title, segment_index=last_segment.segment_index, is_complete=True, ) return segments def _split_large_paragraph( self, text: str, source_post_index: int, source_post_title: str, start_index: int, ) -> list[TextSegment]: """ Split a large paragraph into smaller segments at sentence boundaries. Args: text: Text to split source_post_index: Source post index source_post_title: Source post title start_index: Starting segment index Returns: List of TextSegment objects """ # Split into sentences sentences = self.SENTENCE_END_PATTERN.split(text) sentences = [s.strip() for s in sentences if s.strip()] segments = [] current_sentences = [] current_tokens = 0 segment_index = start_index for sentence in sentences: sent_tokens = self.count_tokens(sentence) if current_tokens + sent_tokens > self.target_tokens and current_sentences: # Save current segment content = " ".join(current_sentences) segments.append(self._create_segment( content=content, tokens=current_tokens, source_post_index=source_post_index, source_post_title=source_post_title, segment_index=segment_index, is_complete=False, )) segment_index += 1 current_sentences = [] current_tokens = 0 current_sentences.append(sentence) current_tokens += sent_tokens # Last segment if current_sentences: content = " ".join(current_sentences) segments.append(self._create_segment( content=content, tokens=self.count_tokens(content), source_post_index=source_post_index, source_post_title=source_post_title, segment_index=segment_index, is_complete=True, )) return segments def _get_overlap_content(self, paragraphs: list[str], target_tokens: int) -> str: """ Get content from the end of paragraphs for overlap. Args: paragraphs: List of paragraph strings target_tokens: Target tokens for overlap Returns: Overlap content string """ # Start from the last paragraph and work backwards overlap_parts = [] current_tokens = 0 for para in reversed(paragraphs): para_tokens = self.count_tokens(para) if current_tokens + para_tokens <= target_tokens: overlap_parts.insert(0, para) current_tokens += para_tokens else: # Take partial from this paragraph (last sentences) sentences = self.SENTENCE_END_PATTERN.split(para) for sent in reversed(sentences): sent = sent.strip() if not sent: continue sent_tokens = self.count_tokens(sent) if current_tokens + sent_tokens <= target_tokens: overlap_parts.insert(0, sent) current_tokens += sent_tokens else: break break return " ".join(overlap_parts) if overlap_parts else "" def _create_segment( self, content: str, tokens: int, source_post_index: int, source_post_title: str, segment_index: int, is_complete: bool, ) -> TextSegment: """Create a TextSegment object.""" return TextSegment( content=content, token_count=tokens, source_post_index=source_post_index, source_post_title=source_post_title, segment_index=segment_index, is_complete=is_complete, ) def main(): """CLI entry point for testing the segmenter.""" import argparse import json parser = argparse.ArgumentParser( description="Segment text into chunks for LLM training", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python text_segmenter.py input.txt --output segments.json python text_segmenter.py input.txt --target-tokens 256 python text_segmenter.py input.txt --overlap 30 """, ) parser.add_argument("input", help="Input text file") parser.add_argument("--output", "-o", help="Output JSON file") parser.add_argument( "--target-tokens", type=int, default=384, help="Target tokens per segment (default: 384)", ) parser.add_argument( "--min-tokens", type=int, default=100, help="Minimum tokens per segment (default: 100)", ) parser.add_argument( "--max-tokens", type=int, default=512, help="Maximum tokens per segment (default: 512)", ) parser.add_argument( "--overlap", type=int, default=50, help="Overlap tokens between segments (default: 50)", ) args = parser.parse_args() segmenter = TextSegmenter( target_tokens=args.target_tokens, min_tokens=args.min_tokens, max_tokens=args.max_tokens, overlap_tokens=args.overlap, ) with open(args.input, "r", encoding="utf-8") as f: text = f.read() segments = segmenter.segment_text(text) print(f"\nCreated {len(segments)} segments:") print("-" * 50) for seg in segments: print(f"\n[{seg.segment_index}] {seg.token_count} tokens (complete: {seg.is_complete})") print(f" {seg.content[:80]}...") if args.output: output_data = [s.to_dict() for s in segments] with open(args.output, "w", encoding="utf-8") as f: json.dump(output_data, f, indent=2, ensure_ascii=False) print(f"\nSaved to: {args.output}") if __name__ == "__main__": main()