Spaces:
Paused
Paused
| """ | |
| Dataset Builder Module | |
| Build final training dataset in ChatML format for Qwen3 fine-tuning. | |
| Creates train/validation splits with proper formatting. | |
| Example usage: | |
| builder = DatasetBuilder(system_prompt="You are Ryouken Okuni...") | |
| builder.build_from_qa_pairs(qa_pairs, output_dir="data/training/") | |
| """ | |
| import json | |
| import random | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Optional | |
| from loguru import logger | |
| try: | |
| import tiktoken | |
| TIKTOKEN_AVAILABLE = True | |
| except ImportError: | |
| TIKTOKEN_AVAILABLE = False | |
| class DatasetStatistics: | |
| """Statistics about the built dataset.""" | |
| total_examples: int | |
| train_examples: int | |
| validation_examples: int | |
| avg_tokens_per_example: float | |
| max_tokens: int | |
| min_tokens: int | |
| total_tokens: int | |
| question_type_distribution: dict | |
| def to_dict(self) -> dict: | |
| """Convert to dictionary for serialization.""" | |
| return { | |
| "total_examples": self.total_examples, | |
| "train_examples": self.train_examples, | |
| "validation_examples": self.validation_examples, | |
| "avg_tokens_per_example": round(self.avg_tokens_per_example, 2), | |
| "max_tokens": self.max_tokens, | |
| "min_tokens": self.min_tokens, | |
| "total_tokens": self.total_tokens, | |
| "question_type_distribution": self.question_type_distribution, | |
| } | |
| class DatasetBuilder: | |
| """ | |
| Build training datasets in ChatML format for Qwen3. | |
| Features: | |
| - ChatML message format | |
| - Train/validation split | |
| - Deduplication | |
| - Token count validation | |
| - Statistics generation | |
| Example: | |
| >>> builder = DatasetBuilder() | |
| >>> stats = builder.build_from_qa_pairs(qa_pairs, "data/training/") | |
| >>> print(f"Built {stats.total_examples} examples") | |
| """ | |
| # Default system prompt template | |
| DEFAULT_SYSTEM_PROMPT = """You are {ceo_name}, CEO of {company_name}. | |
| You are a visionary technology leader with deep expertise in AI, business strategy, and innovation. Your communication style is thoughtful, confident, and grounded in real-world experience. | |
| Key traits: | |
| - You explain complex concepts clearly using analogies and examples | |
| - You balance strategic thinking with practical insights | |
| - You are passionate about technology's potential to transform business | |
| - You value authenticity and speak from genuine experience | |
| - You are direct but respectful in your communication | |
| When responding: | |
| - Draw from your extensive experience in technology and business | |
| - Share insights that reflect your unique perspective as a CEO | |
| - Be helpful and substantive in your answers | |
| - Maintain a professional yet personable tone appropriate for Japanese business culture""" | |
| def __init__( | |
| self, | |
| system_prompt: Optional[str] = None, | |
| ceo_name: str = "Ryouken Okuni", | |
| company_name: str = "Akatsuki AI Technologies", | |
| max_tokens_per_example: int = 2048, | |
| encoding_name: str = "cl100k_base", | |
| ): | |
| """ | |
| Initialize the dataset builder. | |
| Args: | |
| system_prompt: Custom system prompt (uses default if None) | |
| ceo_name: CEO name to insert into prompt | |
| company_name: Company name to insert into prompt | |
| max_tokens_per_example: Maximum tokens per training example | |
| encoding_name: Tiktoken encoding name | |
| """ | |
| self.ceo_name = ceo_name | |
| self.company_name = company_name | |
| self.max_tokens_per_example = max_tokens_per_example | |
| # Set system prompt | |
| if system_prompt: | |
| self.system_prompt = system_prompt | |
| else: | |
| self.system_prompt = self.DEFAULT_SYSTEM_PROMPT.format( | |
| ceo_name=ceo_name, | |
| company_name=company_name, | |
| ) | |
| # Initialize tokenizer | |
| if TIKTOKEN_AVAILABLE: | |
| try: | |
| self.encoding = tiktoken.get_encoding(encoding_name) | |
| except Exception: | |
| self.encoding = None | |
| else: | |
| self.encoding = None | |
| def count_tokens(self, text: str) -> int: | |
| """Count tokens in text.""" | |
| if self.encoding: | |
| return len(self.encoding.encode(text)) | |
| return len(text) // 3 # Rough approximation | |
| def build_from_qa_pairs( | |
| self, | |
| qa_pairs: list, | |
| output_dir: str | Path, | |
| train_ratio: float = 0.9, | |
| shuffle: bool = True, | |
| deduplicate: bool = True, | |
| ) -> DatasetStatistics: | |
| """ | |
| Build training dataset from Q&A pairs. | |
| Args: | |
| qa_pairs: List of QAPair objects or dicts | |
| output_dir: Directory to save train.jsonl and validation.jsonl | |
| train_ratio: Ratio for train/validation split (default 0.9) | |
| shuffle: Whether to shuffle data | |
| deduplicate: Whether to remove duplicate questions | |
| Returns: | |
| DatasetStatistics object | |
| """ | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| logger.info(f"Building dataset from {len(qa_pairs)} Q&A pairs") | |
| # Convert to standard format | |
| examples = self._convert_qa_pairs(qa_pairs) | |
| # Deduplicate | |
| if deduplicate: | |
| original_count = len(examples) | |
| examples = self._deduplicate(examples) | |
| logger.info(f"Deduplication: {original_count} -> {len(examples)} examples") | |
| # Validate token counts | |
| examples = self._validate_token_counts(examples) | |
| logger.info(f"After token validation: {len(examples)} examples") | |
| # Shuffle | |
| if shuffle: | |
| random.shuffle(examples) | |
| # Split into train/validation | |
| split_idx = int(len(examples) * train_ratio) | |
| train_examples = examples[:split_idx] | |
| val_examples = examples[split_idx:] | |
| # Save datasets | |
| train_path = output_dir / "train.jsonl" | |
| val_path = output_dir / "validation.jsonl" | |
| self._save_jsonl(train_examples, train_path) | |
| self._save_jsonl(val_examples, val_path) | |
| # Calculate statistics | |
| stats = self._calculate_statistics(examples, train_examples, val_examples) | |
| # Save statistics | |
| stats_path = output_dir / "dataset_stats.json" | |
| with open(stats_path, "w", encoding="utf-8") as f: | |
| json.dump(stats.to_dict(), f, indent=2) | |
| logger.info(f"Saved train set: {train_path} ({len(train_examples)} examples)") | |
| logger.info(f"Saved validation set: {val_path} ({len(val_examples)} examples)") | |
| logger.info(f"Saved statistics: {stats_path}") | |
| return stats | |
| def _convert_qa_pairs(self, qa_pairs: list) -> list[dict]: | |
| """Convert Q&A pairs to ChatML format.""" | |
| examples = [] | |
| for pair in qa_pairs: | |
| # Handle both QAPair objects and dicts | |
| if hasattr(pair, "question"): | |
| question = pair.question | |
| answer = pair.answer | |
| q_type = pair.question_type | |
| else: | |
| question = pair["question"] | |
| answer = pair["answer"] | |
| q_type = pair.get("question_type", "unknown") | |
| example = { | |
| "messages": [ | |
| {"role": "system", "content": self.system_prompt}, | |
| {"role": "user", "content": question}, | |
| {"role": "assistant", "content": answer}, | |
| ], | |
| "metadata": { | |
| "question_type": q_type, | |
| }, | |
| } | |
| examples.append(example) | |
| return examples | |
| def _deduplicate(self, examples: list[dict]) -> list[dict]: | |
| """Remove examples with duplicate questions.""" | |
| seen_questions = set() | |
| unique_examples = [] | |
| for example in examples: | |
| # Get user message (the question) | |
| question = None | |
| for msg in example["messages"]: | |
| if msg["role"] == "user": | |
| question = msg["content"].strip().lower() | |
| break | |
| if question and question not in seen_questions: | |
| seen_questions.add(question) | |
| unique_examples.append(example) | |
| return unique_examples | |
| def _validate_token_counts(self, examples: list[dict]) -> list[dict]: | |
| """Filter out examples that exceed token limit.""" | |
| valid_examples = [] | |
| for example in examples: | |
| # Calculate total tokens | |
| total_tokens = 0 | |
| for msg in example["messages"]: | |
| total_tokens += self.count_tokens(msg["content"]) | |
| total_tokens += 4 # Approximate overhead per message | |
| if total_tokens <= self.max_tokens_per_example: | |
| example["token_count"] = total_tokens | |
| valid_examples.append(example) | |
| else: | |
| logger.debug(f"Skipping example with {total_tokens} tokens (max: {self.max_tokens_per_example})") | |
| return valid_examples | |
| def _save_jsonl(self, examples: list[dict], path: Path) -> None: | |
| """Save examples to JSONL format.""" | |
| with open(path, "w", encoding="utf-8") as f: | |
| for example in examples: | |
| # Remove metadata before saving (keep only messages) | |
| output = {"messages": example["messages"]} | |
| f.write(json.dumps(output, ensure_ascii=False) + "\n") | |
| def _calculate_statistics( | |
| self, | |
| all_examples: list[dict], | |
| train_examples: list[dict], | |
| val_examples: list[dict], | |
| ) -> DatasetStatistics: | |
| """Calculate dataset statistics.""" | |
| token_counts = [ex.get("token_count", 0) for ex in all_examples] | |
| # Question type distribution | |
| type_counts = {} | |
| for ex in all_examples: | |
| q_type = ex.get("metadata", {}).get("question_type", "unknown") | |
| type_counts[q_type] = type_counts.get(q_type, 0) + 1 | |
| return DatasetStatistics( | |
| total_examples=len(all_examples), | |
| train_examples=len(train_examples), | |
| validation_examples=len(val_examples), | |
| avg_tokens_per_example=sum(token_counts) / len(token_counts) if token_counts else 0, | |
| max_tokens=max(token_counts) if token_counts else 0, | |
| min_tokens=min(token_counts) if token_counts else 0, | |
| total_tokens=sum(token_counts), | |
| question_type_distribution=type_counts, | |
| ) | |
| def build_from_segments( | |
| self, | |
| segments: list, | |
| output_dir: str | Path, | |
| train_ratio: float = 0.9, | |
| ) -> DatasetStatistics: | |
| """ | |
| Build training dataset directly from text segments (for continuation training). | |
| This creates examples where the model learns to continue CEO-style text. | |
| Args: | |
| segments: List of TextSegment objects or dicts | |
| output_dir: Directory to save datasets | |
| train_ratio: Train/validation split ratio | |
| Returns: | |
| DatasetStatistics object | |
| """ | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| logger.info(f"Building continuation dataset from {len(segments)} segments") | |
| examples = [] | |
| for segment in segments: | |
| content = segment.content if hasattr(segment, "content") else segment["content"] | |
| # Create a simple prompt asking to continue the thought | |
| example = { | |
| "messages": [ | |
| {"role": "system", "content": self.system_prompt}, | |
| {"role": "user", "content": "Please share your thoughts on this topic."}, | |
| {"role": "assistant", "content": content}, | |
| ], | |
| "metadata": {"type": "continuation"}, | |
| } | |
| examples.append(example) | |
| # Validate and save | |
| examples = self._validate_token_counts(examples) | |
| random.shuffle(examples) | |
| split_idx = int(len(examples) * train_ratio) | |
| train_examples = examples[:split_idx] | |
| val_examples = examples[split_idx:] | |
| self._save_jsonl(train_examples, output_dir / "train.jsonl") | |
| self._save_jsonl(val_examples, output_dir / "validation.jsonl") | |
| stats = self._calculate_statistics(examples, train_examples, val_examples) | |
| with open(output_dir / "dataset_stats.json", "w", encoding="utf-8") as f: | |
| json.dump(stats.to_dict(), f, indent=2) | |
| return stats | |
| def load_dataset(path: str | Path) -> list[dict]: | |
| """Load a JSONL dataset file.""" | |
| examples = [] | |
| with open(path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| if line.strip(): | |
| examples.append(json.loads(line)) | |
| return examples | |
| def update_system_prompt(self, new_prompt: str) -> None: | |
| """Update the system prompt for future builds.""" | |
| self.system_prompt = new_prompt | |
| logger.info("System prompt updated") | |
| def get_system_prompt(self) -> str: | |
| """Get the current system prompt.""" | |
| return self.system_prompt | |
| def main(): | |
| """CLI entry point for testing the builder.""" | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| description="Build training datasets in ChatML format", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| python dataset_builder.py qa_pairs.json --output data/training/ | |
| python dataset_builder.py qa_pairs.json --train-ratio 0.85 | |
| python dataset_builder.py qa_pairs.json --system-prompt "Custom prompt..." | |
| Input format (qa_pairs.json): | |
| [ | |
| {"question": "...", "answer": "...", "question_type": "..."}, | |
| ... | |
| ] | |
| Output format (train.jsonl): | |
| {"messages": [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]} | |
| """, | |
| ) | |
| parser.add_argument("input", help="Input Q&A pairs JSON file") | |
| parser.add_argument( | |
| "--output", | |
| "-o", | |
| default="data/training/", | |
| help="Output directory (default: data/training/)", | |
| ) | |
| parser.add_argument( | |
| "--train-ratio", | |
| type=float, | |
| default=0.9, | |
| help="Train/validation split ratio (default: 0.9)", | |
| ) | |
| parser.add_argument( | |
| "--system-prompt", | |
| help="Custom system prompt (uses default if not provided)", | |
| ) | |
| parser.add_argument( | |
| "--ceo-name", | |
| default="Ryouken Okuni", | |
| help="CEO name for default prompt", | |
| ) | |
| parser.add_argument( | |
| "--company-name", | |
| default="Akatsuki AI Technologies", | |
| help="Company name for default prompt", | |
| ) | |
| parser.add_argument( | |
| "--max-tokens", | |
| type=int, | |
| default=2048, | |
| help="Maximum tokens per example (default: 2048)", | |
| ) | |
| parser.add_argument( | |
| "--no-shuffle", | |
| action="store_true", | |
| help="Don't shuffle the data", | |
| ) | |
| parser.add_argument( | |
| "--no-dedup", | |
| action="store_true", | |
| help="Don't deduplicate questions", | |
| ) | |
| args = parser.parse_args() | |
| # Load Q&A pairs | |
| with open(args.input, "r", encoding="utf-8") as f: | |
| qa_pairs = json.load(f) | |
| print(f"Loaded {len(qa_pairs)} Q&A pairs") | |
| # Build dataset | |
| builder = DatasetBuilder( | |
| system_prompt=args.system_prompt, | |
| ceo_name=args.ceo_name, | |
| company_name=args.company_name, | |
| max_tokens_per_example=args.max_tokens, | |
| ) | |
| stats = builder.build_from_qa_pairs( | |
| qa_pairs=qa_pairs, | |
| output_dir=args.output, | |
| train_ratio=args.train_ratio, | |
| shuffle=not args.no_shuffle, | |
| deduplicate=not args.no_dedup, | |
| ) | |
| # Print statistics | |
| print("\n=== Dataset Statistics ===") | |
| print(f"Total examples: {stats.total_examples}") | |
| print(f"Train examples: {stats.train_examples}") | |
| print(f"Validation examples: {stats.validation_examples}") | |
| print(f"Avg tokens/example: {stats.avg_tokens_per_example:.1f}") | |
| print(f"Token range: {stats.min_tokens} - {stats.max_tokens}") | |
| print(f"Total tokens: {stats.total_tokens:,}") | |
| print("\nQuestion type distribution:") | |
| for q_type, count in stats.question_type_distribution.items(): | |
| print(f" {q_type}: {count}") | |
| if __name__ == "__main__": | |
| main() | |