Prithvik-1 commited on
Commit
bb9fa45
ยท
verified ยท
1 Parent(s): e99a50b

Upload scripts/dataset_split.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/dataset_split.py +180 -0
scripts/dataset_split.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Dataset splitting script for CodeLlama fine-tuning
4
+ Creates train/val/test splits with validation
5
+ """
6
+
7
+ import json
8
+ import random
9
+ from pathlib import Path
10
+ from typing import List, Dict
11
+
12
+ def validate_sample(sample: Dict, min_length: int = 3) -> bool:
13
+ """Validate a single sample"""
14
+ # Check required fields
15
+ if "instruction" not in sample or "response" not in sample:
16
+ return False
17
+
18
+ # Check data types
19
+ if not isinstance(sample["instruction"], str) or not isinstance(sample["response"], str):
20
+ return False
21
+
22
+ # Check empty content
23
+ instruction = sample["instruction"].strip()
24
+ response = sample["response"].strip()
25
+
26
+ if not instruction or not response:
27
+ return False
28
+
29
+ # Check minimum length
30
+ if len(instruction) < min_length or len(response) < min_length:
31
+ return False
32
+
33
+ return True
34
+
35
+ def split_dataset(
36
+ input_file: str,
37
+ output_dir: str,
38
+ train_ratio: float = 0.75,
39
+ val_ratio: float = 0.10,
40
+ test_ratio: float = 0.15,
41
+ seed: int = 42,
42
+ min_length: int = 3
43
+ ) -> Dict:
44
+ """Split dataset into train/val/test with validation"""
45
+
46
+ # Validate ratios
47
+ ratio_sum = train_ratio + val_ratio + test_ratio
48
+ if abs(ratio_sum - 1.0) > 0.01:
49
+ raise ValueError(f"Ratios must sum to 1.0, got {ratio_sum}")
50
+
51
+ print(f"๐Ÿ“Š Loading dataset from: {input_file}")
52
+
53
+ # Load data
54
+ samples = []
55
+ invalid_count = 0
56
+
57
+ with open(input_file, 'r', encoding='utf-8') as f:
58
+ for line_num, line in enumerate(f, 1):
59
+ line = line.strip()
60
+ if not line:
61
+ continue
62
+
63
+ try:
64
+ sample = json.loads(line)
65
+ if validate_sample(sample, min_length):
66
+ samples.append(sample)
67
+ else:
68
+ invalid_count += 1
69
+ print(f"โš ๏ธ Invalid sample at line {line_num}: missing fields or too short")
70
+ except json.JSONDecodeError as e:
71
+ invalid_count += 1
72
+ print(f"โŒ Invalid JSON at line {line_num}: {e}")
73
+
74
+ print(f"\n๐Ÿ“Š Dataset Statistics:")
75
+ print(f" โœ… Valid samples: {len(samples)}")
76
+ print(f" โŒ Invalid samples: {invalid_count}")
77
+
78
+ if len(samples) < 10:
79
+ raise ValueError(f"Insufficient samples: {len(samples)} (minimum 10 required)")
80
+
81
+ # Shuffle with fixed seed
82
+ print(f"\n๐Ÿ”€ Shuffling with seed={seed}...")
83
+ random.seed(seed)
84
+ random.shuffle(samples)
85
+
86
+ # Calculate split indices
87
+ total = len(samples)
88
+ train_end = int(total * train_ratio)
89
+ val_end = train_end + int(total * val_ratio)
90
+
91
+ train_data = samples[:train_end]
92
+ val_data = samples[train_end:val_end]
93
+ test_data = samples[val_end:]
94
+
95
+ # Create output directory
96
+ output_path = Path(output_dir)
97
+ output_path.mkdir(parents=True, exist_ok=True)
98
+
99
+ # Save splits
100
+ splits = {
101
+ "train": train_data,
102
+ "val": val_data,
103
+ "test": test_data
104
+ }
105
+
106
+ print(f"\n๐Ÿ’พ Saving splits to: {output_path}")
107
+ for split_name, data in splits.items():
108
+ output_file = output_path / f"{split_name}.jsonl"
109
+ with open(output_file, 'w', encoding='utf-8') as f:
110
+ for item in data:
111
+ f.write(json.dumps(item, ensure_ascii=False) + '\n')
112
+
113
+ print(f" โœ… {split_name}.jsonl: {len(data)} samples")
114
+
115
+ # Return statistics
116
+ stats = {
117
+ "total": total,
118
+ "train": len(train_data),
119
+ "val": len(val_data),
120
+ "test": len(test_data),
121
+ "invalid": invalid_count,
122
+ "train_ratio": len(train_data) / total,
123
+ "val_ratio": len(val_data) / total,
124
+ "test_ratio": len(test_data) / total
125
+ }
126
+
127
+ return stats
128
+
129
+ if __name__ == "__main__":
130
+ import argparse
131
+
132
+ parser = argparse.ArgumentParser(description="Split dataset for training")
133
+ parser.add_argument("--input", required=True, help="Input JSONL file")
134
+ parser.add_argument("--output-dir", required=True, help="Output directory")
135
+ parser.add_argument("--train-ratio", type=float, default=0.75, help="Training ratio (default: 0.75)")
136
+ parser.add_argument("--val-ratio", type=float, default=0.10, help="Validation ratio (default: 0.10)")
137
+ parser.add_argument("--test-ratio", type=float, default=0.15, help="Test ratio (default: 0.15)")
138
+ parser.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
139
+ parser.add_argument("--min-length", type=int, default=3, help="Minimum field length (default: 3)")
140
+
141
+ args = parser.parse_args()
142
+
143
+ print("=" * 70)
144
+ print("๐Ÿ“Š DATASET SPLITTING FOR CODELLAMA FINE-TUNING")
145
+ print("=" * 70)
146
+ print(f"\nConfiguration:")
147
+ print(f" Input: {args.input}")
148
+ print(f" Output: {args.output_dir}")
149
+ print(f" Ratios: Train={args.train_ratio:.0%}, Val={args.val_ratio:.0%}, Test={args.test_ratio:.0%}")
150
+ print(f" Seed: {args.seed}")
151
+ print()
152
+
153
+ try:
154
+ stats = split_dataset(
155
+ args.input,
156
+ args.output_dir,
157
+ args.train_ratio,
158
+ args.val_ratio,
159
+ args.test_ratio,
160
+ args.seed,
161
+ args.min_length
162
+ )
163
+
164
+ print(f"\n" + "=" * 70)
165
+ print(f"โœ… SPLIT COMPLETE!")
166
+ print("=" * 70)
167
+ print(f"\nFinal Statistics:")
168
+ print(f" Total samples: {stats['total']}")
169
+ print(f" Training: {stats['train']} samples ({stats['train_ratio']*100:.1f}%)")
170
+ print(f" Validation: {stats['val']} samples ({stats['val_ratio']*100:.1f}%)")
171
+ print(f" Test: {stats['test']} samples ({stats['test_ratio']*100:.1f}%)")
172
+ if stats['invalid'] > 0:
173
+ print(f" โš ๏ธ Invalid samples skipped: {stats['invalid']}")
174
+ print("=" * 70)
175
+
176
+ except Exception as e:
177
+ print(f"\nโŒ Error: {e}")
178
+ exit(1)
179
+
180
+