File size: 17,477 Bytes
45ee481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
"""
Text Segmenter Module

Split blog content into semantic chunks for training.
Preserves complete thoughts and handles various content structures.

Example usage:
    segmenter = TextSegmenter(target_tokens=384, overlap_tokens=50)
    segments = segmenter.segment_posts(blog_posts)
"""

import re
from dataclasses import dataclass, field
from typing import Optional

from loguru import logger

try:
    import tiktoken

    TIKTOKEN_AVAILABLE = True
except ImportError:
    TIKTOKEN_AVAILABLE = False
    logger.warning("tiktoken not available, using approximate token counting")


@dataclass
class TextSegment:
    """Represents a segment of text for training."""

    content: str
    token_count: int
    source_post_index: int
    source_post_title: str
    segment_index: int
    is_complete: bool  # Whether segment ends at a natural boundary
    metadata: dict = field(default_factory=dict)

    def to_dict(self) -> dict:
        """Convert to dictionary for serialization."""
        return {
            "content": self.content,
            "token_count": self.token_count,
            "source_post_index": self.source_post_index,
            "source_post_title": self.source_post_title,
            "segment_index": self.segment_index,
            "is_complete": self.is_complete,
            "metadata": self.metadata,
        }


class TextSegmenter:
    """
    Split text into semantic chunks suitable for LLM training.

    Features:
    - Paragraph-level segmentation
    - Preserves complete thoughts/arguments
    - Handles lists, quotes, code blocks
    - Configurable target size with overlap

    Example:
        >>> segmenter = TextSegmenter(target_tokens=384)
        >>> segments = segmenter.segment_text("Long blog post content...")
        >>> for seg in segments:
        ...     print(f"Segment {seg.segment_index}: {seg.token_count} tokens")
    """

    # Patterns for content structure detection
    LIST_ITEM_PATTERN = re.compile(r"^[\s]*[-*•]\s+", re.MULTILINE)
    NUMBERED_LIST_PATTERN = re.compile(r"^[\s]*\d+[.)\]]\s+", re.MULTILINE)
    CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```", re.MULTILINE)
    BLOCKQUOTE_PATTERN = re.compile(r"^>\s+", re.MULTILINE)

    # Sentence boundary pattern
    SENTENCE_END_PATTERN = re.compile(r"[.!?]+[\s]+")

    def __init__(
        self,
        target_tokens: int = 384,
        min_tokens: int = 100,
        max_tokens: int = 512,
        overlap_tokens: int = 50,
        encoding_name: str = "cl100k_base",
    ):
        """
        Initialize the text segmenter.

        Args:
            target_tokens: Target token count per segment (256-512 recommended)
            min_tokens: Minimum tokens for a valid segment
            max_tokens: Maximum tokens before forcing a split
            overlap_tokens: Token overlap between consecutive segments
            encoding_name: Tiktoken encoding name for token counting
        """
        self.target_tokens = target_tokens
        self.min_tokens = min_tokens
        self.max_tokens = max_tokens
        self.overlap_tokens = overlap_tokens

        # Initialize tokenizer
        if TIKTOKEN_AVAILABLE:
            try:
                self.encoding = tiktoken.get_encoding(encoding_name)
                logger.debug(f"Using tiktoken encoding: {encoding_name}")
            except Exception as e:
                logger.warning(f"Failed to load tiktoken: {e}, using approximation")
                self.encoding = None
        else:
            self.encoding = None

    def count_tokens(self, text: str) -> int:
        """
        Count tokens in text.

        Args:
            text: Text to count tokens for

        Returns:
            Token count
        """
        if self.encoding:
            return len(self.encoding.encode(text))
        else:
            # Approximate: ~4 chars per token for English
            # Adjust for Japanese/mixed content (~2 chars per token)
            # Use a conservative estimate
            return len(text) // 3

    def segment_posts(self, posts: list) -> list[TextSegment]:
        """
        Segment multiple blog posts.

        Args:
            posts: List of BlogPost objects

        Returns:
            List of TextSegment objects
        """
        all_segments = []

        for post in posts:
            post_segments = self.segment_text(
                text=post.content,
                source_post_index=post.index,
                source_post_title=post.title,
            )
            all_segments.extend(post_segments)

        logger.info(f"Created {len(all_segments)} segments from {len(posts)} posts")
        return all_segments

    def segment_text(
        self,
        text: str,
        source_post_index: int = 0,
        source_post_title: str = "Unknown",
    ) -> list[TextSegment]:
        """
        Segment a single text into chunks.

        Args:
            text: Text content to segment
            source_post_index: Index of source post
            source_post_title: Title of source post

        Returns:
            List of TextSegment objects
        """
        if not text.strip():
            return []

        # First, split into paragraphs
        paragraphs = self._split_into_paragraphs(text)

        # Then, group paragraphs into segments
        segments = self._group_paragraphs(
            paragraphs, source_post_index, source_post_title
        )

        return segments

    def _split_into_paragraphs(self, text: str) -> list[dict]:
        """
        Split text into paragraphs while preserving structure.

        Args:
            text: Text to split

        Returns:
            List of paragraph dicts with content and metadata
        """
        # Preserve code blocks as single units
        code_blocks = self.CODE_BLOCK_PATTERN.findall(text)
        for i, block in enumerate(code_blocks):
            text = text.replace(block, f"__CODE_BLOCK_{i}__")

        # Split on double newlines
        raw_paragraphs = re.split(r"\n{2,}", text)

        paragraphs = []
        for para in raw_paragraphs:
            para = para.strip()
            if not para:
                continue

            # Restore code blocks
            for i, block in enumerate(code_blocks):
                para = para.replace(f"__CODE_BLOCK_{i}__", block)

            # Determine paragraph type
            para_type = self._detect_paragraph_type(para)

            paragraphs.append({
                "content": para,
                "type": para_type,
                "tokens": self.count_tokens(para),
            })

        return paragraphs

    def _detect_paragraph_type(self, text: str) -> str:
        """Detect the type of paragraph for better segmentation."""
        if self.CODE_BLOCK_PATTERN.search(text):
            return "code"
        if self.LIST_ITEM_PATTERN.match(text):
            return "list"
        if self.NUMBERED_LIST_PATTERN.match(text):
            return "numbered_list"
        if self.BLOCKQUOTE_PATTERN.match(text):
            return "quote"
        if text.startswith("#"):
            return "header"
        return "text"

    def _group_paragraphs(
        self,
        paragraphs: list[dict],
        source_post_index: int,
        source_post_title: str,
    ) -> list[TextSegment]:
        """
        Group paragraphs into segments of appropriate size.

        Args:
            paragraphs: List of paragraph dicts
            source_post_index: Index of source post
            source_post_title: Title of source post

        Returns:
            List of TextSegment objects
        """
        segments = []
        current_content = []
        current_tokens = 0
        segment_index = 0

        for i, para in enumerate(paragraphs):
            para_tokens = para["tokens"]

            # If single paragraph exceeds max, split it
            if para_tokens > self.max_tokens:
                # First, save current segment if not empty
                if current_content:
                    segments.append(self._create_segment(
                        content="\n\n".join(current_content),
                        tokens=current_tokens,
                        source_post_index=source_post_index,
                        source_post_title=source_post_title,
                        segment_index=segment_index,
                        is_complete=False,
                    ))
                    segment_index += 1
                    current_content = []
                    current_tokens = 0

                # Split large paragraph
                sub_segments = self._split_large_paragraph(
                    para["content"],
                    source_post_index,
                    source_post_title,
                    segment_index,
                )
                segments.extend(sub_segments)
                segment_index += len(sub_segments)
                continue

            # Check if adding this paragraph would exceed target
            if current_tokens + para_tokens > self.target_tokens and current_content:
                # Save current segment
                segments.append(self._create_segment(
                    content="\n\n".join(current_content),
                    tokens=current_tokens,
                    source_post_index=source_post_index,
                    source_post_title=source_post_title,
                    segment_index=segment_index,
                    is_complete=True,
                ))
                segment_index += 1

                # Start new segment with overlap if configured
                if self.overlap_tokens > 0 and current_content:
                    overlap_content = self._get_overlap_content(
                        current_content, self.overlap_tokens
                    )
                    current_content = [overlap_content] if overlap_content else []
                    current_tokens = self.count_tokens(overlap_content) if overlap_content else 0
                else:
                    current_content = []
                    current_tokens = 0

            current_content.append(para["content"])
            current_tokens += para_tokens

        # Don't forget the last segment
        if current_content:
            # Only add if meets minimum token requirement
            if current_tokens >= self.min_tokens:
                segments.append(self._create_segment(
                    content="\n\n".join(current_content),
                    tokens=current_tokens,
                    source_post_index=source_post_index,
                    source_post_title=source_post_title,
                    segment_index=segment_index,
                    is_complete=True,
                ))
            elif segments:
                # Merge with previous segment if too short
                last_segment = segments[-1]
                merged_content = last_segment.content + "\n\n" + "\n\n".join(current_content)
                segments[-1] = self._create_segment(
                    content=merged_content,
                    tokens=self.count_tokens(merged_content),
                    source_post_index=source_post_index,
                    source_post_title=source_post_title,
                    segment_index=last_segment.segment_index,
                    is_complete=True,
                )

        return segments

    def _split_large_paragraph(
        self,
        text: str,
        source_post_index: int,
        source_post_title: str,
        start_index: int,
    ) -> list[TextSegment]:
        """
        Split a large paragraph into smaller segments at sentence boundaries.

        Args:
            text: Text to split
            source_post_index: Source post index
            source_post_title: Source post title
            start_index: Starting segment index

        Returns:
            List of TextSegment objects
        """
        # Split into sentences
        sentences = self.SENTENCE_END_PATTERN.split(text)
        sentences = [s.strip() for s in sentences if s.strip()]

        segments = []
        current_sentences = []
        current_tokens = 0
        segment_index = start_index

        for sentence in sentences:
            sent_tokens = self.count_tokens(sentence)

            if current_tokens + sent_tokens > self.target_tokens and current_sentences:
                # Save current segment
                content = " ".join(current_sentences)
                segments.append(self._create_segment(
                    content=content,
                    tokens=current_tokens,
                    source_post_index=source_post_index,
                    source_post_title=source_post_title,
                    segment_index=segment_index,
                    is_complete=False,
                ))
                segment_index += 1
                current_sentences = []
                current_tokens = 0

            current_sentences.append(sentence)
            current_tokens += sent_tokens

        # Last segment
        if current_sentences:
            content = " ".join(current_sentences)
            segments.append(self._create_segment(
                content=content,
                tokens=self.count_tokens(content),
                source_post_index=source_post_index,
                source_post_title=source_post_title,
                segment_index=segment_index,
                is_complete=True,
            ))

        return segments

    def _get_overlap_content(self, paragraphs: list[str], target_tokens: int) -> str:
        """
        Get content from the end of paragraphs for overlap.

        Args:
            paragraphs: List of paragraph strings
            target_tokens: Target tokens for overlap

        Returns:
            Overlap content string
        """
        # Start from the last paragraph and work backwards
        overlap_parts = []
        current_tokens = 0

        for para in reversed(paragraphs):
            para_tokens = self.count_tokens(para)

            if current_tokens + para_tokens <= target_tokens:
                overlap_parts.insert(0, para)
                current_tokens += para_tokens
            else:
                # Take partial from this paragraph (last sentences)
                sentences = self.SENTENCE_END_PATTERN.split(para)
                for sent in reversed(sentences):
                    sent = sent.strip()
                    if not sent:
                        continue
                    sent_tokens = self.count_tokens(sent)
                    if current_tokens + sent_tokens <= target_tokens:
                        overlap_parts.insert(0, sent)
                        current_tokens += sent_tokens
                    else:
                        break
                break

        return " ".join(overlap_parts) if overlap_parts else ""

    def _create_segment(
        self,
        content: str,
        tokens: int,
        source_post_index: int,
        source_post_title: str,
        segment_index: int,
        is_complete: bool,
    ) -> TextSegment:
        """Create a TextSegment object."""
        return TextSegment(
            content=content,
            token_count=tokens,
            source_post_index=source_post_index,
            source_post_title=source_post_title,
            segment_index=segment_index,
            is_complete=is_complete,
        )


def main():
    """CLI entry point for testing the segmenter."""
    import argparse
    import json

    parser = argparse.ArgumentParser(
        description="Segment text into chunks for LLM training",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    python text_segmenter.py input.txt --output segments.json
    python text_segmenter.py input.txt --target-tokens 256
    python text_segmenter.py input.txt --overlap 30
        """,
    )
    parser.add_argument("input", help="Input text file")
    parser.add_argument("--output", "-o", help="Output JSON file")
    parser.add_argument(
        "--target-tokens",
        type=int,
        default=384,
        help="Target tokens per segment (default: 384)",
    )
    parser.add_argument(
        "--min-tokens",
        type=int,
        default=100,
        help="Minimum tokens per segment (default: 100)",
    )
    parser.add_argument(
        "--max-tokens",
        type=int,
        default=512,
        help="Maximum tokens per segment (default: 512)",
    )
    parser.add_argument(
        "--overlap",
        type=int,
        default=50,
        help="Overlap tokens between segments (default: 50)",
    )

    args = parser.parse_args()

    segmenter = TextSegmenter(
        target_tokens=args.target_tokens,
        min_tokens=args.min_tokens,
        max_tokens=args.max_tokens,
        overlap_tokens=args.overlap,
    )

    with open(args.input, "r", encoding="utf-8") as f:
        text = f.read()

    segments = segmenter.segment_text(text)

    print(f"\nCreated {len(segments)} segments:")
    print("-" * 50)

    for seg in segments:
        print(f"\n[{seg.segment_index}] {seg.token_count} tokens (complete: {seg.is_complete})")
        print(f"    {seg.content[:80]}...")

    if args.output:
        output_data = [s.to_dict() for s in segments]
        with open(args.output, "w", encoding="utf-8") as f:
            json.dump(output_data, f, indent=2, ensure_ascii=False)
        print(f"\nSaved to: {args.output}")


if __name__ == "__main__":
    main()