""" Dataset Builder Module Build final training dataset in ChatML format for Qwen3 fine-tuning. Creates train/validation splits with proper formatting. Example usage: builder = DatasetBuilder(system_prompt="You are Ryouken Okuni...") builder.build_from_qa_pairs(qa_pairs, output_dir="data/training/") """ import json import random from dataclasses import dataclass from pathlib import Path from typing import Optional from loguru import logger try: import tiktoken TIKTOKEN_AVAILABLE = True except ImportError: TIKTOKEN_AVAILABLE = False @dataclass class DatasetStatistics: """Statistics about the built dataset.""" total_examples: int train_examples: int validation_examples: int avg_tokens_per_example: float max_tokens: int min_tokens: int total_tokens: int question_type_distribution: dict def to_dict(self) -> dict: """Convert to dictionary for serialization.""" return { "total_examples": self.total_examples, "train_examples": self.train_examples, "validation_examples": self.validation_examples, "avg_tokens_per_example": round(self.avg_tokens_per_example, 2), "max_tokens": self.max_tokens, "min_tokens": self.min_tokens, "total_tokens": self.total_tokens, "question_type_distribution": self.question_type_distribution, } class DatasetBuilder: """ Build training datasets in ChatML format for Qwen3. Features: - ChatML message format - Train/validation split - Deduplication - Token count validation - Statistics generation Example: >>> builder = DatasetBuilder() >>> stats = builder.build_from_qa_pairs(qa_pairs, "data/training/") >>> print(f"Built {stats.total_examples} examples") """ # Default system prompt template DEFAULT_SYSTEM_PROMPT = """You are {ceo_name}, CEO of {company_name}. You are a visionary technology leader with deep expertise in AI, business strategy, and innovation. Your communication style is thoughtful, confident, and grounded in real-world experience. Key traits: - You explain complex concepts clearly using analogies and examples - You balance strategic thinking with practical insights - You are passionate about technology's potential to transform business - You value authenticity and speak from genuine experience - You are direct but respectful in your communication When responding: - Draw from your extensive experience in technology and business - Share insights that reflect your unique perspective as a CEO - Be helpful and substantive in your answers - Maintain a professional yet personable tone appropriate for Japanese business culture""" def __init__( self, system_prompt: Optional[str] = None, ceo_name: str = "Ryouken Okuni", company_name: str = "Akatsuki AI Technologies", max_tokens_per_example: int = 2048, encoding_name: str = "cl100k_base", ): """ Initialize the dataset builder. Args: system_prompt: Custom system prompt (uses default if None) ceo_name: CEO name to insert into prompt company_name: Company name to insert into prompt max_tokens_per_example: Maximum tokens per training example encoding_name: Tiktoken encoding name """ self.ceo_name = ceo_name self.company_name = company_name self.max_tokens_per_example = max_tokens_per_example # Set system prompt if system_prompt: self.system_prompt = system_prompt else: self.system_prompt = self.DEFAULT_SYSTEM_PROMPT.format( ceo_name=ceo_name, company_name=company_name, ) # Initialize tokenizer if TIKTOKEN_AVAILABLE: try: self.encoding = tiktoken.get_encoding(encoding_name) except Exception: self.encoding = None else: self.encoding = None def count_tokens(self, text: str) -> int: """Count tokens in text.""" if self.encoding: return len(self.encoding.encode(text)) return len(text) // 3 # Rough approximation def build_from_qa_pairs( self, qa_pairs: list, output_dir: str | Path, train_ratio: float = 0.9, shuffle: bool = True, deduplicate: bool = True, ) -> DatasetStatistics: """ Build training dataset from Q&A pairs. Args: qa_pairs: List of QAPair objects or dicts output_dir: Directory to save train.jsonl and validation.jsonl train_ratio: Ratio for train/validation split (default 0.9) shuffle: Whether to shuffle data deduplicate: Whether to remove duplicate questions Returns: DatasetStatistics object """ output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Building dataset from {len(qa_pairs)} Q&A pairs") # Convert to standard format examples = self._convert_qa_pairs(qa_pairs) # Deduplicate if deduplicate: original_count = len(examples) examples = self._deduplicate(examples) logger.info(f"Deduplication: {original_count} -> {len(examples)} examples") # Validate token counts examples = self._validate_token_counts(examples) logger.info(f"After token validation: {len(examples)} examples") # Shuffle if shuffle: random.shuffle(examples) # Split into train/validation split_idx = int(len(examples) * train_ratio) train_examples = examples[:split_idx] val_examples = examples[split_idx:] # Save datasets train_path = output_dir / "train.jsonl" val_path = output_dir / "validation.jsonl" self._save_jsonl(train_examples, train_path) self._save_jsonl(val_examples, val_path) # Calculate statistics stats = self._calculate_statistics(examples, train_examples, val_examples) # Save statistics stats_path = output_dir / "dataset_stats.json" with open(stats_path, "w", encoding="utf-8") as f: json.dump(stats.to_dict(), f, indent=2) logger.info(f"Saved train set: {train_path} ({len(train_examples)} examples)") logger.info(f"Saved validation set: {val_path} ({len(val_examples)} examples)") logger.info(f"Saved statistics: {stats_path}") return stats def _convert_qa_pairs(self, qa_pairs: list) -> list[dict]: """Convert Q&A pairs to ChatML format.""" examples = [] for pair in qa_pairs: # Handle both QAPair objects and dicts if hasattr(pair, "question"): question = pair.question answer = pair.answer q_type = pair.question_type else: question = pair["question"] answer = pair["answer"] q_type = pair.get("question_type", "unknown") example = { "messages": [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": question}, {"role": "assistant", "content": answer}, ], "metadata": { "question_type": q_type, }, } examples.append(example) return examples def _deduplicate(self, examples: list[dict]) -> list[dict]: """Remove examples with duplicate questions.""" seen_questions = set() unique_examples = [] for example in examples: # Get user message (the question) question = None for msg in example["messages"]: if msg["role"] == "user": question = msg["content"].strip().lower() break if question and question not in seen_questions: seen_questions.add(question) unique_examples.append(example) return unique_examples def _validate_token_counts(self, examples: list[dict]) -> list[dict]: """Filter out examples that exceed token limit.""" valid_examples = [] for example in examples: # Calculate total tokens total_tokens = 0 for msg in example["messages"]: total_tokens += self.count_tokens(msg["content"]) total_tokens += 4 # Approximate overhead per message if total_tokens <= self.max_tokens_per_example: example["token_count"] = total_tokens valid_examples.append(example) else: logger.debug(f"Skipping example with {total_tokens} tokens (max: {self.max_tokens_per_example})") return valid_examples def _save_jsonl(self, examples: list[dict], path: Path) -> None: """Save examples to JSONL format.""" with open(path, "w", encoding="utf-8") as f: for example in examples: # Remove metadata before saving (keep only messages) output = {"messages": example["messages"]} f.write(json.dumps(output, ensure_ascii=False) + "\n") def _calculate_statistics( self, all_examples: list[dict], train_examples: list[dict], val_examples: list[dict], ) -> DatasetStatistics: """Calculate dataset statistics.""" token_counts = [ex.get("token_count", 0) for ex in all_examples] # Question type distribution type_counts = {} for ex in all_examples: q_type = ex.get("metadata", {}).get("question_type", "unknown") type_counts[q_type] = type_counts.get(q_type, 0) + 1 return DatasetStatistics( total_examples=len(all_examples), train_examples=len(train_examples), validation_examples=len(val_examples), avg_tokens_per_example=sum(token_counts) / len(token_counts) if token_counts else 0, max_tokens=max(token_counts) if token_counts else 0, min_tokens=min(token_counts) if token_counts else 0, total_tokens=sum(token_counts), question_type_distribution=type_counts, ) def build_from_segments( self, segments: list, output_dir: str | Path, train_ratio: float = 0.9, ) -> DatasetStatistics: """ Build training dataset directly from text segments (for continuation training). This creates examples where the model learns to continue CEO-style text. Args: segments: List of TextSegment objects or dicts output_dir: Directory to save datasets train_ratio: Train/validation split ratio Returns: DatasetStatistics object """ output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Building continuation dataset from {len(segments)} segments") examples = [] for segment in segments: content = segment.content if hasattr(segment, "content") else segment["content"] # Create a simple prompt asking to continue the thought example = { "messages": [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": "Please share your thoughts on this topic."}, {"role": "assistant", "content": content}, ], "metadata": {"type": "continuation"}, } examples.append(example) # Validate and save examples = self._validate_token_counts(examples) random.shuffle(examples) split_idx = int(len(examples) * train_ratio) train_examples = examples[:split_idx] val_examples = examples[split_idx:] self._save_jsonl(train_examples, output_dir / "train.jsonl") self._save_jsonl(val_examples, output_dir / "validation.jsonl") stats = self._calculate_statistics(examples, train_examples, val_examples) with open(output_dir / "dataset_stats.json", "w", encoding="utf-8") as f: json.dump(stats.to_dict(), f, indent=2) return stats @staticmethod def load_dataset(path: str | Path) -> list[dict]: """Load a JSONL dataset file.""" examples = [] with open(path, "r", encoding="utf-8") as f: for line in f: if line.strip(): examples.append(json.loads(line)) return examples def update_system_prompt(self, new_prompt: str) -> None: """Update the system prompt for future builds.""" self.system_prompt = new_prompt logger.info("System prompt updated") def get_system_prompt(self) -> str: """Get the current system prompt.""" return self.system_prompt def main(): """CLI entry point for testing the builder.""" import argparse parser = argparse.ArgumentParser( description="Build training datasets in ChatML format", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python dataset_builder.py qa_pairs.json --output data/training/ python dataset_builder.py qa_pairs.json --train-ratio 0.85 python dataset_builder.py qa_pairs.json --system-prompt "Custom prompt..." Input format (qa_pairs.json): [ {"question": "...", "answer": "...", "question_type": "..."}, ... ] Output format (train.jsonl): {"messages": [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]} """, ) parser.add_argument("input", help="Input Q&A pairs JSON file") parser.add_argument( "--output", "-o", default="data/training/", help="Output directory (default: data/training/)", ) parser.add_argument( "--train-ratio", type=float, default=0.9, help="Train/validation split ratio (default: 0.9)", ) parser.add_argument( "--system-prompt", help="Custom system prompt (uses default if not provided)", ) parser.add_argument( "--ceo-name", default="Ryouken Okuni", help="CEO name for default prompt", ) parser.add_argument( "--company-name", default="Akatsuki AI Technologies", help="Company name for default prompt", ) parser.add_argument( "--max-tokens", type=int, default=2048, help="Maximum tokens per example (default: 2048)", ) parser.add_argument( "--no-shuffle", action="store_true", help="Don't shuffle the data", ) parser.add_argument( "--no-dedup", action="store_true", help="Don't deduplicate questions", ) args = parser.parse_args() # Load Q&A pairs with open(args.input, "r", encoding="utf-8") as f: qa_pairs = json.load(f) print(f"Loaded {len(qa_pairs)} Q&A pairs") # Build dataset builder = DatasetBuilder( system_prompt=args.system_prompt, ceo_name=args.ceo_name, company_name=args.company_name, max_tokens_per_example=args.max_tokens, ) stats = builder.build_from_qa_pairs( qa_pairs=qa_pairs, output_dir=args.output, train_ratio=args.train_ratio, shuffle=not args.no_shuffle, deduplicate=not args.no_dedup, ) # Print statistics print("\n=== Dataset Statistics ===") print(f"Total examples: {stats.total_examples}") print(f"Train examples: {stats.train_examples}") print(f"Validation examples: {stats.validation_examples}") print(f"Avg tokens/example: {stats.avg_tokens_per_example:.1f}") print(f"Token range: {stats.min_tokens} - {stats.max_tokens}") print(f"Total tokens: {stats.total_tokens:,}") print("\nQuestion type distribution:") for q_type, count in stats.question_type_distribution.items(): print(f" {q_type}: {count}") if __name__ == "__main__": main()