#!/usr/bin/env python3 """ Dataset splitting script for CodeLlama fine-tuning Creates train/val/test splits with validation """ import json import random from pathlib import Path from typing import List, Dict def validate_sample(sample: Dict, min_length: int = 3) -> bool: """Validate a single sample""" # Check required fields if "instruction" not in sample or "response" not in sample: return False # Check data types if not isinstance(sample["instruction"], str) or not isinstance(sample["response"], str): return False # Check empty content instruction = sample["instruction"].strip() response = sample["response"].strip() if not instruction or not response: return False # Check minimum length if len(instruction) < min_length or len(response) < min_length: return False return True def split_dataset( input_file: str, output_dir: str, train_ratio: float = 0.75, val_ratio: float = 0.10, test_ratio: float = 0.15, seed: int = 42, min_length: int = 3 ) -> Dict: """Split dataset into train/val/test with validation""" # Validate ratios ratio_sum = train_ratio + val_ratio + test_ratio if abs(ratio_sum - 1.0) > 0.01: raise ValueError(f"Ratios must sum to 1.0, got {ratio_sum}") print(f"šŸ“Š Loading dataset from: {input_file}") # Load data samples = [] invalid_count = 0 with open(input_file, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): line = line.strip() if not line: continue try: sample = json.loads(line) if validate_sample(sample, min_length): samples.append(sample) else: invalid_count += 1 print(f"āš ļø Invalid sample at line {line_num}: missing fields or too short") except json.JSONDecodeError as e: invalid_count += 1 print(f"āŒ Invalid JSON at line {line_num}: {e}") print(f"\nšŸ“Š Dataset Statistics:") print(f" āœ… Valid samples: {len(samples)}") print(f" āŒ Invalid samples: {invalid_count}") if len(samples) < 10: raise ValueError(f"Insufficient samples: {len(samples)} (minimum 10 required)") # Shuffle with fixed seed print(f"\nšŸ”€ Shuffling with seed={seed}...") random.seed(seed) random.shuffle(samples) # Calculate split indices total = len(samples) train_end = int(total * train_ratio) val_end = train_end + int(total * val_ratio) train_data = samples[:train_end] val_data = samples[train_end:val_end] test_data = samples[val_end:] # Create output directory output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # Save splits splits = { "train": train_data, "val": val_data, "test": test_data } print(f"\nšŸ’¾ Saving splits to: {output_path}") for split_name, data in splits.items(): output_file = output_path / f"{split_name}.jsonl" with open(output_file, 'w', encoding='utf-8') as f: for item in data: f.write(json.dumps(item, ensure_ascii=False) + '\n') print(f" āœ… {split_name}.jsonl: {len(data)} samples") # Return statistics stats = { "total": total, "train": len(train_data), "val": len(val_data), "test": len(test_data), "invalid": invalid_count, "train_ratio": len(train_data) / total, "val_ratio": len(val_data) / total, "test_ratio": len(test_data) / total } return stats if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Split dataset for training") parser.add_argument("--input", required=True, help="Input JSONL file") parser.add_argument("--output-dir", required=True, help="Output directory") parser.add_argument("--train-ratio", type=float, default=0.75, help="Training ratio (default: 0.75)") parser.add_argument("--val-ratio", type=float, default=0.10, help="Validation ratio (default: 0.10)") parser.add_argument("--test-ratio", type=float, default=0.15, help="Test ratio (default: 0.15)") parser.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)") parser.add_argument("--min-length", type=int, default=3, help="Minimum field length (default: 3)") args = parser.parse_args() print("=" * 70) print("šŸ“Š DATASET SPLITTING FOR CODELLAMA FINE-TUNING") print("=" * 70) print(f"\nConfiguration:") print(f" Input: {args.input}") print(f" Output: {args.output_dir}") print(f" Ratios: Train={args.train_ratio:.0%}, Val={args.val_ratio:.0%}, Test={args.test_ratio:.0%}") print(f" Seed: {args.seed}") print() try: stats = split_dataset( args.input, args.output_dir, args.train_ratio, args.val_ratio, args.test_ratio, args.seed, args.min_length ) print(f"\n" + "=" * 70) print(f"āœ… SPLIT COMPLETE!") print("=" * 70) print(f"\nFinal Statistics:") print(f" Total samples: {stats['total']}") print(f" Training: {stats['train']} samples ({stats['train_ratio']*100:.1f}%)") print(f" Validation: {stats['val']} samples ({stats['val_ratio']*100:.1f}%)") print(f" Test: {stats['test']} samples ({stats['test_ratio']*100:.1f}%)") if stats['invalid'] > 0: print(f" āš ļø Invalid samples skipped: {stats['invalid']}") print("=" * 70) except Exception as e: print(f"\nāŒ Error: {e}") exit(1)