Spaces:
Paused
Paused
| """ | |
| Text Segmenter Module | |
| Split blog content into semantic chunks for training. | |
| Preserves complete thoughts and handles various content structures. | |
| Example usage: | |
| segmenter = TextSegmenter(target_tokens=384, overlap_tokens=50) | |
| segments = segmenter.segment_posts(blog_posts) | |
| """ | |
| import re | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| from loguru import logger | |
| try: | |
| import tiktoken | |
| TIKTOKEN_AVAILABLE = True | |
| except ImportError: | |
| TIKTOKEN_AVAILABLE = False | |
| logger.warning("tiktoken not available, using approximate token counting") | |
| class TextSegment: | |
| """Represents a segment of text for training.""" | |
| content: str | |
| token_count: int | |
| source_post_index: int | |
| source_post_title: str | |
| segment_index: int | |
| is_complete: bool # Whether segment ends at a natural boundary | |
| metadata: dict = field(default_factory=dict) | |
| def to_dict(self) -> dict: | |
| """Convert to dictionary for serialization.""" | |
| return { | |
| "content": self.content, | |
| "token_count": self.token_count, | |
| "source_post_index": self.source_post_index, | |
| "source_post_title": self.source_post_title, | |
| "segment_index": self.segment_index, | |
| "is_complete": self.is_complete, | |
| "metadata": self.metadata, | |
| } | |
| class TextSegmenter: | |
| """ | |
| Split text into semantic chunks suitable for LLM training. | |
| Features: | |
| - Paragraph-level segmentation | |
| - Preserves complete thoughts/arguments | |
| - Handles lists, quotes, code blocks | |
| - Configurable target size with overlap | |
| Example: | |
| >>> segmenter = TextSegmenter(target_tokens=384) | |
| >>> segments = segmenter.segment_text("Long blog post content...") | |
| >>> for seg in segments: | |
| ... print(f"Segment {seg.segment_index}: {seg.token_count} tokens") | |
| """ | |
| # Patterns for content structure detection | |
| LIST_ITEM_PATTERN = re.compile(r"^[\s]*[-*•]\s+", re.MULTILINE) | |
| NUMBERED_LIST_PATTERN = re.compile(r"^[\s]*\d+[.)\]]\s+", re.MULTILINE) | |
| CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```", re.MULTILINE) | |
| BLOCKQUOTE_PATTERN = re.compile(r"^>\s+", re.MULTILINE) | |
| # Sentence boundary pattern | |
| SENTENCE_END_PATTERN = re.compile(r"[.!?]+[\s]+") | |
| def __init__( | |
| self, | |
| target_tokens: int = 384, | |
| min_tokens: int = 100, | |
| max_tokens: int = 512, | |
| overlap_tokens: int = 50, | |
| encoding_name: str = "cl100k_base", | |
| ): | |
| """ | |
| Initialize the text segmenter. | |
| Args: | |
| target_tokens: Target token count per segment (256-512 recommended) | |
| min_tokens: Minimum tokens for a valid segment | |
| max_tokens: Maximum tokens before forcing a split | |
| overlap_tokens: Token overlap between consecutive segments | |
| encoding_name: Tiktoken encoding name for token counting | |
| """ | |
| self.target_tokens = target_tokens | |
| self.min_tokens = min_tokens | |
| self.max_tokens = max_tokens | |
| self.overlap_tokens = overlap_tokens | |
| # Initialize tokenizer | |
| if TIKTOKEN_AVAILABLE: | |
| try: | |
| self.encoding = tiktoken.get_encoding(encoding_name) | |
| logger.debug(f"Using tiktoken encoding: {encoding_name}") | |
| except Exception as e: | |
| logger.warning(f"Failed to load tiktoken: {e}, using approximation") | |
| self.encoding = None | |
| else: | |
| self.encoding = None | |
| def count_tokens(self, text: str) -> int: | |
| """ | |
| Count tokens in text. | |
| Args: | |
| text: Text to count tokens for | |
| Returns: | |
| Token count | |
| """ | |
| if self.encoding: | |
| return len(self.encoding.encode(text)) | |
| else: | |
| # Approximate: ~4 chars per token for English | |
| # Adjust for Japanese/mixed content (~2 chars per token) | |
| # Use a conservative estimate | |
| return len(text) // 3 | |
| def segment_posts(self, posts: list) -> list[TextSegment]: | |
| """ | |
| Segment multiple blog posts. | |
| Args: | |
| posts: List of BlogPost objects | |
| Returns: | |
| List of TextSegment objects | |
| """ | |
| all_segments = [] | |
| for post in posts: | |
| post_segments = self.segment_text( | |
| text=post.content, | |
| source_post_index=post.index, | |
| source_post_title=post.title, | |
| ) | |
| all_segments.extend(post_segments) | |
| logger.info(f"Created {len(all_segments)} segments from {len(posts)} posts") | |
| return all_segments | |
| def segment_text( | |
| self, | |
| text: str, | |
| source_post_index: int = 0, | |
| source_post_title: str = "Unknown", | |
| ) -> list[TextSegment]: | |
| """ | |
| Segment a single text into chunks. | |
| Args: | |
| text: Text content to segment | |
| source_post_index: Index of source post | |
| source_post_title: Title of source post | |
| Returns: | |
| List of TextSegment objects | |
| """ | |
| if not text.strip(): | |
| return [] | |
| # First, split into paragraphs | |
| paragraphs = self._split_into_paragraphs(text) | |
| # Then, group paragraphs into segments | |
| segments = self._group_paragraphs( | |
| paragraphs, source_post_index, source_post_title | |
| ) | |
| return segments | |
| def _split_into_paragraphs(self, text: str) -> list[dict]: | |
| """ | |
| Split text into paragraphs while preserving structure. | |
| Args: | |
| text: Text to split | |
| Returns: | |
| List of paragraph dicts with content and metadata | |
| """ | |
| # Preserve code blocks as single units | |
| code_blocks = self.CODE_BLOCK_PATTERN.findall(text) | |
| for i, block in enumerate(code_blocks): | |
| text = text.replace(block, f"__CODE_BLOCK_{i}__") | |
| # Split on double newlines | |
| raw_paragraphs = re.split(r"\n{2,}", text) | |
| paragraphs = [] | |
| for para in raw_paragraphs: | |
| para = para.strip() | |
| if not para: | |
| continue | |
| # Restore code blocks | |
| for i, block in enumerate(code_blocks): | |
| para = para.replace(f"__CODE_BLOCK_{i}__", block) | |
| # Determine paragraph type | |
| para_type = self._detect_paragraph_type(para) | |
| paragraphs.append({ | |
| "content": para, | |
| "type": para_type, | |
| "tokens": self.count_tokens(para), | |
| }) | |
| return paragraphs | |
| def _detect_paragraph_type(self, text: str) -> str: | |
| """Detect the type of paragraph for better segmentation.""" | |
| if self.CODE_BLOCK_PATTERN.search(text): | |
| return "code" | |
| if self.LIST_ITEM_PATTERN.match(text): | |
| return "list" | |
| if self.NUMBERED_LIST_PATTERN.match(text): | |
| return "numbered_list" | |
| if self.BLOCKQUOTE_PATTERN.match(text): | |
| return "quote" | |
| if text.startswith("#"): | |
| return "header" | |
| return "text" | |
| def _group_paragraphs( | |
| self, | |
| paragraphs: list[dict], | |
| source_post_index: int, | |
| source_post_title: str, | |
| ) -> list[TextSegment]: | |
| """ | |
| Group paragraphs into segments of appropriate size. | |
| Args: | |
| paragraphs: List of paragraph dicts | |
| source_post_index: Index of source post | |
| source_post_title: Title of source post | |
| Returns: | |
| List of TextSegment objects | |
| """ | |
| segments = [] | |
| current_content = [] | |
| current_tokens = 0 | |
| segment_index = 0 | |
| for i, para in enumerate(paragraphs): | |
| para_tokens = para["tokens"] | |
| # If single paragraph exceeds max, split it | |
| if para_tokens > self.max_tokens: | |
| # First, save current segment if not empty | |
| if current_content: | |
| segments.append(self._create_segment( | |
| content="\n\n".join(current_content), | |
| tokens=current_tokens, | |
| source_post_index=source_post_index, | |
| source_post_title=source_post_title, | |
| segment_index=segment_index, | |
| is_complete=False, | |
| )) | |
| segment_index += 1 | |
| current_content = [] | |
| current_tokens = 0 | |
| # Split large paragraph | |
| sub_segments = self._split_large_paragraph( | |
| para["content"], | |
| source_post_index, | |
| source_post_title, | |
| segment_index, | |
| ) | |
| segments.extend(sub_segments) | |
| segment_index += len(sub_segments) | |
| continue | |
| # Check if adding this paragraph would exceed target | |
| if current_tokens + para_tokens > self.target_tokens and current_content: | |
| # Save current segment | |
| segments.append(self._create_segment( | |
| content="\n\n".join(current_content), | |
| tokens=current_tokens, | |
| source_post_index=source_post_index, | |
| source_post_title=source_post_title, | |
| segment_index=segment_index, | |
| is_complete=True, | |
| )) | |
| segment_index += 1 | |
| # Start new segment with overlap if configured | |
| if self.overlap_tokens > 0 and current_content: | |
| overlap_content = self._get_overlap_content( | |
| current_content, self.overlap_tokens | |
| ) | |
| current_content = [overlap_content] if overlap_content else [] | |
| current_tokens = self.count_tokens(overlap_content) if overlap_content else 0 | |
| else: | |
| current_content = [] | |
| current_tokens = 0 | |
| current_content.append(para["content"]) | |
| current_tokens += para_tokens | |
| # Don't forget the last segment | |
| if current_content: | |
| # Only add if meets minimum token requirement | |
| if current_tokens >= self.min_tokens: | |
| segments.append(self._create_segment( | |
| content="\n\n".join(current_content), | |
| tokens=current_tokens, | |
| source_post_index=source_post_index, | |
| source_post_title=source_post_title, | |
| segment_index=segment_index, | |
| is_complete=True, | |
| )) | |
| elif segments: | |
| # Merge with previous segment if too short | |
| last_segment = segments[-1] | |
| merged_content = last_segment.content + "\n\n" + "\n\n".join(current_content) | |
| segments[-1] = self._create_segment( | |
| content=merged_content, | |
| tokens=self.count_tokens(merged_content), | |
| source_post_index=source_post_index, | |
| source_post_title=source_post_title, | |
| segment_index=last_segment.segment_index, | |
| is_complete=True, | |
| ) | |
| return segments | |
| def _split_large_paragraph( | |
| self, | |
| text: str, | |
| source_post_index: int, | |
| source_post_title: str, | |
| start_index: int, | |
| ) -> list[TextSegment]: | |
| """ | |
| Split a large paragraph into smaller segments at sentence boundaries. | |
| Args: | |
| text: Text to split | |
| source_post_index: Source post index | |
| source_post_title: Source post title | |
| start_index: Starting segment index | |
| Returns: | |
| List of TextSegment objects | |
| """ | |
| # Split into sentences | |
| sentences = self.SENTENCE_END_PATTERN.split(text) | |
| sentences = [s.strip() for s in sentences if s.strip()] | |
| segments = [] | |
| current_sentences = [] | |
| current_tokens = 0 | |
| segment_index = start_index | |
| for sentence in sentences: | |
| sent_tokens = self.count_tokens(sentence) | |
| if current_tokens + sent_tokens > self.target_tokens and current_sentences: | |
| # Save current segment | |
| content = " ".join(current_sentences) | |
| segments.append(self._create_segment( | |
| content=content, | |
| tokens=current_tokens, | |
| source_post_index=source_post_index, | |
| source_post_title=source_post_title, | |
| segment_index=segment_index, | |
| is_complete=False, | |
| )) | |
| segment_index += 1 | |
| current_sentences = [] | |
| current_tokens = 0 | |
| current_sentences.append(sentence) | |
| current_tokens += sent_tokens | |
| # Last segment | |
| if current_sentences: | |
| content = " ".join(current_sentences) | |
| segments.append(self._create_segment( | |
| content=content, | |
| tokens=self.count_tokens(content), | |
| source_post_index=source_post_index, | |
| source_post_title=source_post_title, | |
| segment_index=segment_index, | |
| is_complete=True, | |
| )) | |
| return segments | |
| def _get_overlap_content(self, paragraphs: list[str], target_tokens: int) -> str: | |
| """ | |
| Get content from the end of paragraphs for overlap. | |
| Args: | |
| paragraphs: List of paragraph strings | |
| target_tokens: Target tokens for overlap | |
| Returns: | |
| Overlap content string | |
| """ | |
| # Start from the last paragraph and work backwards | |
| overlap_parts = [] | |
| current_tokens = 0 | |
| for para in reversed(paragraphs): | |
| para_tokens = self.count_tokens(para) | |
| if current_tokens + para_tokens <= target_tokens: | |
| overlap_parts.insert(0, para) | |
| current_tokens += para_tokens | |
| else: | |
| # Take partial from this paragraph (last sentences) | |
| sentences = self.SENTENCE_END_PATTERN.split(para) | |
| for sent in reversed(sentences): | |
| sent = sent.strip() | |
| if not sent: | |
| continue | |
| sent_tokens = self.count_tokens(sent) | |
| if current_tokens + sent_tokens <= target_tokens: | |
| overlap_parts.insert(0, sent) | |
| current_tokens += sent_tokens | |
| else: | |
| break | |
| break | |
| return " ".join(overlap_parts) if overlap_parts else "" | |
| def _create_segment( | |
| self, | |
| content: str, | |
| tokens: int, | |
| source_post_index: int, | |
| source_post_title: str, | |
| segment_index: int, | |
| is_complete: bool, | |
| ) -> TextSegment: | |
| """Create a TextSegment object.""" | |
| return TextSegment( | |
| content=content, | |
| token_count=tokens, | |
| source_post_index=source_post_index, | |
| source_post_title=source_post_title, | |
| segment_index=segment_index, | |
| is_complete=is_complete, | |
| ) | |
| def main(): | |
| """CLI entry point for testing the segmenter.""" | |
| import argparse | |
| import json | |
| parser = argparse.ArgumentParser( | |
| description="Segment text into chunks for LLM training", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| python text_segmenter.py input.txt --output segments.json | |
| python text_segmenter.py input.txt --target-tokens 256 | |
| python text_segmenter.py input.txt --overlap 30 | |
| """, | |
| ) | |
| parser.add_argument("input", help="Input text file") | |
| parser.add_argument("--output", "-o", help="Output JSON file") | |
| parser.add_argument( | |
| "--target-tokens", | |
| type=int, | |
| default=384, | |
| help="Target tokens per segment (default: 384)", | |
| ) | |
| parser.add_argument( | |
| "--min-tokens", | |
| type=int, | |
| default=100, | |
| help="Minimum tokens per segment (default: 100)", | |
| ) | |
| parser.add_argument( | |
| "--max-tokens", | |
| type=int, | |
| default=512, | |
| help="Maximum tokens per segment (default: 512)", | |
| ) | |
| parser.add_argument( | |
| "--overlap", | |
| type=int, | |
| default=50, | |
| help="Overlap tokens between segments (default: 50)", | |
| ) | |
| args = parser.parse_args() | |
| segmenter = TextSegmenter( | |
| target_tokens=args.target_tokens, | |
| min_tokens=args.min_tokens, | |
| max_tokens=args.max_tokens, | |
| overlap_tokens=args.overlap, | |
| ) | |
| with open(args.input, "r", encoding="utf-8") as f: | |
| text = f.read() | |
| segments = segmenter.segment_text(text) | |
| print(f"\nCreated {len(segments)} segments:") | |
| print("-" * 50) | |
| for seg in segments: | |
| print(f"\n[{seg.segment_index}] {seg.token_count} tokens (complete: {seg.is_complete})") | |
| print(f" {seg.content[:80]}...") | |
| if args.output: | |
| output_data = [s.to_dict() for s in segments] | |
| with open(args.output, "w", encoding="utf-8") as f: | |
| json.dump(output_data, f, indent=2, ensure_ascii=False) | |
| print(f"\nSaved to: {args.output}") | |
| if __name__ == "__main__": | |
| main() | |