File size: 16,436 Bytes
45ee481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
"""
Dataset Builder Module

Build final training dataset in ChatML format for Qwen3 fine-tuning.
Creates train/validation splits with proper formatting.

Example usage:
    builder = DatasetBuilder(system_prompt="You are Ryouken Okuni...")
    builder.build_from_qa_pairs(qa_pairs, output_dir="data/training/")
"""

import json
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

from loguru import logger

try:
    import tiktoken

    TIKTOKEN_AVAILABLE = True
except ImportError:
    TIKTOKEN_AVAILABLE = False


@dataclass
class DatasetStatistics:
    """Statistics about the built dataset."""

    total_examples: int
    train_examples: int
    validation_examples: int
    avg_tokens_per_example: float
    max_tokens: int
    min_tokens: int
    total_tokens: int
    question_type_distribution: dict

    def to_dict(self) -> dict:
        """Convert to dictionary for serialization."""
        return {
            "total_examples": self.total_examples,
            "train_examples": self.train_examples,
            "validation_examples": self.validation_examples,
            "avg_tokens_per_example": round(self.avg_tokens_per_example, 2),
            "max_tokens": self.max_tokens,
            "min_tokens": self.min_tokens,
            "total_tokens": self.total_tokens,
            "question_type_distribution": self.question_type_distribution,
        }


