ai_exec / src /training /prepare_dataset.py
Chaitanya-aitf's picture
Upload 38 files
45ee481 verified
"""
Prepare Dataset Module
Load and preprocess training data for fine-tuning.
Converts JSONL files to Hugging Face Dataset format.
Example usage:
from src.training.prepare_dataset import prepare_dataset
train_dataset, val_dataset = prepare_dataset(
train_path="data/training/train.jsonl",
val_path="data/training/validation.jsonl",
)
"""
import json
from pathlib import Path
from typing import Optional, Tuple
from loguru import logger
try:
from datasets import Dataset, DatasetDict
HF_DATASETS_AVAILABLE = True
except ImportError:
HF_DATASETS_AVAILABLE = False
logger.warning("datasets library not available")
try:
from transformers import AutoTokenizer
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
logger.warning("transformers library not available")
def load_jsonl(path: str | Path) -> list[dict]:
"""Load data from JSONL file."""
data = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
data.append(json.loads(line))
return data
def format_chat_template(
messages: list[dict],
tokenizer,
add_generation_prompt: bool = False,
) -> str:
"""
Format messages using the tokenizer's chat template.
Args:
messages: List of message dicts with 'role' and 'content'
tokenizer: HuggingFace tokenizer
add_generation_prompt: Whether to add generation prompt at end
Returns:
Formatted string
"""
if hasattr(tokenizer, "apply_chat_template"):
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
else:
# Fallback to ChatML format
formatted = ""
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "system":
formatted += f"<|im_start|>system\n{content}<|im_end|>\n"
elif role == "user":
formatted += f"<|im_start|>user\n{content}<|im_end|>\n"
elif role == "assistant":
formatted += f"<|im_start|>assistant\n{content}<|im_end|>\n"
return formatted
def prepare_dataset(
train_path: str | Path,
val_path: Optional[str | Path] = None,
tokenizer_name: str = "Qwen/Qwen3-4B-Instruct",
max_length: int = 2048,
add_eos_token: bool = True,
) -> Tuple:
"""
Prepare training and validation datasets.
Args:
train_path: Path to training JSONL file
val_path: Path to validation JSONL file (optional)
tokenizer_name: Name of tokenizer to use for formatting
max_length: Maximum sequence length
add_eos_token: Whether to add EOS token
Returns:
Tuple of (train_dataset, val_dataset) or (train_dataset, None)
"""
if not HF_DATASETS_AVAILABLE:
raise ImportError("datasets library required. Run: pip install datasets")
if not TRANSFORMERS_AVAILABLE:
raise ImportError("transformers library required. Run: pip install transformers")
logger.info(f"Loading tokenizer: {tokenizer_name}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
# Ensure padding token is set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load training data
logger.info(f"Loading training data from: {train_path}")
train_data = load_jsonl(train_path)
logger.info(f"Loaded {len(train_data)} training examples")
# Format training examples
train_formatted = []
for example in train_data:
messages = example["messages"]
text = format_chat_template(messages, tokenizer)
if add_eos_token and not text.endswith(tokenizer.eos_token):
text += tokenizer.eos_token
train_formatted.append({"text": text})
train_dataset = Dataset.from_list(train_formatted)
# Load validation data if provided
val_dataset = None
if val_path:
logger.info(f"Loading validation data from: {val_path}")
val_data = load_jsonl(val_path)
logger.info(f"Loaded {len(val_data)} validation examples")
val_formatted = []
for example in val_data:
messages = example["messages"]
text = format_chat_template(messages, tokenizer)
if add_eos_token and not text.endswith(tokenizer.eos_token):
text += tokenizer.eos_token
val_formatted.append({"text": text})
val_dataset = Dataset.from_list(val_formatted)
logger.info("Dataset preparation complete")
return train_dataset, val_dataset
def prepare_dataset_dict(
train_path: str | Path,
val_path: str | Path,
tokenizer_name: str = "Qwen/Qwen3-4B-Instruct",
max_length: int = 2048,
) -> DatasetDict:
"""
Prepare a DatasetDict with train and validation splits.
Args:
train_path: Path to training JSONL
val_path: Path to validation JSONL
tokenizer_name: Tokenizer name
max_length: Maximum sequence length
Returns:
DatasetDict with 'train' and 'validation' keys
"""
train_dataset, val_dataset = prepare_dataset(
train_path=train_path,
val_path=val_path,
tokenizer_name=tokenizer_name,
max_length=max_length,
)
return DatasetDict({
"train": train_dataset,
"validation": val_dataset,
})
def tokenize_dataset(
dataset: Dataset,
tokenizer,
max_length: int = 2048,
num_proc: int = 4,
) -> Dataset:
"""
Tokenize a dataset for training.
Args:
dataset: Dataset with 'text' column
tokenizer: HuggingFace tokenizer
max_length: Maximum sequence length
num_proc: Number of processes for parallel tokenization
Returns:
Tokenized dataset
"""
def tokenize_function(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=max_length,
padding=False,
return_tensors=None,
)
tokenized = dataset.map(
tokenize_function,
batched=True,
num_proc=num_proc,
remove_columns=dataset.column_names,
desc="Tokenizing",
)
return tokenized
def push_dataset_to_hub(
dataset_dict: DatasetDict,
repo_id: str,
private: bool = True,
token: Optional[str] = None,
) -> None:
"""
Push dataset to Hugging Face Hub.
Args:
dataset_dict: DatasetDict to push
repo_id: Repository ID on HF Hub
private: Whether repo should be private
token: HF token (uses HF_TOKEN env var if not provided)
"""
import os
token = token or os.environ.get("HF_TOKEN")
logger.info(f"Pushing dataset to: {repo_id}")
dataset_dict.push_to_hub(
repo_id,
private=private,
token=token,
)
logger.info("Dataset pushed successfully")
def main():
"""CLI entry point for testing dataset preparation."""
import argparse
parser = argparse.ArgumentParser(
description="Prepare training datasets for fine-tuning",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python prepare_dataset.py data/training/train.jsonl --val data/training/validation.jsonl
python prepare_dataset.py data/training/train.jsonl --push-to-hub username/dataset-name
""",
)
parser.add_argument("train", help="Path to training JSONL file")
parser.add_argument("--val", help="Path to validation JSONL file")
parser.add_argument(
"--tokenizer",
default="Qwen/Qwen3-4B-Instruct",
help="Tokenizer name (default: Qwen/Qwen3-4B-Instruct)",
)
parser.add_argument(
"--max-length",
type=int,
default=2048,
help="Maximum sequence length (default: 2048)",
)
parser.add_argument(
"--push-to-hub",
help="Push dataset to HF Hub with this repo ID",
)
parser.add_argument(
"--private",
action="store_true",
default=True,
help="Make HF repo private (default: True)",
)
args = parser.parse_args()
# Prepare dataset
if args.val:
dataset_dict = prepare_dataset_dict(
train_path=args.train,
val_path=args.val,
tokenizer_name=args.tokenizer,
max_length=args.max_length,
)
print(f"\nDataset prepared:")
print(f" Train: {len(dataset_dict['train'])} examples")
print(f" Validation: {len(dataset_dict['validation'])} examples")
# Show sample
print("\nSample training example:")
print(dataset_dict["train"][0]["text"][:500] + "...")
# Push to hub if requested
if args.push_to_hub:
push_dataset_to_hub(
dataset_dict,
args.push_to_hub,
private=args.private,
)
else:
train_dataset, _ = prepare_dataset(
train_path=args.train,
tokenizer_name=args.tokenizer,
max_length=args.max_length,
)
print(f"\nDataset prepared: {len(train_dataset)} examples")
print("\nSample:")
print(train_dataset[0]["text"][:500] + "...")
if __name__ == "__main__":
main()