codellama-fine-tuning / scripts /dataset_split.py
Prithvik-1's picture
Upload scripts/dataset_split.py with huggingface_hub
bb9fa45 verified
#!/usr/bin/env python3
"""
Dataset splitting script for CodeLlama fine-tuning
Creates train/val/test splits with validation
"""
import json
import random
from pathlib import Path
from typing import List, Dict
def validate_sample(sample: Dict, min_length: int = 3) -> bool:
"""Validate a single sample"""
# Check required fields
if "instruction" not in sample or "response" not in sample:
return False
# Check data types
if not isinstance(sample["instruction"], str) or not isinstance(sample["response"], str):
return False
# Check empty content
instruction = sample["instruction"].strip()
response = sample["response"].strip()
if not instruction or not response:
return False
# Check minimum length
if len(instruction) < min_length or len(response) < min_length:
return False
return True
def split_dataset(
input_file: str,
output_dir: str,
train_ratio: float = 0.75,
val_ratio: float = 0.10,
test_ratio: float = 0.15,
seed: int = 42,
min_length: int = 3
) -> Dict:
"""Split dataset into train/val/test with validation"""
# Validate ratios
ratio_sum = train_ratio + val_ratio + test_ratio
if abs(ratio_sum - 1.0) > 0.01:
raise ValueError(f"Ratios must sum to 1.0, got {ratio_sum}")
print(f"πŸ“Š Loading dataset from: {input_file}")
# Load data
samples = []
invalid_count = 0
with open(input_file, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
sample = json.loads(line)
if validate_sample(sample, min_length):
samples.append(sample)
else:
invalid_count += 1
print(f"⚠️ Invalid sample at line {line_num}: missing fields or too short")
except json.JSONDecodeError as e:
invalid_count += 1
print(f"❌ Invalid JSON at line {line_num}: {e}")
print(f"\nπŸ“Š Dataset Statistics:")
print(f" βœ… Valid samples: {len(samples)}")
print(f" ❌ Invalid samples: {invalid_count}")
if len(samples) < 10:
raise ValueError(f"Insufficient samples: {len(samples)} (minimum 10 required)")
# Shuffle with fixed seed
print(f"\nπŸ”€ Shuffling with seed={seed}...")
random.seed(seed)
random.shuffle(samples)
# Calculate split indices
total = len(samples)
train_end = int(total * train_ratio)
val_end = train_end + int(total * val_ratio)
train_data = samples[:train_end]
val_data = samples[train_end:val_end]
test_data = samples[val_end:]
# Create output directory
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Save splits
splits = {
"train": train_data,
"val": val_data,
"test": test_data
}
print(f"\nπŸ’Ύ Saving splits to: {output_path}")
for split_name, data in splits.items():
output_file = output_path / f"{split_name}.jsonl"
with open(output_file, 'w', encoding='utf-8') as f:
for item in data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
print(f" βœ… {split_name}.jsonl: {len(data)} samples")
# Return statistics
stats = {
"total": total,
"train": len(train_data),
"val": len(val_data),
"test": len(test_data),
"invalid": invalid_count,
"train_ratio": len(train_data) / total,
"val_ratio": len(val_data) / total,
"test_ratio": len(test_data) / total
}
return stats
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Split dataset for training")
parser.add_argument("--input", required=True, help="Input JSONL file")
parser.add_argument("--output-dir", required=True, help="Output directory")
parser.add_argument("--train-ratio", type=float, default=0.75, help="Training ratio (default: 0.75)")
parser.add_argument("--val-ratio", type=float, default=0.10, help="Validation ratio (default: 0.10)")
parser.add_argument("--test-ratio", type=float, default=0.15, help="Test ratio (default: 0.15)")
parser.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
parser.add_argument("--min-length", type=int, default=3, help="Minimum field length (default: 3)")
args = parser.parse_args()
print("=" * 70)
print("πŸ“Š DATASET SPLITTING FOR CODELLAMA FINE-TUNING")
print("=" * 70)
print(f"\nConfiguration:")
print(f" Input: {args.input}")
print(f" Output: {args.output_dir}")
print(f" Ratios: Train={args.train_ratio:.0%}, Val={args.val_ratio:.0%}, Test={args.test_ratio:.0%}")
print(f" Seed: {args.seed}")
print()
try:
stats = split_dataset(
args.input,
args.output_dir,
args.train_ratio,
args.val_ratio,
args.test_ratio,
args.seed,
args.min_length
)
print(f"\n" + "=" * 70)
print(f"βœ… SPLIT COMPLETE!")
print("=" * 70)
print(f"\nFinal Statistics:")
print(f" Total samples: {stats['total']}")
print(f" Training: {stats['train']} samples ({stats['train_ratio']*100:.1f}%)")
print(f" Validation: {stats['val']} samples ({stats['val_ratio']*100:.1f}%)")
print(f" Test: {stats['test']} samples ({stats['test_ratio']*100:.1f}%)")
if stats['invalid'] > 0:
print(f" ⚠️ Invalid samples skipped: {stats['invalid']}")
print("=" * 70)
except Exception as e:
print(f"\n❌ Error: {e}")
exit(1)