class DatasetBuilder:
    """
    Build training datasets in ChatML format for Qwen3.

    Features:
    - ChatML message format
    - Train/validation split
    - Deduplication
    - Token count validation
    - Statistics generation

    Example:
        >>> builder = DatasetBuilder()
        >>> stats = builder.build_from_qa_pairs(qa_pairs, "data/training/")
        >>> print(f"Built {stats.total_examples} examples")
    """

    # Default system prompt template
    DEFAULT_SYSTEM_PROMPT = """You are {ceo_name}, CEO of {company_name}.

You are a visionary technology leader with deep expertise in AI, business strategy, and innovation. Your communication style is thoughtful, confident, and grounded in real-world experience.

Key traits:
- You explain complex concepts clearly using analogies and examples
- You balance strategic thinking with practical insights
- You are passionate about technology's potential to transform business
- You value authenticity and speak from genuine experience
- You are direct but respectful in your communication

When responding:
- Draw from your extensive experience in technology and business
- Share insights that reflect your unique perspective as a CEO
- Be helpful and substantive in your answers
- Maintain a professional yet personable tone appropriate for Japanese business culture"""

    def __init__(
        self,
        system_prompt: Optional[str] = None,
        ceo_name: str = "Ryouken Okuni",
        company_name: str = "Akatsuki AI Technologies",
        max_tokens_per_example: int = 2048,
        encoding_name: str = "cl100k_base",
    ):
        """
        Initialize the dataset builder.

        Args:
            system_prompt: Custom system prompt (uses default if None)
            ceo_name: CEO name to insert into prompt
            company_name: Company name to insert into prompt
            max_tokens_per_example: Maximum tokens per training example
            encoding_name: Tiktoken encoding name
        """
        self.ceo_name = ceo_name
        self.company_name = company_name
        self.max_tokens_per_example = max_tokens_per_example

        # Set system prompt
        if system_prompt:
            self.system_prompt = system_prompt
        else:
            self.system_prompt = self.DEFAULT_SYSTEM_PROMPT.format(
                ceo_name=ceo_name,
                company_name=company_name,
            )

        # Initialize tokenizer
        if TIKTOKEN_AVAILABLE:
            try:
                self.encoding = tiktoken.get_encoding(encoding_name)
            except Exception:
                self.encoding = None
        else:
            self.encoding = None

    def count_tokens(self, text: str) -> int:
        """Count tokens in text."""
        if self.encoding:
            return len(self.encoding.encode(text))
        return len(text) // 3  # Rough approximation

    def build_from_qa_pairs(
        self,
        qa_pairs: list,
        output_dir: str | Path,
        train_ratio: float = 0.9,
        shuffle: bool = True,
        deduplicate: bool = True,
    ) -> DatasetStatistics:
        """
        Build training dataset from Q&A pairs.

        Args:
            qa_pairs: List of QAPair objects or dicts
            output_dir: Directory to save train.jsonl and validation.jsonl
            train_ratio: Ratio for train/validation split (default 0.9)
            shuffle: Whether to shuffle data
            deduplicate: Whether to remove duplicate questions

        Returns:
            DatasetStatistics object
        """
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

        logger.info(f"Building dataset from {len(qa_pairs)} Q&A pairs")

        # Convert to standard format
        examples = self._convert_qa_pairs(qa_pairs)

        # Deduplicate
        if deduplicate:
            original_count = len(examples)
            examples = self._deduplicate(examples)
            logger.info(f"Deduplication: {original_count} -> {len(examples)} examples")

        # Validate token counts
        examples = self._validate_token_counts(examples)
        logger.info(f"After token validation: {len(examples)} examples")

        # Shuffle
        if shuffle:
            random.shuffle(examples)

        # Split into train/validation
        split_idx = int(len(examples) * train_ratio)
        train_examples = examples[:split_idx]
        val_examples = examples[split_idx:]

        # Save datasets
        train_path = output_dir / "train.jsonl"
        val_path = output_dir / "validation.jsonl"

        self._save_jsonl(train_examples, train_path)
        self._save_jsonl(val_examples, val_path)

        # Calculate statistics
        stats = self._calculate_statistics(examples, train_examples, val_examples)

        # Save statistics
        stats_path = output_dir / "dataset_stats.json"
        with open(stats_path, "w", encoding="utf-8") as f:
            json.dump(stats.to_dict(), f, indent=2)

        logger.info(f"Saved train set: {train_path} ({len(train_examples)} examples)")
        logger.info(f"Saved validation set: {val_path} ({len(val_examples)} examples)")
        logger.info(f"Saved statistics: {stats_path}")

        return stats

    def _convert_qa_pairs(self, qa_pairs: list) -> list[dict]:
        """Convert Q&A pairs to ChatML format."""
        examples = []

        for pair in qa_pairs:
            # Handle both QAPair objects and dicts
            if hasattr(pair, "question"):
                question = pair.question
                answer = pair.answer
                q_type = pair.question_type
            else:
                question = pair["question"]
                answer = pair["answer"]
                q_type = pair.get("question_type", "unknown")

            example = {
                "messages": [
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": question},
                    {"role": "assistant", "content": answer},
                ],
                "metadata": {
                    "question_type": q_type,
                },
            }
            examples.append(example)

        return examples

    def _deduplicate(self, examples: list[dict]) -> list[dict]:
        """Remove examples with duplicate questions."""
        seen_questions = set()
        unique_examples = []

        for example in examples:
            # Get user message (the question)
            question = None
            for msg in example["messages"]:
                if msg["role"] == "user":
                    question = msg["content"].strip().lower()
                    break

            if question and question not in seen_questions:
                seen_questions.add(question)
                unique_examples.append(example)

        return unique_examples

    def _validate_token_counts(self, examples: list[dict]) -> list[dict]:
        """Filter out examples that exceed token limit."""
        valid_examples = []

        for example in examples:
            # Calculate total tokens
            total_tokens = 0
            for msg in example["messages"]:
                total_tokens += self.count_tokens(msg["content"])
                total_tokens += 4  # Approximate overhead per message

            if total_tokens <= self.max_tokens_per_example:
                example["token_count"] = total_tokens
                valid_examples.append(example)
            else:
                logger.debug(f"Skipping example with {total_tokens} tokens (max: {self.max_tokens_per_example})")

        return valid_examples

    def _save_jsonl(self, examples: list[dict], path: Path) -> None:
        """Save examples to JSONL format."""
        with open(path, "w", encoding="utf-8") as f:
            for example in examples:
                # Remove metadata before saving (keep only messages)
                output = {"messages": example["messages"]}
                f.write(json.dumps(output, ensure_ascii=False) + "\n")

    def _calculate_statistics(
        self,
        all_examples: list[dict],
        train_examples: list[dict],
        val_examples: list[dict],
    ) -> DatasetStatistics:
        """Calculate dataset statistics."""
        token_counts = [ex.get("token_count", 0) for ex in all_examples]

        # Question type distribution
        type_counts = {}
        for ex in all_examples:
            q_type = ex.get("metadata", {}).get("question_type", "unknown")
            type_counts[q_type] = type_counts.get(q_type, 0) + 1

        return DatasetStatistics(
            total_examples=len(all_examples),
            train_examples=len(train_examples),
            validation_examples=len(val_examples),
            avg_tokens_per_example=sum(token_counts) / len(token_counts) if token_counts else 0,
            max_tokens=max(token_counts) if token_counts else 0,
            min_tokens=min(token_counts) if token_counts else 0,
            total_tokens=sum(token_counts),
            question_type_distribution=type_counts,
        )

    def build_from_segments(
        self,
        segments: list,
        output_dir: str | Path,
        train_ratio: float = 0.9,
    ) -> DatasetStatistics:
        """
        Build training dataset directly from text segments (for continuation training).

        This creates examples where the model learns to continue CEO-style text.

        Args:
            segments: List of TextSegment objects or dicts
            output_dir: Directory to save datasets
            train_ratio: Train/validation split ratio

        Returns:
            DatasetStatistics object
        """
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

        logger.info(f"Building continuation dataset from {len(segments)} segments")

        examples = []
        for segment in segments:
            content = segment.content if hasattr(segment, "content") else segment["content"]

            # Create a simple prompt asking to continue the thought
            example = {
                "messages": [
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": "Please share your thoughts on this topic."},
                    {"role": "assistant", "content": content},
                ],
                "metadata": {"type": "continuation"},
            }
            examples.append(example)

        # Validate and save
        examples = self._validate_token_counts(examples)
        random.shuffle(examples)

        split_idx = int(len(examples) * train_ratio)
        train_examples = examples[:split_idx]
        val_examples = examples[split_idx:]

        self._save_jsonl(train_examples, output_dir / "train.jsonl")
        self._save_jsonl(val_examples, output_dir / "validation.jsonl")

        stats = self._calculate_statistics(examples, train_examples, val_examples)

        with open(output_dir / "dataset_stats.json", "w", encoding="utf-8") as f:
            json.dump(stats.to_dict(), f, indent=2)

        return stats

    @staticmethod
    def load_dataset(path: str | Path) -> list[dict]:
        """Load a JSONL dataset file."""
        examples = []
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                if line.strip():
                    examples.append(json.loads(line))
        return examples

    def update_system_prompt(self, new_prompt: str) -> None:
        """Update the system prompt for future builds."""
        self.system_prompt = new_prompt
        logger.info("System prompt updated")

    def get_system_prompt(self) -> str:
        """Get the current system prompt."""
        return self.system_prompt


