Spaces:
Paused
Paused
| """ | |
| Prepare Dataset Module | |
| Load and preprocess training data for fine-tuning. | |
| Converts JSONL files to Hugging Face Dataset format. | |
| Example usage: | |
| from src.training.prepare_dataset import prepare_dataset | |
| train_dataset, val_dataset = prepare_dataset( | |
| train_path="data/training/train.jsonl", | |
| val_path="data/training/validation.jsonl", | |
| ) | |
| """ | |
| import json | |
| from pathlib import Path | |
| from typing import Optional, Tuple | |
| from loguru import logger | |
| try: | |
| from datasets import Dataset, DatasetDict | |
| HF_DATASETS_AVAILABLE = True | |
| except ImportError: | |
| HF_DATASETS_AVAILABLE = False | |
| logger.warning("datasets library not available") | |
| try: | |
| from transformers import AutoTokenizer | |
| TRANSFORMERS_AVAILABLE = True | |
| except ImportError: | |
| TRANSFORMERS_AVAILABLE = False | |
| logger.warning("transformers library not available") | |
| def load_jsonl(path: str | Path) -> list[dict]: | |
| """Load data from JSONL file.""" | |
| data = [] | |
| with open(path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| if line.strip(): | |
| data.append(json.loads(line)) | |
| return data | |
| def format_chat_template( | |
| messages: list[dict], | |
| tokenizer, | |
| add_generation_prompt: bool = False, | |
| ) -> str: | |
| """ | |
| Format messages using the tokenizer's chat template. | |
| Args: | |
| messages: List of message dicts with 'role' and 'content' | |
| tokenizer: HuggingFace tokenizer | |
| add_generation_prompt: Whether to add generation prompt at end | |
| Returns: | |
| Formatted string | |
| """ | |
| if hasattr(tokenizer, "apply_chat_template"): | |
| return tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=add_generation_prompt, | |
| ) | |
| else: | |
| # Fallback to ChatML format | |
| formatted = "" | |
| for msg in messages: | |
| role = msg["role"] | |
| content = msg["content"] | |
| if role == "system": | |
| formatted += f"<|im_start|>system\n{content}<|im_end|>\n" | |
| elif role == "user": | |
| formatted += f"<|im_start|>user\n{content}<|im_end|>\n" | |
| elif role == "assistant": | |
| formatted += f"<|im_start|>assistant\n{content}<|im_end|>\n" | |
| return formatted | |
| def prepare_dataset( | |
| train_path: str | Path, | |
| val_path: Optional[str | Path] = None, | |
| tokenizer_name: str = "Qwen/Qwen3-4B-Instruct", | |
| max_length: int = 2048, | |
| add_eos_token: bool = True, | |
| ) -> Tuple: | |
| """ | |
| Prepare training and validation datasets. | |
| Args: | |
| train_path: Path to training JSONL file | |
| val_path: Path to validation JSONL file (optional) | |
| tokenizer_name: Name of tokenizer to use for formatting | |
| max_length: Maximum sequence length | |
| add_eos_token: Whether to add EOS token | |
| Returns: | |
| Tuple of (train_dataset, val_dataset) or (train_dataset, None) | |
| """ | |
| if not HF_DATASETS_AVAILABLE: | |
| raise ImportError("datasets library required. Run: pip install datasets") | |
| if not TRANSFORMERS_AVAILABLE: | |
| raise ImportError("transformers library required. Run: pip install transformers") | |
| logger.info(f"Loading tokenizer: {tokenizer_name}") | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) | |
| # Ensure padding token is set | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load training data | |
| logger.info(f"Loading training data from: {train_path}") | |
| train_data = load_jsonl(train_path) | |
| logger.info(f"Loaded {len(train_data)} training examples") | |
| # Format training examples | |
| train_formatted = [] | |
| for example in train_data: | |
| messages = example["messages"] | |
| text = format_chat_template(messages, tokenizer) | |
| if add_eos_token and not text.endswith(tokenizer.eos_token): | |
| text += tokenizer.eos_token | |
| train_formatted.append({"text": text}) | |
| train_dataset = Dataset.from_list(train_formatted) | |
| # Load validation data if provided | |
| val_dataset = None | |
| if val_path: | |
| logger.info(f"Loading validation data from: {val_path}") | |
| val_data = load_jsonl(val_path) | |
| logger.info(f"Loaded {len(val_data)} validation examples") | |
| val_formatted = [] | |
| for example in val_data: | |
| messages = example["messages"] | |
| text = format_chat_template(messages, tokenizer) | |
| if add_eos_token and not text.endswith(tokenizer.eos_token): | |
| text += tokenizer.eos_token | |
| val_formatted.append({"text": text}) | |
| val_dataset = Dataset.from_list(val_formatted) | |
| logger.info("Dataset preparation complete") | |
| return train_dataset, val_dataset | |
| def prepare_dataset_dict( | |
| train_path: str | Path, | |
| val_path: str | Path, | |
| tokenizer_name: str = "Qwen/Qwen3-4B-Instruct", | |
| max_length: int = 2048, | |
| ) -> DatasetDict: | |
| """ | |
| Prepare a DatasetDict with train and validation splits. | |
| Args: | |
| train_path: Path to training JSONL | |
| val_path: Path to validation JSONL | |
| tokenizer_name: Tokenizer name | |
| max_length: Maximum sequence length | |
| Returns: | |
| DatasetDict with 'train' and 'validation' keys | |
| """ | |
| train_dataset, val_dataset = prepare_dataset( | |
| train_path=train_path, | |
| val_path=val_path, | |
| tokenizer_name=tokenizer_name, | |
| max_length=max_length, | |
| ) | |
| return DatasetDict({ | |
| "train": train_dataset, | |
| "validation": val_dataset, | |
| }) | |
| def tokenize_dataset( | |
| dataset: Dataset, | |
| tokenizer, | |
| max_length: int = 2048, | |
| num_proc: int = 4, | |
| ) -> Dataset: | |
| """ | |
| Tokenize a dataset for training. | |
| Args: | |
| dataset: Dataset with 'text' column | |
| tokenizer: HuggingFace tokenizer | |
| max_length: Maximum sequence length | |
| num_proc: Number of processes for parallel tokenization | |
| Returns: | |
| Tokenized dataset | |
| """ | |
| def tokenize_function(examples): | |
| return tokenizer( | |
| examples["text"], | |
| truncation=True, | |
| max_length=max_length, | |
| padding=False, | |
| return_tensors=None, | |
| ) | |
| tokenized = dataset.map( | |
| tokenize_function, | |
| batched=True, | |
| num_proc=num_proc, | |
| remove_columns=dataset.column_names, | |
| desc="Tokenizing", | |
| ) | |
| return tokenized | |
| def push_dataset_to_hub( | |
| dataset_dict: DatasetDict, | |
| repo_id: str, | |
| private: bool = True, | |
| token: Optional[str] = None, | |
| ) -> None: | |
| """ | |
| Push dataset to Hugging Face Hub. | |
| Args: | |
| dataset_dict: DatasetDict to push | |
| repo_id: Repository ID on HF Hub | |
| private: Whether repo should be private | |
| token: HF token (uses HF_TOKEN env var if not provided) | |
| """ | |
| import os | |
| token = token or os.environ.get("HF_TOKEN") | |
| logger.info(f"Pushing dataset to: {repo_id}") | |
| dataset_dict.push_to_hub( | |
| repo_id, | |
| private=private, | |
| token=token, | |
| ) | |
| logger.info("Dataset pushed successfully") | |
| def main(): | |
| """CLI entry point for testing dataset preparation.""" | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| description="Prepare training datasets for fine-tuning", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| python prepare_dataset.py data/training/train.jsonl --val data/training/validation.jsonl | |
| python prepare_dataset.py data/training/train.jsonl --push-to-hub username/dataset-name | |
| """, | |
| ) | |
| parser.add_argument("train", help="Path to training JSONL file") | |
| parser.add_argument("--val", help="Path to validation JSONL file") | |
| parser.add_argument( | |
| "--tokenizer", | |
| default="Qwen/Qwen3-4B-Instruct", | |
| help="Tokenizer name (default: Qwen/Qwen3-4B-Instruct)", | |
| ) | |
| parser.add_argument( | |
| "--max-length", | |
| type=int, | |
| default=2048, | |
| help="Maximum sequence length (default: 2048)", | |
| ) | |
| parser.add_argument( | |
| "--push-to-hub", | |
| help="Push dataset to HF Hub with this repo ID", | |
| ) | |
| parser.add_argument( | |
| "--private", | |
| action="store_true", | |
| default=True, | |
| help="Make HF repo private (default: True)", | |
| ) | |
| args = parser.parse_args() | |
| # Prepare dataset | |
| if args.val: | |
| dataset_dict = prepare_dataset_dict( | |
| train_path=args.train, | |
| val_path=args.val, | |
| tokenizer_name=args.tokenizer, | |
| max_length=args.max_length, | |
| ) | |
| print(f"\nDataset prepared:") | |
| print(f" Train: {len(dataset_dict['train'])} examples") | |
| print(f" Validation: {len(dataset_dict['validation'])} examples") | |
| # Show sample | |
| print("\nSample training example:") | |
| print(dataset_dict["train"][0]["text"][:500] + "...") | |
| # Push to hub if requested | |
| if args.push_to_hub: | |
| push_dataset_to_hub( | |
| dataset_dict, | |
| args.push_to_hub, | |
| private=args.private, | |
| ) | |
| else: | |
| train_dataset, _ = prepare_dataset( | |
| train_path=args.train, | |
| tokenizer_name=args.tokenizer, | |
| max_length=args.max_length, | |
| ) | |
| print(f"\nDataset prepared: {len(train_dataset)} examples") | |
| print("\nSample:") | |
| print(train_dataset[0]["text"][:500] + "...") | |
| if __name__ == "__main__": | |
| main() | |