hassanshka commited on
Commit
eae0f44
·
verified ·
1 Parent(s): cb9eb68

Add calibration data: prepare_calibration.py

Browse files
calibration_data/prepare_calibration.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Prepare calibration dataset from baseline R0 evaluation results.
4
+
5
+ This script extracts successful completions (prompt + full_response) from the
6
+ baseline model evaluation to use as calibration data. This captures the full
7
+ "trajectory" of the model's behavior, which is better for quantization calibration
8
+ than using prompts alone.
9
+
10
+ Key features:
11
+ - Only uses successful completions (success=True)
12
+ - Balances across all tasks for fair representation
13
+ - Uses full prompt + full_response as calibration text
14
+ - Random stratified sampling for diversity
15
+ """
16
+
17
+ import sys
18
+ import json
19
+ from pathlib import Path
20
+ from collections import defaultdict
21
+ import random
22
+
23
+ print("=" * 80)
24
+ print("CALIBRATION DATASET PREPARATION (Baseline Trajectories)")
25
+ print("=" * 80)
26
+
27
+ # Configuration
28
+ BASELINE_RESULTS_PATH = "Data_r0_annotated_cleaned.jsonl"
29
+ OUTPUT_DIR = Path(__file__).parent
30
+ NUM_CALIBRATION_SAMPLES = 128
31
+ RANDOM_SEED = 42
32
+
33
+ print(f"\nBaseline results: {BASELINE_RESULTS_PATH}")
34
+ print(f"Number of calibration samples: {NUM_CALIBRATION_SAMPLES}")
35
+ print(f"Sampling strategy: Stratified random across tasks (successful completions only)")
36
+ print(f"Calibration format: prompt + full_response (complete trajectories)")
37
+ print(f"Random seed: {RANDOM_SEED}")
38
+ print(f"Output directory: {OUTPUT_DIR}")
39
+
40
+ # Set random seed
41
+ random.seed(RANDOM_SEED)
42
+
43
+ # --- TOKEN COUNTING FUNCTION ---
44
+ try:
45
+ import tiktoken
46
+
47
+ def count_tokens(text, enc_name="gpt2"):
48
+ enc = tiktoken.get_encoding(enc_name)
49
+ return len(enc.encode(text))
50
+ except ImportError:
51
+ # fallback: crude whitespace split as estimation
52
+ def count_tokens(text, enc_name=None):
53
+ # Not accurate, but gives an order of magnitude
54
+ return len(text.split())
55
+
56
+ print("[!] tiktoken not found. Falling back to whitespace token count (less accurate).")
57
+ else:
58
+ print("[i] Using tiktoken for accurate token counting.")
59
+
60
+ # Load baseline results
61
+ print("\n[1/5] Loading baseline evaluation results...")
62
+ try:
63
+ results = []
64
+ with open(BASELINE_RESULTS_PATH, 'r') as f:
65
+ for line_num, line in enumerate(f, 1):
66
+ line = line.strip()
67
+ if line:
68
+ try:
69
+ result = json.loads(line)
70
+ results.append(result)
71
+ except json.JSONDecodeError as e:
72
+ print(f" ⚠️ Warning: Skipping line {line_num} (invalid JSON): {e}")
73
+
74
+ print(f"✓ Loaded {len(results)} evaluation instances")
75
+ except FileNotFoundError:
76
+ print(f"✗ ERROR: Baseline results file not found at:")
77
+ print(f" {BASELINE_RESULTS_PATH}")
78
+ print(f"\nPlease ensure you have run the baseline evaluation and the results file exists.")
79
+ sys.exit(1)
80
+ except Exception as e:
81
+ print(f"✗ ERROR: Failed to load baseline results: {e}")
82
+ sys.exit(1)
83
+
84
+ # Filter for successful completions
85
+ print("\n[2/5] Filtering for successful completions...")
86
+ successful_results = [r for r in results if r.get('success', False) == True]
87
+ print(f"✓ Found {len(successful_results)} successful completions out of {len(results)} total")
88
+ print(f" Success rate: {len(successful_results)/len(results)*100:.1f}%")
89
+
90
+ if len(successful_results) < NUM_CALIBRATION_SAMPLES:
91
+ print(f"\n⚠️ WARNING: Only {len(successful_results)} successful completions available")
92
+ print(f" Requested {NUM_CALIBRATION_SAMPLES} samples")
93
+ print(f" Will use all {len(successful_results)} available samples")
94
+ NUM_CALIBRATION_SAMPLES = len(successful_results)
95
+
96
+ # Group by task
97
+ print("\n[3/5] Grouping by task...")
98
+ task_groups = defaultdict(list)
99
+ for result in successful_results:
100
+ task_name = result.get('task_name', 'unknown')
101
+ task_groups[task_name].append(result)
102
+
103
+ print(f"✓ Found {len(task_groups)} unique tasks:")
104
+ for task, instances in sorted(task_groups.items()):
105
+ print(f" • {task}: {len(instances)} successful completions")
106
+
107
+ # Stratified sampling
108
+ print(f"\n[4/5] Performing stratified random sampling...")
109
+ print(f" Target: {NUM_CALIBRATION_SAMPLES} samples balanced across {len(task_groups)} tasks")
110
+
111
+ # Calculate samples per task
112
+ samples_per_task = NUM_CALIBRATION_SAMPLES // len(task_groups)
113
+ remainder = NUM_CALIBRATION_SAMPLES % len(task_groups)
114
+
115
+ print(f" Base samples per task: {samples_per_task}")
116
+ print(f" Remainder to distribute: {remainder}")
117
+
118
+ calibration_samples = []
119
+ task_sample_counts = {}
120
+
121
+ # Sample from each task
122
+ for task, instances in sorted(task_groups.items()):
123
+ # Calculate how many samples for this task
124
+ n_samples = samples_per_task
125
+ if remainder > 0:
126
+ n_samples += 1
127
+ remainder -= 1
128
+
129
+ # Don't sample more than available
130
+ n_samples = min(n_samples, len(instances))
131
+
132
+ # Random sample
133
+ sampled = random.sample(instances, n_samples)
134
+ calibration_samples.extend(sampled)
135
+ task_sample_counts[task] = n_samples
136
+
137
+ print(f" • {task}: sampled {n_samples}/{len(instances)}")
138
+
139
+ print(f"\n✓ Sampled {len(calibration_samples)} total instances")
140
+
141
+ # Create calibration data
142
+ print("\n[5/5] Creating calibration dataset...")
143
+
144
+ calibration_texts = []
145
+ for sample in calibration_samples:
146
+ # Combine prompt + full_response as the calibration text
147
+ prompt = str(sample.get('prompt', '')).strip()
148
+ full_response = str(sample.get('full_response', '')).strip()
149
+
150
+ if not prompt:
151
+ print(f" ⚠️ Warning: Empty prompt for instance {sample.get('instance_id', '?')}")
152
+ continue
153
+
154
+ if not full_response:
155
+ print(f" ⚠️ Warning: Empty response for instance {sample.get('instance_id', '?')}")
156
+ continue
157
+
158
+ # Combine as conversation trajectory
159
+ calibration_text = f"{prompt}\n\n{full_response}"
160
+ calibration_texts.append(calibration_text)
161
+
162
+ print(f"✓ Created {len(calibration_texts)} calibration texts")
163
+
164
+ # Validation check
165
+ print("\n Validation:")
166
+ total_length = sum(len(text) for text in calibration_texts)
167
+ avg_length = total_length / len(calibration_texts) if calibration_texts else 0
168
+ min_length = min(len(text) for text in calibration_texts) if calibration_texts else 0
169
+ max_length = max(len(text) for text in calibration_texts) if calibration_texts else 0
170
+
171
+ print(f" • Total characters: {total_length:,}")
172
+ print(f" • Average length: {avg_length:,.0f} chars")
173
+ print(f" • Min length: {min_length:,} chars")
174
+ print(f" • Max length: {max_length:,} chars")
175
+
176
+ # --- TOKEN COUNT CALCULATION & PRINT ---
177
+ total_tokens = sum(count_tokens(text) for text in calibration_texts)
178
+ avg_tokens = total_tokens / len(calibration_texts) if calibration_texts else 0
179
+ min_tokens = min(count_tokens(text) for text in calibration_texts) if calibration_texts else 0
180
+ max_tokens = max(count_tokens(text) for text in calibration_texts) if calibration_texts else 0
181
+
182
+ print(f" • Total tokens: {total_tokens:,}")
183
+ print(f" • Average tokens: {avg_tokens:,.0f}")
184
+ print(f" • Min tokens: {min_tokens:,}")
185
+ print(f" • Max tokens: {max_tokens:,}")
186
+
187
+ if avg_length < 100:
188
+ print(f"\n ⚠️ WARNING: Average text length is very short ({avg_length:.0f} chars)")
189
+ print(f" This might indicate a problem with data extraction")
190
+ elif avg_length > 10000:
191
+ print(f"\n ⚠️ WARNING: Average text length is very long ({avg_length:.0f} chars)")
192
+ print(f" Some quantization methods may have issues with very long texts")
193
+ else:
194
+ print(f" ✓ Text lengths look reasonable")
195
+
196
+ # Save calibration data
197
+ output_json = OUTPUT_DIR / "calibration_data.json"
198
+ output_preview = OUTPUT_DIR / "calibration_preview.txt"
199
+
200
+ print(f"\nSaving calibration data...")
201
+ print(f" JSON: {output_json}")
202
+ print(f" Preview: {output_preview}")
203
+
204
+ # Save as JSON
205
+ with open(output_json, 'w') as f:
206
+ json.dump(calibration_texts, f, indent=2)
207
+ print(f"✓ Saved {len(calibration_texts)} calibration texts to JSON")
208
+
209
+ # Save preview
210
+ with open(output_preview, 'w') as f:
211
+ f.write("=" * 80 + "\n")
212
+ f.write("CALIBRATION DATASET PREVIEW\n")
213
+ f.write("=" * 80 + "\n\n")
214
+
215
+ f.write(f"Total samples: {len(calibration_texts)}\n")
216
+ f.write(f"Data source: Baseline R0 evaluation (successful completions only)\n")
217
+ f.write(f"Format: prompt + full_response (complete trajectories)\n\n")
218
+
219
+ f.write("Task distribution:\n")
220
+ for task, count in sorted(task_sample_counts.items()):
221
+ f.write(f" • {task}: {count} samples\n")
222
+
223
+ f.write(f"\nText statistics:\n")
224
+ f.write(f" • Average length: {avg_length:,.0f} characters\n")
225
+ f.write(f" • Min length: {min_length:,} characters\n")
226
+ f.write(f" • Max length: {max_length:,} characters\n")
227
+ f.write(f" • Total characters: {total_length:,}\n")
228
+ f.write(f" • Average tokens: {avg_tokens:,.0f}\n")
229
+ f.write(f" • Total tokens: {total_tokens:,}\n")
230
+ f.write(f" • Min tokens: {min_tokens:,}\n")
231
+ f.write(f" • Max tokens: {max_tokens:,}\n")
232
+
233
+ f.write("\n" + "=" * 80 + "\n")
234
+ f.write("SAMPLE PREVIEW (First 3 samples, truncated)\n")
235
+ f.write("=" * 80 + "\n\n")
236
+
237
+ for i, text in enumerate(calibration_texts[:3], 1):
238
+ f.write(f"Sample {i}:\n")
239
+ f.write("-" * 80 + "\n")
240
+ # Show first 500 chars and last 200 chars if text is long
241
+ if len(text) > 1000:
242
+ f.write(text[:500])
243
+ f.write(f"\n\n... [{len(text)-700:,} characters omitted] ...\n\n")
244
+ f.write(text[-200:])
245
+ else:
246
+ f.write(text)
247
+ f.write("\n" + "=" * 80 + "\n\n")
248
+
249
+ print(f"✓ Saved preview to {output_preview}")
250
+
251
+ # Final summary
252
+ print("\n" + "=" * 80)
253
+ print("CALIBRATION DATASET PREPARATION COMPLETE")
254
+ print("=" * 80)
255
+ print(f"\n✓ Successfully created calibration dataset with {len(calibration_texts)} samples")
256
+ print(f"✓ Balanced across {len(task_groups)} tasks")
257
+ print(f"✓ Using full trajectories (prompt + response)")
258
+ print(f"✓ Average calibration text length: {avg_length:,.0f} characters")
259
+ print(f"✓ Total tokens in calibration data: {total_tokens:,}")
260
+ print(f"✓ Average tokens per calibration text: {avg_tokens:,.0f}")
261
+ print(f"\nOutput files:")
262
+ print(f" • {output_json}")
263
+ print(f" • {output_preview}")
264
+ print(f"\nNext steps:")
265
+ print(f" 1. Review calibration_preview.txt to verify the data looks correct")
266
+ print(f" 2. Run quantization scripts (AWQ, PTQ) - they will use calibration_data.json")
267
+ print(f" 3. AWQ: python awq/quantize_awq.py")
268
+ print(f" 4. PTQ: python ptq/quantize_ptq.py")
269
+ print()