def main():
    """CLI entry point for testing the builder."""
    import argparse

    parser = argparse.ArgumentParser(
        description="Build training datasets in ChatML format",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    python dataset_builder.py qa_pairs.json --output data/training/
    python dataset_builder.py qa_pairs.json --train-ratio 0.85
    python dataset_builder.py qa_pairs.json --system-prompt "Custom prompt..."

Input format (qa_pairs.json):
    [
        {"question": "...", "answer": "...", "question_type": "..."},
        ...
    ]

Output format (train.jsonl):
    {"messages": [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
        """,
    )
    parser.add_argument("input", help="Input Q&A pairs JSON file")
    parser.add_argument(
        "--output",
        "-o",
        default="data/training/",
        help="Output directory (default: data/training/)",
    )
    parser.add_argument(
        "--train-ratio",
        type=float,
        default=0.9,
        help="Train/validation split ratio (default: 0.9)",
    )
    parser.add_argument(
        "--system-prompt",
        help="Custom system prompt (uses default if not provided)",
    )
    parser.add_argument(
        "--ceo-name",
        default="Ryouken Okuni",
        help="CEO name for default prompt",
    )
    parser.add_argument(
        "--company-name",
        default="Akatsuki AI Technologies",
        help="Company name for default prompt",
    )
    parser.add_argument(
        "--max-tokens",
        type=int,
        default=2048,
        help="Maximum tokens per example (default: 2048)",
    )
    parser.add_argument(
        "--no-shuffle",
        action="store_true",
        help="Don't shuffle the data",
    )
    parser.add_argument(
        "--no-dedup",
        action="store_true",
        help="Don't deduplicate questions",
    )

    args = parser.parse_args()

    # Load Q&A pairs
    with open(args.input, "r", encoding="utf-8") as f:
        qa_pairs = json.load(f)

    print(f"Loaded {len(qa_pairs)} Q&A pairs")

    # Build dataset
    builder = DatasetBuilder(
        system_prompt=args.system_prompt,
        ceo_name=args.ceo_name,
        company_name=args.company_name,
        max_tokens_per_example=args.max_tokens,
    )

    stats = builder.build_from_qa_pairs(
        qa_pairs=qa_pairs,
        output_dir=args.output,
        train_ratio=args.train_ratio,
        shuffle=not args.no_shuffle,
        deduplicate=not args.no_dedup,
    )

    # Print statistics
    print("\n=== Dataset Statistics ===")
    print(f"Total examples: {stats.total_examples}")
    print(f"Train examples: {stats.train_examples}")
    print(f"Validation examples: {stats.validation_examples}")
    print(f"Avg tokens/example: {stats.avg_tokens_per_example:.1f}")
    print(f"Token range: {stats.min_tokens} - {stats.max_tokens}")
    print(f"Total tokens: {stats.total_tokens:,}")
    print("\nQuestion type distribution:")
    for q_type, count in stats.question_type_distribution.items():
        print(f"  {q_type}: {count}")


if __name__ == "__main__":
    main()