ModerRAS commited on
Commit
7c5b41e
·
1 Parent(s): e964ae5

feat: add concurrency and token/cost telemetry for llm relabel runs

Browse files
datasets/AnimeName CHANGED
@@ -1 +1 @@
1
- Subproject commit 9987cc8d7b7bf829d0022ee6e6a0b08de5327975
 
1
+ Subproject commit 5de6ddeed7dafd43207953072a9e197f13b32077
reports/llm_relabel_perf_char_chunk2.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "wall_seconds": 586.9843375682831,
3
+ "rows_processed": 1000,
4
+ "rows_per_second": 1.703622969128493,
5
+ "batches_completed": 250,
6
+ "avg_batch_seconds": 15.12994603919983,
7
+ "avg_active_workers": 6.443932279510124,
8
+ "max_active_workers": 8,
9
+ "configured_workers": 8
10
+ }
reports/llm_relabel_perf_char_chunk3_tokens.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "wall_seconds": 739.8576474189758,
3
+ "rows_processed": 1000,
4
+ "rows_per_second": 1.351611358601944,
5
+ "batches_completed": 250,
6
+ "avg_batch_seconds": 20.91811747932434,
7
+ "avg_active_workers": 7.068292213849803,
8
+ "max_active_workers": 8,
9
+ "configured_workers": 8,
10
+ "input_tokens": 406372,
11
+ "output_tokens": 351059,
12
+ "cached_tokens": 52096,
13
+ "reasoning_tokens": 220306,
14
+ "input_tokens_per_sec": 549.2570110177892,
15
+ "output_tokens_per_sec": 474.49533193943984,
16
+ "input_tokens_per_hour": 1977325.239664041,
17
+ "output_tokens_per_hour": 1708183.1949819834,
18
+ "usd_per_1m_input": 0.0,
19
+ "usd_per_1m_output": 0.0
20
+ }
reports/llm_relabel_perf_weak_chunk2.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "wall_seconds": 436.6158034801483,
3
+ "rows_processed": 1000,
4
+ "rows_per_second": 2.2903431163720285,
5
+ "batches_completed": 250,
6
+ "avg_batch_seconds": 13.27953750705719,
7
+ "avg_active_workers": 7.603676611459359,
8
+ "max_active_workers": 8,
9
+ "configured_workers": 8
10
+ }
reports/llm_relabel_perf_weak_chunk3_tokens.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "wall_seconds": 315.9569420814514,
3
+ "rows_processed": 1000,
4
+ "rows_per_second": 3.164988220901971,
5
+ "batches_completed": 250,
6
+ "avg_batch_seconds": 9.494983646392821,
7
+ "avg_active_workers": 7.51287982740815,
8
+ "max_active_workers": 8,
9
+ "configured_workers": 8,
10
+ "input_tokens": 271438,
11
+ "output_tokens": 173591,
12
+ "cached_tokens": 0,
13
+ "reasoning_tokens": 109518,
14
+ "input_tokens_per_sec": 859.0980727051892,
15
+ "output_tokens_per_sec": 549.413470254594,
16
+ "input_tokens_per_hour": 3092753.061738681,
17
+ "output_tokens_per_hour": 1977888.4929165384,
18
+ "usd_per_1m_input": 0.0,
19
+ "usd_per_1m_output": 0.0
20
+ }
reports/llm_relabel_perf_weak_smoke.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "wall_seconds": 134.8101508617401,
3
+ "rows_processed": 40,
4
+ "rows_per_second": 0.29671356158501433,
5
+ "batches_completed": 10,
6
+ "avg_batch_seconds": 20.668276357650758,
7
+ "avg_active_workers": 1.5331398703993708,
8
+ "max_active_workers": 8,
9
+ "configured_workers": 8
10
+ }
reports/llm_relabel_perf_weak_smoke_with_tokens.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "wall_seconds": 36.052828311920166,
3
+ "rows_processed": 40,
4
+ "rows_per_second": 1.1094829968381366,
5
+ "batches_completed": 10,
6
+ "avg_batch_seconds": 11.140262508392334,
7
+ "avg_active_workers": 3.0899838841684244,
8
+ "max_active_workers": 8,
9
+ "configured_workers": 8,
10
+ "input_tokens": 9621,
11
+ "output_tokens": 6818,
12
+ "cached_tokens": 0,
13
+ "reasoning_tokens": 4760,
14
+ "input_tokens_per_sec": 266.8583978144928,
15
+ "output_tokens_per_sec": 189.11137681106038,
16
+ "input_tokens_per_hour": 960690.232132174,
17
+ "output_tokens_per_hour": 680800.9565198174,
18
+ "usd_per_1m_input": 0.0,
19
+ "usd_per_1m_output": 0.0
20
+ }
tools/llm_relabel_rows.py CHANGED
@@ -13,6 +13,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
13
  import json
