ai_exec / src /data_processing /dataset_builder.py
Chaitanya-aitf's picture
Upload 38 files
45ee481 verified
"""
Dataset Builder Module
Build final training dataset in ChatML format for Qwen3 fine-tuning.
Creates train/validation splits with proper formatting.
Example usage:
builder = DatasetBuilder(system_prompt="You are Ryouken Okuni...")
builder.build_from_qa_pairs(qa_pairs, output_dir="data/training/")
"""
import json
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from loguru import logger
try:
import tiktoken
TIKTOKEN_AVAILABLE = True
except ImportError:
TIKTOKEN_AVAILABLE = False
@dataclass
class DatasetStatistics:
"""Statistics about the built dataset."""
total_examples: int
train_examples: int
validation_examples: int
avg_tokens_per_example: float
max_tokens: int
min_tokens: int
total_tokens: int
question_type_distribution: dict
def to_dict(self) -> dict:
"""Convert to dictionary for serialization."""
return {
"total_examples": self.total_examples,
"train_examples": self.train_examples,
"validation_examples": self.validation_examples,
"avg_tokens_per_example": round(self.avg_tokens_per_example, 2),
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
"total_tokens": self.total_tokens,
"question_type_distribution": self.question_type_distribution,
}
class DatasetBuilder:
"""
Build training datasets in ChatML format for Qwen3.
Features:
- ChatML message format
- Train/validation split
- Deduplication
- Token count validation
- Statistics generation
Example:
>>> builder = DatasetBuilder()
>>> stats = builder.build_from_qa_pairs(qa_pairs, "data/training/")
>>> print(f"Built {stats.total_examples} examples")
"""
# Default system prompt template
DEFAULT_SYSTEM_PROMPT = """You are {ceo_name}, CEO of {company_name}.
You are a visionary technology leader with deep expertise in AI, business strategy, and innovation. Your communication style is thoughtful, confident, and grounded in real-world experience.
Key traits:
- You explain complex concepts clearly using analogies and examples
- You balance strategic thinking with practical insights
- You are passionate about technology's potential to transform business
- You value authenticity and speak from genuine experience
- You are direct but respectful in your communication
When responding:
- Draw from your extensive experience in technology and business
- Share insights that reflect your unique perspective as a CEO
- Be helpful and substantive in your answers
- Maintain a professional yet personable tone appropriate for Japanese business culture"""
def __init__(
self,
system_prompt: Optional[str] = None,
ceo_name: str = "Ryouken Okuni",
company_name: str = "Akatsuki AI Technologies",
max_tokens_per_example: int = 2048,
encoding_name: str = "cl100k_base",
):
"""
Initialize the dataset builder.
Args:
system_prompt: Custom system prompt (uses default if None)
ceo_name: CEO name to insert into prompt
company_name: Company name to insert into prompt
max_tokens_per_example: Maximum tokens per training example
encoding_name: Tiktoken encoding name
"""
self.ceo_name = ceo_name
self.company_name = company_name
self.max_tokens_per_example = max_tokens_per_example
# Set system prompt
if system_prompt:
self.system_prompt = system_prompt
else:
self.system_prompt = self.DEFAULT_SYSTEM_PROMPT.format(
ceo_name=ceo_name,
company_name=company_name,
)
# Initialize tokenizer
if TIKTOKEN_AVAILABLE:
try:
self.encoding = tiktoken.get_encoding(encoding_name)
except Exception:
self.encoding = None
else:
self.encoding = None
def count_tokens(self, text: str) -> int:
"""Count tokens in text."""
if self.encoding:
return len(self.encoding.encode(text))
return len(text) // 3 # Rough approximation
def build_from_qa_pairs(
self,
qa_pairs: list,
output_dir: str | Path,
train_ratio: float = 0.9,
shuffle: bool = True,
deduplicate: bool = True,
) -> DatasetStatistics:
"""
Build training dataset from Q&A pairs.
Args:
qa_pairs: List of QAPair objects or dicts
output_dir: Directory to save train.jsonl and validation.jsonl
train_ratio: Ratio for train/validation split (default 0.9)
shuffle: Whether to shuffle data
deduplicate: Whether to remove duplicate questions
Returns:
DatasetStatistics object
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Building dataset from {len(qa_pairs)} Q&A pairs")
# Convert to standard format
examples = self._convert_qa_pairs(qa_pairs)
# Deduplicate
if deduplicate:
original_count = len(examples)
examples = self._deduplicate(examples)
logger.info(f"Deduplication: {original_count} -> {len(examples)} examples")
# Validate token counts
examples = self._validate_token_counts(examples)
logger.info(f"After token validation: {len(examples)} examples")
# Shuffle
if shuffle:
random.shuffle(examples)
# Split into train/validation
split_idx = int(len(examples) * train_ratio)
train_examples = examples[:split_idx]
val_examples = examples[split_idx:]
# Save datasets
train_path = output_dir / "train.jsonl"
val_path = output_dir / "validation.jsonl"
self._save_jsonl(train_examples, train_path)
self._save_jsonl(val_examples, val_path)
# Calculate statistics
stats = self._calculate_statistics(examples, train_examples, val_examples)
# Save statistics
stats_path = output_dir / "dataset_stats.json"
with open(stats_path, "w", encoding="utf-8") as f:
json.dump(stats.to_dict(), f, indent=2)
logger.info(f"Saved train set: {train_path} ({len(train_examples)} examples)")
logger.info(f"Saved validation set: {val_path} ({len(val_examples)} examples)")
logger.info(f"Saved statistics: {stats_path}")
return stats
def _convert_qa_pairs(self, qa_pairs: list) -> list[dict]:
"""Convert Q&A pairs to ChatML format."""
examples = []
for pair in qa_pairs:
# Handle both QAPair objects and dicts
if hasattr(pair, "question"):
question = pair.question
answer = pair.answer
q_type = pair.question_type
else:
question = pair["question"]
answer = pair["answer"]
q_type = pair.get("question_type", "unknown")
example = {
"messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": question},
{"role": "assistant", "content": answer},
],
"metadata": {
"question_type": q_type,
},
}
examples.append(example)
return examples
def _deduplicate(self, examples: list[dict]) -> list[dict]:
"""Remove examples with duplicate questions."""
seen_questions = set()
unique_examples = []
for example in examples:
# Get user message (the question)
question = None
for msg in example["messages"]:
if msg["role"] == "user":
question = msg["content"].strip().lower()
break
if question and question not in seen_questions:
seen_questions.add(question)
unique_examples.append(example)
return unique_examples
def _validate_token_counts(self, examples: list[dict]) -> list[dict]:
"""Filter out examples that exceed token limit."""
valid_examples = []
for example in examples:
# Calculate total tokens
total_tokens = 0
for msg in example["messages"]:
total_tokens += self.count_tokens(msg["content"])
total_tokens += 4 # Approximate overhead per message
if total_tokens <= self.max_tokens_per_example:
example["token_count"] = total_tokens
valid_examples.append(example)
else:
logger.debug(f"Skipping example with {total_tokens} tokens (max: {self.max_tokens_per_example})")
return valid_examples
def _save_jsonl(self, examples: list[dict], path: Path) -> None:
"""Save examples to JSONL format."""
with open(path, "w", encoding="utf-8") as f:
for example in examples:
# Remove metadata before saving (keep only messages)
output = {"messages": example["messages"]}
f.write(json.dumps(output, ensure_ascii=False) + "\n")
def _calculate_statistics(
self,
all_examples: list[dict],
train_examples: list[dict],
val_examples: list[dict],
) -> DatasetStatistics:
"""Calculate dataset statistics."""
token_counts = [ex.get("token_count", 0) for ex in all_examples]
# Question type distribution
type_counts = {}
for ex in all_examples:
q_type = ex.get("metadata", {}).get("question_type", "unknown")
type_counts[q_type] = type_counts.get(q_type, 0) + 1
return DatasetStatistics(
total_examples=len(all_examples),
train_examples=len(train_examples),
validation_examples=len(val_examples),
avg_tokens_per_example=sum(token_counts) / len(token_counts) if token_counts else 0,
max_tokens=max(token_counts) if token_counts else 0,
min_tokens=min(token_counts) if token_counts else 0,
total_tokens=sum(token_counts),
question_type_distribution=type_counts,
)
def build_from_segments(
self,
segments: list,
output_dir: str | Path,
train_ratio: float = 0.9,
) -> DatasetStatistics:
"""
Build training dataset directly from text segments (for continuation training).
This creates examples where the model learns to continue CEO-style text.
Args:
segments: List of TextSegment objects or dicts
output_dir: Directory to save datasets
train_ratio: Train/validation split ratio
Returns:
DatasetStatistics object
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Building continuation dataset from {len(segments)} segments")
examples = []
for segment in segments:
content = segment.content if hasattr(segment, "content") else segment["content"]
# Create a simple prompt asking to continue the thought
example = {
"messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": "Please share your thoughts on this topic."},
{"role": "assistant", "content": content},
],
"metadata": {"type": "continuation"},
}
examples.append(example)
# Validate and save
examples = self._validate_token_counts(examples)
random.shuffle(examples)
split_idx = int(len(examples) * train_ratio)
train_examples = examples[:split_idx]
val_examples = examples[split_idx:]
self._save_jsonl(train_examples, output_dir / "train.jsonl")
self._save_jsonl(val_examples, output_dir / "validation.jsonl")
stats = self._calculate_statistics(examples, train_examples, val_examples)
with open(output_dir / "dataset_stats.json", "w", encoding="utf-8") as f:
json.dump(stats.to_dict(), f, indent=2)
return stats
@staticmethod
def load_dataset(path: str | Path) -> list[dict]:
"""Load a JSONL dataset file."""
examples = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
examples.append(json.loads(line))
return examples
def update_system_prompt(self, new_prompt: str) -> None:
"""Update the system prompt for future builds."""
self.system_prompt = new_prompt
logger.info("System prompt updated")
def get_system_prompt(self) -> str:
"""Get the current system prompt."""
return self.system_prompt
def main():
"""CLI entry point for testing the builder."""
import argparse
parser = argparse.ArgumentParser(
description="Build training datasets in ChatML format",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python dataset_builder.py qa_pairs.json --output data/training/
python dataset_builder.py qa_pairs.json --train-ratio 0.85
python dataset_builder.py qa_pairs.json --system-prompt "Custom prompt..."
Input format (qa_pairs.json):
[
{"question": "...", "answer": "...", "question_type": "..."},
...
]
Output format (train.jsonl):
{"messages": [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
""",
)
parser.add_argument("input", help="Input Q&A pairs JSON file")
parser.add_argument(
"--output",
"-o",
default="data/training/",
help="Output directory (default: data/training/)",
)
parser.add_argument(
"--train-ratio",
type=float,
default=0.9,
help="Train/validation split ratio (default: 0.9)",
)
parser.add_argument(
"--system-prompt",
help="Custom system prompt (uses default if not provided)",
)
parser.add_argument(
"--ceo-name",
default="Ryouken Okuni",
help="CEO name for default prompt",
)
parser.add_argument(
"--company-name",
default="Akatsuki AI Technologies",
help="Company name for default prompt",
)
parser.add_argument(
"--max-tokens",
type=int,
default=2048,
help="Maximum tokens per example (default: 2048)",
)
parser.add_argument(
"--no-shuffle",
action="store_true",
help="Don't shuffle the data",
)
parser.add_argument(
"--no-dedup",
action="store_true",
help="Don't deduplicate questions",
)
args = parser.parse_args()
# Load Q&A pairs
with open(args.input, "r", encoding="utf-8") as f:
qa_pairs = json.load(f)
print(f"Loaded {len(qa_pairs)} Q&A pairs")
# Build dataset
builder = DatasetBuilder(
system_prompt=args.system_prompt,
ceo_name=args.ceo_name,
company_name=args.company_name,
max_tokens_per_example=args.max_tokens,
)
stats = builder.build_from_qa_pairs(
qa_pairs=qa_pairs,
output_dir=args.output,
train_ratio=args.train_ratio,
shuffle=not args.no_shuffle,
deduplicate=not args.no_dedup,
)
# Print statistics
print("\n=== Dataset Statistics ===")
print(f"Total examples: {stats.total_examples}")
print(f"Train examples: {stats.train_examples}")
print(f"Validation examples: {stats.validation_examples}")
print(f"Avg tokens/example: {stats.avg_tokens_per_example:.1f}")
print(f"Token range: {stats.min_tokens} - {stats.max_tokens}")
print(f"Total tokens: {stats.total_tokens:,}")
print("\nQuestion type distribution:")
for q_type, count in stats.question_type_distribution.items():
print(f" {q_type}: {count}")
if __name__ == "__main__":
main()