DouDou commited on
Commit
0003466
·
verified ·
1 Parent(s): 3562304

Upload data3/generate_problems_batch.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. data3/generate_problems_batch.py +655 -0
data3/generate_problems_batch.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate programming problems from function_dataset_v2.csv using OpenAI Batch API.
4
+ Batch API offers 50% cost savings compared to standard API.
5
+ """
6
+
7
+ import csv
8
+ import json
9
+ import os
10
+ import sys
11
+ from openai import OpenAI
12
+ from datetime import datetime
13
+ from typing import Dict, Optional, List
14
+ import time
15
+
16
+ # Configuration
17
+ MODEL_NAME = "gpt-4o-mini"
18
+ MIN_RELEVANCE_SCORE = 60
19
+ MAX_BUDGET_USD = 10.0
20
+
21
+ # OpenAI Batch API pricing (50% off standard pricing)
22
+ # Official pricing: https://openai.com/api/pricing/
23
+ BATCH_PRICING = {
24
+ # GPT-5 series with Batch API discount
25
+ "gpt-5.2": {
26
+ "input": 0.875 / 1_000_000, # $0.875 per 1M (50% off $1.75)
27
+ "output": 7.00 / 1_000_000, # $7.00 per 1M (50% off $14.00)
28
+ },
29
+ "gpt-5.1": {
30
+ "input": 0.625 / 1_000_000, # $0.625 per 1M (50% off $1.25)
31
+ "output": 5.00 / 1_000_000, # $5.00 per 1M (50% off $10.00)
32
+ },
33
+ "gpt-5": {
34
+ "input": 0.625 / 1_000_000, # $0.625 per 1M (50% off $1.25)
35
+ "output": 5.00 / 1_000_000, # $5.00 per 1M (50% off $10.00)
36
+ },
37
+ "gpt-5-mini": {
38
+ "input": 0.125 / 1_000_000, # $0.125 per 1M (50% off $0.25)
39
+ "output": 1.00 / 1_000_000, # $1.00 per 1M (50% off $2.00)
40
+ },
41
+ "gpt-5-nano": {
42
+ "input": 0.025 / 1_000_000, # $0.025 per 1M (50% off $0.05)
43
+ "output": 0.20 / 1_000_000, # $0.20 per 1M (50% off $0.40)
44
+ },
45
+ # GPT-4o series with Batch API discount
46
+ "gpt-4o": {
47
+ "input": 1.25 / 1_000_000, # $1.25 per 1M (50% off $2.50)
48
+ "output": 5.00 / 1_000_000, # $5.00 per 1M (50% off $10.00)
49
+ },
50
+ "gpt-4o-2024-05-13": {
51
+ "input": 2.50 / 1_000_000, # $2.50 per 1M (50% off $5.00)
52
+ "output": 7.50 / 1_000_000, # $7.50 per 1M (50% off $15.00)
53
+ },
54
+ "gpt-4o-mini": {
55
+ "input": 0.075 / 1_000_000, # $0.075 per 1M (50% off $0.15)
56
+ "output": 0.30 / 1_000_000, # $0.30 per 1M (50% off $0.60)
57
+ },
58
+ # GPT-4 Turbo
59
+ "gpt-4-turbo": {
60
+ "input": 5.00 / 1_000_000, # $5.00 per 1M (50% off $10.00)
61
+ "output": 15.00 / 1_000_000, # $15.00 per 1M (50% off $30.00)
62
+ },
63
+ # GPT-3.5 Turbo
64
+ "gpt-3.5-turbo": {
65
+ "input": 0.25 / 1_000_000, # $0.25 per 1M (50% off $0.50)
66
+ "output": 0.75 / 1_000_000, # $0.75 per 1M (50% off $1.50)
67
+ },
68
+ }
69
+
70
+ PROMPT_TEMPLATE = """You are an expert in scientific computing and computational chemistry/biology/physics. Please create a high-quality programming problem inspired by the following code snippet from a real scientific computing project.
71
+
72
+ The problem should focus on scientific computing concepts such as:
73
+ - Numerical algorithms and simulations
74
+ - Data analysis and visualization
75
+ - Mathematical modeling
76
+ - Scientific data processing
77
+ - Computational methods in chemistry, biology, or physics
78
+
79
+ Code snippet for inspiration:
80
+ ```python
81
+ {code}
82
+ ```
83
+
84
+ Present your output in two distinct sections:
85
+
86
+ [Problem Description]
87
+ Create a **completely self-contained** problem description that:
88
+ - Does NOT directly reference the code snippet above
89
+ - Provides all necessary context and background
90
+ - Clearly states what needs to be implemented
91
+ - Specifies input/output format and constraints
92
+ - Is inspired by the scientific computing concepts in the code but creates a NEW, interesting problem
93
+ - Assumes common programming knowledge but explains any domain-specific concepts
94
+
95
+ [Solution]
96
+ Provide a comprehensive, **correct** Python solution that:
97
+ - Accurately solves the problem described
98
+ - Includes clear comments explaining the approach
99
+ - Uses appropriate scientific computing libraries (numpy, scipy, etc.) when relevant
100
+ - Is complete and runnable
101
+ - Follows best practices for scientific computing
102
+
103
+ Remember: The problem should be INSPIRED by the code, not a direct copy. Create something educational and interesting for scientific computing practitioners."""
104
+
105
+
106
+ class BatchAPIClient:
107
+ """Client for OpenAI Batch API with cost tracking."""
108
+
109
+ def __init__(self, model_name: str = MODEL_NAME, api_key: Optional[str] = None):
110
+ """Initialize OpenAI Batch API client.
111
+
112
+ Args:
113
+ model_name: Name of the OpenAI model to use
114
+ api_key: OpenAI API key (if None, will use OPENAI_API_KEY env variable)
115
+ """
116
+ self.model_name = model_name
117
+ self.client = OpenAI(api_key=api_key)
118
+
119
+ # Get pricing for the model (Batch API is 50% off)
120
+ if model_name in BATCH_PRICING:
121
+ self.input_price = BATCH_PRICING[model_name]["input"]
122
+ self.output_price = BATCH_PRICING[model_name]["output"]
123
+ else:
124
+ print(f"Warning: No Batch pricing info for {model_name}, using gpt-4o-mini prices")
125
+ self.input_price = BATCH_PRICING["gpt-4o-mini"]["input"]
126
+ self.output_price = BATCH_PRICING["gpt-4o-mini"]["output"]
127
+
128
+ print(f"📊 Batch API Pricing (50% off standard rates):")
129
+ print(f" Input: ${self.input_price * 1_000_000:.4f} per 1M tokens")
130
+ print(f" Output: ${self.output_price * 1_000_000:.4f} per 1M tokens")
131
+ print()
132
+
133
+ def create_batch_file(self, requests: List[Dict], output_path: str) -> str:
134
+ """Create a JSONL file for batch processing.
135
+
136
+ Args:
137
+ requests: List of request dictionaries
138
+ output_path: Path to save the JSONL file
139
+
140
+ Returns:
141
+ Path to the created file
142
+ """
143
+ with open(output_path, 'w', encoding='utf-8') as f:
144
+ for req in requests:
145
+ f.write(json.dumps(req, ensure_ascii=False) + '\n')
146
+
147
+ print(f"✅ Created batch file: {output_path}")
148
+ print(f" Total requests: {len(requests)}")
149
+ return output_path
150
+
151
+ def upload_batch_file(self, file_path: str) -> str:
152
+ """Upload batch file to OpenAI.
153
+
154
+ Args:
155
+ file_path: Path to the JSONL file
156
+
157
+ Returns:
158
+ File ID
159
+ """
160
+ print(f"⬆️ Uploading batch file to OpenAI...")
161
+ with open(file_path, 'rb') as f:
162
+ batch_file = self.client.files.create(
163
+ file=f,
164
+ purpose='batch'
165
+ )
166
+
167
+ print(f"✅ File uploaded: {batch_file.id}")
168
+ return batch_file.id
169
+
170
+ def create_batch(self, file_id: str, description: Optional[str] = None) -> str:
171
+ """Create a batch job.
172
+
173
+ Args:
174
+ file_id: ID of the uploaded file
175
+ description: Optional description for the batch
176
+
177
+ Returns:
178
+ Batch ID
179
+ """
180
+ print(f"🚀 Creating batch job...")
181
+ batch = self.client.batches.create(
182
+ input_file_id=file_id,
183
+ endpoint="/v1/chat/completions",
184
+ completion_window="24h",
185
+ metadata={
186
+ "description": description or "Programming problems generation",
187
+ "created_at": datetime.now().isoformat()
188
+ }
189
+ )
190
+
191
+ print(f"✅ Batch created: {batch.id}")
192
+ print(f" Status: {batch.status}")
193
+ print(f" Total requests: {batch.request_counts.total}")
194
+ return batch.id
195
+
196
+ def check_batch_status(self, batch_id: str) -> Dict:
197
+ """Check the status of a batch job.
198
+
199
+ Args:
200
+ batch_id: ID of the batch
201
+
202
+ Returns:
203
+ Batch status information
204
+ """
205
+ batch = self.client.batches.retrieve(batch_id)
206
+
207
+ status_info = {
208
+ 'id': batch.id,
209
+ 'status': batch.status,
210
+ 'created_at': batch.created_at,
211
+ 'completed_at': batch.completed_at,
212
+ 'failed_at': batch.failed_at,
213
+ 'expired_at': batch.expired_at,
214
+ 'request_counts': {
215
+ 'total': batch.request_counts.total,
216
+ 'completed': batch.request_counts.completed,
217
+ 'failed': batch.request_counts.failed,
218
+ },
219
+ 'output_file_id': batch.output_file_id,
220
+ 'error_file_id': batch.error_file_id,
221
+ }
222
+
223
+ return status_info
224
+
225
+ def download_results(self, file_id: str, output_path: str):
226
+ """Download batch results.
227
+
228
+ Args:
229
+ file_id: ID of the output file
230
+ output_path: Path to save the results
231
+ """
232
+ print(f"⬇️ Downloading results...")
233
+ content = self.client.files.content(file_id)
234
+
235
+ with open(output_path, 'wb') as f:
236
+ f.write(content.content)
237
+
238
+ print(f"✅ Results saved to: {output_path}")
239
+
240
+ def estimate_cost(self, num_requests: int, avg_input_tokens: int, avg_output_tokens: int) -> Dict:
241
+ """Estimate the cost of a batch job.
242
+
243
+ Args:
244
+ num_requests: Number of requests
245
+ avg_input_tokens: Average input tokens per request
246
+ avg_output_tokens: Average output tokens per request
247
+
248
+ Returns:
249
+ Cost estimation dictionary
250
+ """
251
+ total_input_tokens = num_requests * avg_input_tokens
252
+ total_output_tokens = num_requests * avg_output_tokens
253
+
254
+ input_cost = total_input_tokens * self.input_price
255
+ output_cost = total_output_tokens * self.output_price
256
+ total_cost = input_cost + output_cost
257
+
258
+ # Compare with standard API (2x the batch price)
259
+ standard_cost = total_cost * 2
260
+ savings = standard_cost - total_cost
261
+
262
+ return {
263
+ 'num_requests': num_requests,
264
+ 'total_input_tokens': total_input_tokens,
265
+ 'total_output_tokens': total_output_tokens,
266
+ 'total_tokens': total_input_tokens + total_output_tokens,
267
+ 'input_cost': input_cost,
268
+ 'output_cost': output_cost,
269
+ 'total_cost': total_cost,
270
+ 'standard_api_cost': standard_cost,
271
+ 'savings': savings,
272
+ 'savings_percentage': 50.0
273
+ }
274
+
275
+
276
+ def prepare_batch_requests(
277
+ input_file: str,
278
+ min_score: int = MIN_RELEVANCE_SCORE,
279
+ max_samples: Optional[int] = None,
280
+ start_from: int = 0,
281
+ ) -> List[Dict]:
282
+ """Prepare batch requests from function dataset.
283
+
284
+ Args:
285
+ input_file: Path to function_dataset_v2.csv
286
+ min_score: Minimum relevance score to process
287
+ max_samples: Maximum number of samples to process
288
+ start_from: Skip first N rows
289
+
290
+ Returns:
291
+ List of batch request dictionaries
292
+ """
293
+ print(f"📋 Preparing batch requests...")
294
+ print(f" Input: {input_file}")
295
+ print(f" Min Score: {min_score}")
296
+ if max_samples:
297
+ print(f" Max Samples: {max_samples}")
298
+ print()
299
+
300
+ requests = []
301
+ total_rows = 0
302
+ skipped_low_score = 0
303
+ skipped_no_code = 0
304
+
305
+ with open(input_file, 'r', encoding='utf-8') as infile:
306
+ reader = csv.DictReader(infile)
307
+
308
+ for row in reader:
309
+ total_rows += 1
310
+
311
+ # Skip if resuming
312
+ if total_rows <= start_from:
313
+ continue
314
+
315
+ # Check if we've reached max samples
316
+ if max_samples and len(requests) >= max_samples:
317
+ break
318
+
319
+ # Filter by relevance score
320
+ try:
321
+ relevance_score = int(row.get('relevance_score', 0))
322
+ except (ValueError, TypeError):
323
+ relevance_score = 0
324
+
325
+ if relevance_score < min_score:
326
+ skipped_low_score += 1
327
+ continue
328
+
329
+ # Get function content
330
+ function_content = row.get('function_content', '').strip()
331
+ if not function_content or len(function_content) < 50:
332
+ skipped_no_code += 1
333
+ continue
334
+
335
+ # Prepare metadata (OpenAI Batch API requires all metadata values to be strings)
336
+ metadata = {
337
+ 'original_index': str(row.get('original_index', '')),
338
+ 'function_name': str(row.get('function_name', '')),
339
+ 'repo_name': str(row.get('repo_name', '')),
340
+ 'path': str(row.get('path', '')),
341
+ 'language': str(row.get('language', '')),
342
+ 'relevance_score': str(relevance_score), # Convert to string!
343
+ 'function_start_line': str(row.get('function_start_line', '')),
344
+ 'function_end_line': str(row.get('function_end_line', '')),
345
+ }
346
+
347
+ # Generate prompt
348
+ prompt = PROMPT_TEMPLATE.format(code=function_content)
349
+
350
+ # Create batch request in OpenAI Batch API format
351
+ request = {
352
+ "custom_id": f"request-{len(requests)}",
353
+ "method": "POST",
354
+ "url": "/v1/chat/completions",
355
+ "body": {
356
+ "model": MODEL_NAME,
357
+ "messages": [
358
+ {
359
+ "role": "system",
360
+ "content": "You are an expert in scientific computing and programming education."
361
+ },
362
+ {
363
+ "role": "user",
364
+ "content": prompt
365
+ }
366
+ ],
367
+ "temperature": 0.7,
368
+ "metadata": metadata # All values are now strings
369
+ }
370
+ }
371
+
372
+ requests.append(request)
373
+
374
+ print(f"✅ Prepared {len(requests)} requests")
375
+ print(f" Total rows: {total_rows}")
376
+ print(f" Skipped (low score): {skipped_low_score}")
377
+ print(f" Skipped (no/short code): {skipped_no_code}")
378
+ print()
379
+
380
+ return requests
381
+
382
+
383
+ def process_batch_results(
384
+ results_file: str,
385
+ output_file: str,
386
+ model_name: str,
387
+ input_price: float,
388
+ output_price: float,
389
+ requests_file: Optional[str] = None
390
+ ):
391
+ """Process batch results and save to JSONL format.
392
+
393
+ Args:
394
+ results_file: Path to batch results file
395
+ output_file: Path to output JSONL file
396
+ model_name: Model name used
397
+ input_price: Input token price
398
+ output_price: Output token price
399
+ requests_file: Optional path to original batch requests file (to restore prompts)
400
+ """
401
+ print(f"📊 Processing batch results...")
402
+
403
+ # Load prompts from requests file if provided
404
+ prompts_map = {}
405
+ if requests_file and os.path.exists(requests_file):
406
+ print(f" Loading prompts from: {requests_file}")
407
+ with open(requests_file, 'r', encoding='utf-8') as f:
408
+ for line in f:
409
+ req = json.loads(line)
410
+ custom_id = req['custom_id']
411
+ # Extract prompt from messages
412
+ for msg in req['body']['messages']:
413
+ if msg['role'] == 'user':
414
+ prompts_map[custom_id] = msg['content']
415
+ break
416
+ print(f" Loaded {len(prompts_map)} prompts")
417
+
418
+ processed = 0
419
+ errors = 0
420
+ total_input_tokens = 0
421
+ total_output_tokens = 0
422
+ total_cost = 0.0
423
+
424
+ with open(results_file, 'r', encoding='utf-8') as infile, \
425
+ open(output_file, 'w', encoding='utf-8') as outfile:
426
+
427
+ for line in infile:
428
+ batch_result = json.loads(line)
429
+
430
+ # Check if request was successful
431
+ if batch_result.get('error'):
432
+ errors += 1
433
+ print(f"❌ Error in {batch_result['custom_id']}: {batch_result['error']}")
434
+ continue
435
+
436
+ response = batch_result['response']
437
+ custom_id = batch_result['custom_id']
438
+
439
+ # Extract usage information
440
+ usage = response['body']['usage']
441
+ input_tokens = usage['prompt_tokens']
442
+ output_tokens = usage['completion_tokens']
443
+
444
+ # Calculate cost
445
+ input_cost = input_tokens * input_price
446
+ output_cost = output_tokens * output_price
447
+ request_cost = input_cost + output_cost
448
+
449
+ # Update totals
450
+ total_input_tokens += input_tokens
451
+ total_output_tokens += output_tokens
452
+ total_cost += request_cost
453
+
454
+ # Get metadata from the original request
455
+ metadata = response['body'].get('metadata', {})
456
+
457
+ # Extract the response text
458
+ response_text = response['body']['choices'][0]['message']['content']
459
+
460
+ # Build result - include prompt if available
461
+ result = {
462
+ 'metadata': metadata,
463
+ 'response': response_text,
464
+ 'usage': {
465
+ 'input_tokens': input_tokens,
466
+ 'output_tokens': output_tokens,
467
+ 'total_tokens': input_tokens + output_tokens,
468
+ 'input_cost': input_cost,
469
+ 'output_cost': output_cost,
470
+ 'request_cost': request_cost
471
+ },
472
+ 'model': model_name,
473
+ 'timestamp': datetime.now().isoformat(),
474
+ 'custom_id': custom_id
475
+ }
476
+
477
+ # Add prompt if we have it
478
+ if custom_id in prompts_map:
479
+ result['prompt'] = prompts_map[custom_id]
480
+
481
+ outfile.write(json.dumps(result, ensure_ascii=False) + '\n')
482
+ processed += 1
483
+
484
+ print(f"\n✅ Processed {processed} results")
485
+ print(f" Errors: {errors}")
486
+ print()
487
+
488
+ # Print usage summary
489
+ print("=" * 70)
490
+ print("BATCH API USAGE SUMMARY")
491
+ print("=" * 70)
492
+ print(f"Model: {model_name}")
493
+ print(f"Total Requests: {processed}")
494
+ print(f"Total Input Tokens: {total_input_tokens:,}")
495
+ print(f"Total Output Tokens: {total_output_tokens:,}")
496
+ print(f"Total Tokens: {total_input_tokens + total_output_tokens:,}")
497
+ print(f"\nBatch API Cost: ${total_cost:.6f}")
498
+ print(f"Standard API Cost: ${total_cost * 2:.6f}")
499
+ print(f"Savings (50%): ${total_cost:.6f}")
500
+ print("=" * 70)
501
+
502
+
503
+ def main():
504
+ import argparse
505
+
506
+ parser = argparse.ArgumentParser(
507
+ description='Generate programming problems using OpenAI Batch API (50% cost savings)'
508
+ )
509
+
510
+ subparsers = parser.add_subparsers(dest='command', help='Command to run')
511
+
512
+ # Prepare command
513
+ prepare_parser = subparsers.add_parser('prepare', help='Prepare batch requests')
514
+ prepare_parser.add_argument('--input', default='function_dataset_v2.csv')
515
+ prepare_parser.add_argument('--output', default='batch_requests.jsonl')
516
+ prepare_parser.add_argument('--min-score', type=int, default=MIN_RELEVANCE_SCORE)
517
+ prepare_parser.add_argument('--max-samples', type=int, default=None)
518
+ prepare_parser.add_argument('--start-from', type=int, default=0)
519
+ prepare_parser.add_argument('--model', default=MODEL_NAME)
520
+
521
+ # Submit command
522
+ submit_parser = subparsers.add_parser('submit', help='Submit batch job to OpenAI')
523
+ submit_parser.add_argument('--input', default='batch_requests.jsonl')
524
+ submit_parser.add_argument('--model', default=MODEL_NAME)
525
+ submit_parser.add_argument('--description', default='Programming problems generation')
526
+
527
+ # Status command
528
+ status_parser = subparsers.add_parser('status', help='Check batch job status')
529
+ status_parser.add_argument('batch_id', help='Batch ID to check')
530
+
531
+ # Download command
532
+ download_parser = subparsers.add_parser('download', help='Download batch results')
533
+ download_parser.add_argument('batch_id', help='Batch ID to download')
534
+ download_parser.add_argument('--output', default='batch_results.jsonl')
535
+
536
+ # Process command
537
+ process_parser = subparsers.add_parser('process', help='Process downloaded results')
538
+ process_parser.add_argument('--input', default='batch_results.jsonl')
539
+ process_parser.add_argument('--output', default='programming_problems_batch.jsonl')
540
+ process_parser.add_argument('--model', default=MODEL_NAME)
541
+ process_parser.add_argument('--requests', default='batch_requests_full.jsonl',
542
+ help='Original batch requests file (to restore prompts)')
543
+
544
+ # Estimate command
545
+ estimate_parser = subparsers.add_parser('estimate', help='Estimate batch cost')
546
+ estimate_parser.add_argument('--num-requests', type=int, required=True)
547
+ estimate_parser.add_argument('--avg-input-tokens', type=int, default=1917)
548
+ estimate_parser.add_argument('--avg-output-tokens', type=int, default=2552)
549
+ estimate_parser.add_argument('--model', default=MODEL_NAME)
550
+
551
+ args = parser.parse_args()
552
+
553
+ if not args.command:
554
+ parser.print_help()
555
+ sys.exit(1)
556
+
557
+ # Check API key
558
+ if not os.getenv('OPENAI_API_KEY'):
559
+ print("❌ Error: OPENAI_API_KEY environment variable not set.")
560
+ print(" Please set it with: export OPENAI_API_KEY='your-api-key'")
561
+ sys.exit(1)
562
+
563
+ client = BatchAPIClient(model_name=args.model if hasattr(args, 'model') else MODEL_NAME)
564
+
565
+ if args.command == 'prepare':
566
+ requests = prepare_batch_requests(
567
+ input_file=args.input,
568
+ min_score=args.min_score,
569
+ max_samples=args.max_samples,
570
+ start_from=args.start_from
571
+ )
572
+
573
+ client.create_batch_file(requests, args.output)
574
+
575
+ # Estimate cost
576
+ print("\n💰 Cost Estimation:")
577
+ estimate = client.estimate_cost(
578
+ num_requests=len(requests),
579
+ avg_input_tokens=1917, # From your test
580
+ avg_output_tokens=2552 # From your test
581
+ )
582
+ print(f" Estimated Batch API Cost: ${estimate['total_cost']:.2f}")
583
+ print(f" Standard API Cost: ${estimate['standard_api_cost']:.2f}")
584
+ print(f" Savings (50%): ${estimate['savings']:.2f}")
585
+ print()
586
+
587
+ elif args.command == 'submit':
588
+ file_id = client.upload_batch_file(args.input)
589
+ batch_id = client.create_batch(file_id, args.description)
590
+
591
+ print(f"\n📝 Save this Batch ID: {batch_id}")
592
+ print(f" Check status with: python3 {sys.argv[0]} status {batch_id}")
593
+
594
+ elif args.command == 'status':
595
+ status = client.check_batch_status(args.batch_id)
596
+
597
+ print("\n📊 Batch Status:")
598
+ print(f" ID: {status['id']}")
599
+ print(f" Status: {status['status']}")
600
+ print(f" Total: {status['request_counts']['total']}")
601
+ print(f" Completed: {status['request_counts']['completed']}")
602
+ print(f" Failed: {status['request_counts']['failed']}")
603
+
604
+ if status['status'] == 'completed':
605
+ print(f"\n✅ Batch completed!")
606
+ print(f" Download with: python3 {sys.argv[0]} download {args.batch_id}")
607
+ elif status['status'] == 'failed':
608
+ print(f"\n❌ Batch failed!")
609
+ else:
610
+ print(f"\n⏳ Batch is still processing...")
611
+
612
+ elif args.command == 'download':
613
+ status = client.check_batch_status(args.batch_id)
614
+
615
+ if status['status'] != 'completed':
616
+ print(f"❌ Batch is not completed yet (status: {status['status']})")
617
+ sys.exit(1)
618
+
619
+ client.download_results(status['output_file_id'], args.output)
620
+ print(f"\n✅ Downloaded to: {args.output}")
621
+ print(f" Process with: python3 {sys.argv[0]} process --input {args.output}")
622
+
623
+ elif args.command == 'process':
624
+ process_batch_results(
625
+ results_file=args.input,
626
+ output_file=args.output,
627
+ model_name=args.model,
628
+ input_price=client.input_price,
629
+ output_price=client.output_price,
630
+ requests_file=args.requests
631
+ )
632
+ print(f"\n✅ Final results saved to: {args.output}")
633
+
634
+ elif args.command == 'estimate':
635
+ estimate = client.estimate_cost(
636
+ num_requests=args.num_requests,
637
+ avg_input_tokens=args.avg_input_tokens,
638
+ avg_output_tokens=args.avg_output_tokens
639
+ )
640
+
641
+ print("\n💰 COST ESTIMATION")
642
+ print("=" * 70)
643
+ print(f"Number of Requests: {estimate['num_requests']:,}")
644
+ print(f"Total Input Tokens: {estimate['total_input_tokens']:,}")
645
+ print(f"Total Output Tokens: {estimate['total_output_tokens']:,}")
646
+ print(f"Total Tokens: {estimate['total_tokens']:,}")
647
+ print()
648
+ print(f"Batch API Cost: ${estimate['total_cost']:.2f}")
649
+ print(f"Standard API Cost: ${estimate['standard_api_cost']:.2f}")
650
+ print(f"💰 Savings (50%): ${estimate['savings']:.2f}")
651
+ print("=" * 70)
652
+
653
+
654
+ if __name__ == "__main__":
655
+ main()