14
  import os
15
  import re
 
16
  import time
17
  from dataclasses import dataclass
18
  from pathlib import Path
@@ -67,6 +68,61 @@ class Row:
67
  record: Dict[str, Any]
68
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def parse_args() -> argparse.Namespace:
71
  p = argparse.ArgumentParser(description="Relabel selected JSONL rows via Responses API")
72
  p.add_argument("--input", required=True, help="Input JSONL")
@@ -84,6 +140,9 @@ def parse_args() -> argparse.Namespace:
84
  p.add_argument("--concurrency", type=int, default=4, help="Parallel request workers")
85
  p.add_argument("--max-rows", type=int, default=0, help="Optional cap; 0 means no cap")
86
  p.add_argument("--skip-selected", type=int, default=0, help="Skip this many selected rows before processing")
 
 
 
87
  p.add_argument("--retries", type=int, default=3, help="Retries per batch")
88
  p.add_argument("--sleep-ms", type=int, default=150, help="Delay between successful calls")
89
  p.add_argument("--prompt-cache-key", default="anifilebert-relabel-v1", help="Stable prompt cache key")
@@ -91,6 +150,9 @@ def parse_args() -> argparse.Namespace:
91
  p.add_argument("--reasoning-effort", default="medium", help="Reasoning effort (e.g. low/medium/high)")
92
  p.add_argument("--checkpoint-rows", type=int, default=100, help="Write checkpoint every N processed rows")
93
  p.add_argument("--failure-log", default="reports/llm_relabel_failures.log", help="Failure log path")
 
 
 
94
  p.add_argument(
95
  "--user-agent",
96
  default="Codex Desktop/0.133.0-alpha.1 (Windows 10.0.22631; x86_64) unknown (Codex Desktop; 26.519.41501)",
@@ -214,6 +276,18 @@ def append_failure_log(path: str, message: str) -> None:
214
  f.write(message.rstrip() + "\n")
215
 
216
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  def relabel_batch(
218
  api_base: str,
219
  api_key: str,
@@ -225,7 +299,7 @@ def relabel_batch(
225
  user_agent: str,
226
  retries: int,
227
  failure_log: str,
228
- ) -> Dict[int, List[str]]:
229
  url = f"{api_base.rstrip('/')}/responses"
230
  headers = {
231
  "Authorization": f"Bearer {api_key}",
@@ -259,6 +333,7 @@ def relabel_batch(
259
  resp = requests.post(url, headers=headers, json=body, timeout=120)
260
  resp.raise_for_status()
261
  obj = resp.json()
 
262
  try:
263
  parsed = extract_function_args(obj, "submit_labels")
264
  except Exception:
@@ -300,7 +375,7 @@ def relabel_batch(
300
  )
301
  raise ValueError(f"incomplete/invalid rows from model: missing={missing}")
302
 
303
- return mapping
304
  except Exception as exc: # noqa: BLE001
305
  last_error = exc
306
  # Some compatible gateways may not support prompt caching or reasoning fields.
@@ -336,8 +411,9 @@ def process_batch_with_fallback(
336
  retries: int,
337
  failure_log: str,
338
  ) -> List[tuple[Row, List[str]]]:
 
339
  try:
340
- mapping = relabel_batch(
341
  api_base=api_base,
342
  api_key=api_key,
343
  model=model,
@@ -349,11 +425,12 @@ def process_batch_with_fallback(
349
  retries=retries,
350
  failure_log=failure_log,
351
  )
 
352
  except RuntimeError:
353
  mapping = {}
354
  for idx, row in enumerate(batch):
355
  try:
356
- single = relabel_batch(
357
  api_base=api_base,
358
  api_key=api_key,
359
  model=model,
@@ -365,6 +442,7 @@ def process_batch_with_fallback(
365
  retries=max(retries, 4),
366
  failure_log=failure_log,
367
  )
 
368
  mapping[idx] = single[0]
369
  except RuntimeError as exc:
370
  append_failure_log(
@@ -372,7 +450,45 @@ def process_batch_with_fallback(
372
  f"[row-skip] file_id={row.record.get('file_id')} line={row.line_no} reason={exc}",
373
  )
374
  mapping[idx] = row.record.get("labels", [])
375
- return [(batch[row_id], labels) for row_id, labels in mapping.items()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
 
377
 
378
  def main() -> None:
@@ -385,6 +501,18 @@ def main() -> None:
385
  output_path = Path(args.output)
386
 
387
  all_records, selected_rows = load_rows(input_path, args.selector)
 
 
 
 
 
 
 
 
 
 
 
 
388
  if args.skip_selected > 0:
389
  selected_rows = selected_rows[args.skip_selected:]
390
  if args.max_rows > 0:
@@ -404,10 +532,16 @@ def main() -> None:
404
  ]
405
 
406
  done_rows = 0
 
 
 
 
 
407
  with ThreadPoolExecutor(max_workers=concurrency) as executor:
408
  futures = [
409
  executor.submit(
410
- process_batch_with_fallback,
 
411
  api_base=args.api_base,
412
  api_key=api_key,
413
  model=args.model,
@@ -422,14 +556,35 @@ def main() -> None:
422
  for batch in batches
423
  ]
424
  for fut in as_completed(futures):
425
- updates = fut.result()
 
 
 
 
426
  for row, new_labels in updates:
427
  rec = row.record
428
  if rec.get("labels") != new_labels:
429
  rec["labels"] = new_labels
430
  changed += 1
431
  done_rows += len(updates)
432
- print(f"processed={done_rows}/{total} changed={changed}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  if args.checkpoint_rows > 0 and (done_rows % args.checkpoint_rows == 0 or done_rows == total):
434
  write_jsonl(output_path, all_records)
435
  if args.sleep_ms > 0:
@@ -437,6 +592,44 @@ def main() -> None:
437
 
438
  # rows in selected_rows reference dicts in all_records by identity, so changes are already reflected.
439
  write_jsonl(output_path, all_records)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  print(f"done selected_rows={total} changed_rows={changed} output={output_path}")
441
 
442
 
 
13
  import json
14
  import os
15
  import re
16
+ import threading
17
  import time
18
  from dataclasses import dataclass
19
  from pathlib import Path
 
68
  record: Dict[str, Any]
69
 
70
 
71
+ class ConcurrentMeter:
72
+ def __init__(self) -> None:
73
+ self._lock = threading.Lock()
74
+ self.current_active = 0
75
+ self.max_active = 0
76
+ self.active_time_accum = 0.0
77
+ self.last_ts = time.time()
78
+
79
+ def _accumulate(self, now: float) -> None:
80
+ dt = now - self.last_ts
81
+ if dt > 0:
82
+ self.active_time_accum += self.current_active * dt
83
+ self.last_ts = now
84
+
85
+ def task_start(self) -> None:
86
+ now = time.time()
87
+ with self._lock:
88
+ self._accumulate(now)
89
+ self.current_active += 1
90
+ if self.current_active > self.max_active:
91
+ self.max_active = self.current_active
92
+
93
+ def task_end(self) -> None:
94
+ now = time.time()
95
+ with self._lock:
96
+ self._accumulate(now)
97
+ if self.current_active > 0:
98
+ self.current_active -= 1
99
+
100
+ def snapshot(self) -> Dict[str, float]:
101
+ now = time.time()
102
+ with self._lock:
103
+ self._accumulate(now)
104
+ return {
105
+ "current_active": float(self.current_active),
106
+ "max_active": float(self.max_active),
107
+ "active_time_accum": float(self.active_time_accum),
108
+ "timestamp": now,
109
+ }
110
+
111
+
112
+ @dataclass
113
+ class UsageStats:
114
+ input_tokens: int = 0
115
+ output_tokens: int = 0
116
+ cached_tokens: int = 0
117
+ reasoning_tokens: int = 0
118
+
119
+ def add(self, other: "UsageStats") -> None:
120
+ self.input_tokens += int(other.input_tokens)
121
+ self.output_tokens += int(other.output_tokens)
122
+ self.cached_tokens += int(other.cached_tokens)
123
+ self.reasoning_tokens += int(other.reasoning_tokens)
124
+
125
+
126
  def parse_args() -> argparse.Namespace:
127
  p = argparse.ArgumentParser(description="Relabel selected JSONL rows via Responses API")
128
  p.add_argument("--input", required=True, help="Input JSONL")
 
140
  p.add_argument("--concurrency", type=int, default=4, help="Parallel request workers")
141
  p.add_argument("--max-rows", type=int, default=0, help="Optional cap; 0 means no cap")
142
  p.add_argument("--skip-selected", type=int, default=0, help="Skip this many selected rows before processing")
143
+ p.add_argument("--min-token-len", type=int, default=0, help="Only process rows with token length >= this value")
144
+ p.add_argument("--max-token-len", type=int, default=0, help="Only process rows with token length <= this value (0 = no limit)")
145
+ p.add_argument("--sort-by", choices=("none", "token_len_asc"), default="none", help="Optional ordering of selected rows")
146
  p.add_argument("--retries", type=int, default=3, help="Retries per batch")
147
  p.add_argument("--sleep-ms", type=int, default=150, help="Delay between successful calls")
148
  p.add_argument("--prompt-cache-key", default="anifilebert-relabel-v1", help="Stable prompt cache key")
 
150
  p.add_argument("--reasoning-effort", default="medium", help="Reasoning effort (e.g. low/medium/high)")
151
  p.add_argument("--checkpoint-rows", type=int, default=100, help="Write checkpoint every N processed rows")
152
  p.add_argument("--failure-log", default="reports/llm_relabel_failures.log", help="Failure log path")
153
+ p.add_argument("--perf-log", default="", help="Optional JSON perf summary path")
154
+ p.add_argument("--usd-per-1m-input", type=float, default=0.0, help="Input token price (USD per 1M tokens)")
155
+ p.add_argument("--usd-per-1m-output", type=float, default=0.0, help="Output token price (USD per 1M tokens)")
156
  p.add_argument(
157
  "--user-agent",
158
  default="Codex Desktop/0.133.0-alpha.1 (Windows 10.0.22631; x86_64) unknown (Codex Desktop; 26.519.41501)",
 
276
  f.write(message.rstrip() + "\n")
277
 
278
 
279
+ def parse_usage(response_obj: Dict[str, Any]) -> UsageStats:
280
+ usage = response_obj.get("usage", {}) or {}
281
+ in_details = usage.get("input_tokens_details", {}) or {}
282
+ out_details = usage.get("output_tokens_details", {}) or {}
283
+ return UsageStats(
284
+ input_tokens=int(usage.get("input_tokens", 0) or 0),
285
+ output_tokens=int(usage.get("output_tokens", 0) or 0),
286
+ cached_tokens=int(in_details.get("cached_tokens", 0) or 0),
287
+ reasoning_tokens=int(out_details.get("reasoning_tokens", 0) or 0),
288
+ )
289
+
290
+
291
  def relabel_batch(
292
  api_base: str,
293
  api_key: str,
 
299
  user_agent: str,
300
  retries: int,
301
  failure_log: str,
302
+ ) -> tuple[Dict[int, List[str]], UsageStats]:
303
  url = f"{api_base.rstrip('/')}/responses"
304
  headers = {
305
  "Authorization": f"Bearer {api_key}",
 
333
  resp = requests.post(url, headers=headers, json=body, timeout=120)
334
  resp.raise_for_status()
335
  obj = resp.json()
336
+ usage_stats = parse_usage(obj)
337
  try:
338
  parsed = extract_function_args(obj, "submit_labels")
339
  except Exception:
 
375
  )
376
  raise ValueError(f"incomplete/invalid rows from model: missing={missing}")
377
 
378
+ return mapping, usage_stats
379
  except Exception as exc: # noqa: BLE001
380
  last_error = exc
381
  # Some compatible gateways may not support prompt caching or reasoning fields.
 
411
  retries: int,
412
  failure_log: str,
413
  ) -> List[tuple[Row, List[str]]]:
414
+ usage_total = UsageStats()
415
  try:
416
+ mapping, usage = relabel_batch(
417
  api_base=api_base,
418
  api_key=api_key,
419
  model=model,
 
425
  retries=retries,
426
  failure_log=failure_log,
427
  )
428
+ usage_total.add(usage)
429
  except RuntimeError:
430
  mapping = {}
431
  for idx, row in enumerate(batch):
432
  try:
433
+ single, usage = relabel_batch(
434
  api_base=api_base,
435
  api_key=api_key,
436
  model=model,
 
442
  retries=max(retries, 4),
443
  failure_log=failure_log,
444
  )
445
+ usage_total.add(usage)
446
  mapping[idx] = single[0]
447
  except RuntimeError as exc:
448
  append_failure_log(
 
450
  f"[row-skip] file_id={row.record.get('file_id')} line={row.line_no} reason={exc}",
451
  )
452
  mapping[idx] = row.record.get("labels", [])
453
+ return [(batch[row_id], labels) for row_id, labels in mapping.items()], usage_total
454
+
455
+
456
+ def process_batch_timed(
457
+ meter: ConcurrentMeter,
458
+ api_base: str,
459
+ api_key: str,
460
+ model: str,
461
+ batch: Sequence[Row],
462
+ prompt_cache_key: str,
463
+ prompt_cache_retention: str,
464
+ reasoning_effort: str,
465
+ user_agent: str,
466
+ retries: int,
467
+ failure_log: str,
468
+ ) -> Dict[str, Any]:
469
+ meter.task_start()
470
+ t0 = time.time()
471
+ try:
472
+ updates, usage = process_batch_with_fallback(
473
+ api_base=api_base,
474
+ api_key=api_key,
475
+ model=model,
476
+ batch=batch,
477
+ prompt_cache_key=prompt_cache_key,
478
+ prompt_cache_retention=prompt_cache_retention,
479
+ reasoning_effort=reasoning_effort,
480
+ user_agent=user_agent,
481
+ retries=retries,
482
+ failure_log=failure_log,
483
+ )
484
+ return {
485
+ "updates": updates,
486
+ "elapsed": time.time() - t0,
487
+ "batch_size": len(batch),
488
+ "usage": usage,
489
+ }
490
+ finally:
491
+ meter.task_end()
492
 
493
 
494
  def main() -> None:
 
501
  output_path = Path(args.output)
502
 
503
  all_records, selected_rows = load_rows(input_path, args.selector)
504
+ if args.min_token_len > 0 or args.max_token_len > 0:
505
+ filtered: List[Row] = []
506
+ for row in selected_rows:
507
+ tok_len = len(row.record.get("tokens", []))
508
+ if tok_len < args.min_token_len:
509
+ continue
510
+ if args.max_token_len > 0 and tok_len > args.max_token_len:
511
+ continue
512
+ filtered.append(row)
513
+ selected_rows = filtered
514
+ if args.sort_by == "token_len_asc":
515
+ selected_rows.sort(key=lambda r: len(r.record.get("tokens", [])))
516
  if args.skip_selected > 0:
517
  selected_rows = selected_rows[args.skip_selected:]
518
  if args.max_rows > 0:
 
532
  ]
533
 
534
  done_rows = 0
535
+ wall_start = time.time()
536
+ meter = ConcurrentMeter()
537
+ total_batch_elapsed = 0.0
538
+ completed_batches = 0
539
+ usage_total = UsageStats()
540
  with ThreadPoolExecutor(max_workers=concurrency) as executor:
541
  futures = [
542
  executor.submit(
543
+ process_batch_timed,
544
+ meter,
545
  api_base=args.api_base,
546
  api_key=api_key,
547
  model=args.model,
 
556
  for batch in batches
557
  ]
558
  for fut in as_completed(futures):
559
+ result = fut.result()
560
+ updates = result["updates"]
561
+ total_batch_elapsed += float(result["elapsed"])
562
+ completed_batches += 1
563
+ usage_total.add(result["usage"])
564
  for row, new_labels in updates:
565
  rec = row.record
566
  if rec.get("labels") != new_labels:
567
  rec["labels"] = new_labels
568
  changed += 1
569
  done_rows += len(updates)
570
+ snap = meter.snapshot()
571
+ wall_elapsed = max(1e-9, snap["timestamp"] - wall_start)
572
+ rows_per_sec = done_rows / wall_elapsed
573
+ avg_active = snap["active_time_accum"] / wall_elapsed
574
+ in_tok_per_sec = usage_total.input_tokens / wall_elapsed
575
+ out_tok_per_sec = usage_total.output_tokens / wall_elapsed
576
+ hourly_usd = 0.0
577
+ if args.usd_per_1m_input > 0 or args.usd_per_1m_output > 0:
578
+ cost = (usage_total.input_tokens / 1_000_000.0) * args.usd_per_1m_input + (
579
+ usage_total.output_tokens / 1_000_000.0
580
+ ) * args.usd_per_1m_output
581
+ hourly_usd = cost / wall_elapsed * 3600.0
582
+ print(
583
+ f"processed={done_rows}/{total} changed={changed} "
584
+ f"rows_per_sec={rows_per_sec:.2f} active_now={int(snap['current_active'])} "
585
+ f"avg_active={avg_active:.2f} max_active={int(snap['max_active'])}/{concurrency} "
586
+ f"in_tok_s={in_tok_per_sec:.1f} out_tok_s={out_tok_per_sec:.1f} usd_h={hourly_usd:.3f}"
587
+ )
588
  if args.checkpoint_rows > 0 and (done_rows % args.checkpoint_rows == 0 or done_rows == total):
589
  write_jsonl(output_path, all_records)
590
  if args.sleep_ms > 0:
 
592
 
593
  # rows in selected_rows reference dicts in all_records by identity, so changes are already reflected.
594
  write_jsonl(output_path, all_records)
595
+ wall_total = time.time() - wall_start
596
+ final_snap = meter.snapshot()
597
+ avg_active = final_snap["active_time_accum"] / max(1e-9, wall_total)
598
+ perf_summary = {
599
+ "wall_seconds": wall_total,
600
+ "rows_processed": done_rows,
601
+ "rows_per_second": done_rows / max(1e-9, wall_total),
602
+ "batches_completed": completed_batches,
603
+ "avg_batch_seconds": total_batch_elapsed / max(1, completed_batches),
604
+ "avg_active_workers": avg_active,
605
+ "max_active_workers": int(final_snap["max_active"]),
606
+ "configured_workers": concurrency,
607
+ "input_tokens": usage_total.input_tokens,
608
+ "output_tokens": usage_total.output_tokens,
609
+ "cached_tokens": usage_total.cached_tokens,
610
+ "reasoning_tokens": usage_total.reasoning_tokens,
611
+ "input_tokens_per_sec": usage_total.input_tokens / max(1e-9, wall_total),
612
+ "output_tokens_per_sec": usage_total.output_tokens / max(1e-9, wall_total),
613
+ "input_tokens_per_hour": usage_total.input_tokens / max(1e-9, wall_total) * 3600.0,
614
+ "output_tokens_per_hour": usage_total.output_tokens / max(1e-9, wall_total) * 3600.0,
615
+ "usd_per_1m_input": args.usd_per_1m_input,
616
+ "usd_per_1m_output": args.usd_per_1m_output,
617
+ }
618
+ if args.usd_per_1m_input > 0 or args.usd_per_1m_output > 0:
619
+ total_cost = (usage_total.input_tokens / 1_000_000.0) * args.usd_per_1m_input + (
620
+ usage_total.output_tokens / 1_000_000.0
621
+ ) * args.usd_per_1m_output
622
+ perf_summary["usd_total"] = total_cost
623
+ perf_summary["usd_per_hour"] = total_cost / max(1e-9, wall_total) * 3600.0
624
+ if args.perf_log:
625
+ p = Path(args.perf_log)
626
+ p.parent.mkdir(parents=True, exist_ok=True)
627
+ p.write_text(json.dumps(perf_summary, ensure_ascii=False, indent=2), encoding="utf-8")
628
+ print(
629
+ f"perf wall={wall_total:.1f}s rows_per_sec={perf_summary['rows_per_second']:.2f} "
630
+ f"avg_active={avg_active:.2f} max_active={int(final_snap['max_active'])}/{concurrency} "
631
+ f"in_tok_s={perf_summary['input_tokens_per_sec']:.1f} out_tok_s={perf_summary['output_tokens_per_sec']:.1f}"
632
+ )
633
  print(f"done selected_rows={total} changed_rows={changed} output={output_path}")
634
 
635