AniFileBERT / tools /llm_relabel_rows.py
ModerRAS's picture
Update relabel artifacts and pathfix export
5af16a8
raw
history blame
30.3 kB
#!/usr/bin/env python3
"""
Relabel selected rows in a JSONL dataset via an OpenAI-compatible Responses API.
Designed for high-throughput cleanup with a stable prompt prefix and
`prompt_cache_key` to improve cache hit rates across calls.
"""
from __future__ import annotations
import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
import json
import os
import re
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Sequence
import requests
from anifilebert.label_repairs import repair_jsonl_item
ALLOWED_LABELS = {
"O",
"B-TITLE", "I-TITLE",
"B-SEASON", "I-SEASON",
"B-EPISODE", "I-EPISODE",
"B-SPECIAL", "I-SPECIAL",
"B-GROUP", "I-GROUP",
"B-RESOLUTION", "I-RESOLUTION",
"B-SOURCE", "I-SOURCE",
}
LANG_MARKERS = (
"中文版",
"日语版",
"国语版",
"粤语版",
"英语版",
"英配版",
"中配版",
"日配版",
)
BRACKET_DELIMITER_TOKENS = {
"[",
"]",
"(",
")",
"【",
"】",
"《",
"》",
"(",
")",
}
SYSTEM_INSTRUCTIONS = """You relabel anime filename tokens with BIO tags.
Allowed labels only:
O, B/I-TITLE, B/I-SEASON, B/I-EPISODE, B/I-SPECIAL, B/I-GROUP, B/I-RESOLUTION, B/I-SOURCE.
Hard rules:
1) Output exactly one label per token.
2) Language markers like 中文版/日语版/国语版/粤语版/英语版/英配版/中配版/日配版 must be SOURCE.
3) Episode identifiers (e.g. 01, 13, EP13, 第13集/話/话) must be EPISODE.
4) If title already appears before episode number, episode-name text after the episode number should be O (not TITLE).
5) Preserve obvious GROUP/RESOLUTION/SOURCE tags when present.
6) If bracket delimiters are split into standalone tokens (`[ ] ( ) 【 】 《 》 ( )`), they must be O.
Return strict JSON only:
{"results":[{"row_id":int,"labels":[str,...]}]}
No markdown. No explanation.
"""
@dataclass
class Row:
line_no: int
record: Dict[str, Any]
class ConcurrentMeter:
def __init__(self) -> None:
self._lock = threading.Lock()
self.current_active = 0
self.max_active = 0
self.active_time_accum = 0.0
self.last_ts = time.time()
def _accumulate(self, now: float) -> None:
dt = now - self.last_ts
if dt > 0:
self.active_time_accum += self.current_active * dt
self.last_ts = now
def task_start(self) -> None:
now = time.time()
with self._lock:
self._accumulate(now)
self.current_active += 1
if self.current_active > self.max_active:
self.max_active = self.current_active
def task_end(self) -> None:
now = time.time()
with self._lock:
self._accumulate(now)
if self.current_active > 0:
self.current_active -= 1
def snapshot(self) -> Dict[str, float]:
now = time.time()
with self._lock:
self._accumulate(now)
return {
"current_active": float(self.current_active),
"max_active": float(self.max_active),
"active_time_accum": float(self.active_time_accum),
"timestamp": now,
}
@dataclass
class UsageStats:
input_tokens: int = 0
output_tokens: int = 0
cached_tokens: int = 0
reasoning_tokens: int = 0
def add(self, other: "UsageStats") -> None:
self.input_tokens += int(other.input_tokens)
self.output_tokens += int(other.output_tokens)
self.cached_tokens += int(other.cached_tokens)
self.reasoning_tokens += int(other.reasoning_tokens)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Relabel selected JSONL rows via Responses API")
p.add_argument("--input", required=True, help="Input JSONL")
p.add_argument("--output", required=True, help="Output JSONL (can equal input)")
p.add_argument("--api-base", required=True, help="API base URL, e.g. http://host:port/v1")
p.add_argument("--api-key", default=None, help="API key; falls back to env ANIFILEBERT_RELABEL_API_KEY")
p.add_argument("--model", default="gpt-5.4-mini", help="Model name")
p.add_argument(
"--selector",
choices=("language", "discontinuous_title", "all"),
default="language",
help="Row selector",
)
p.add_argument("--batch-size", type=int, default=12, help="Rows per request")
p.add_argument("--concurrency", type=int, default=4, help="Parallel request workers")
p.add_argument("--max-rows", type=int, default=0, help="Optional cap; 0 means no cap")
p.add_argument("--skip-selected", type=int, default=0, help="Skip this many selected rows before processing")
p.add_argument("--min-token-len", type=int, default=0, help="Only process rows with token length >= this value")
p.add_argument("--max-token-len", type=int, default=0, help="Only process rows with token length <= this value (0 = no limit)")
p.add_argument("--sort-by", choices=("none", "token_len_asc"), default="none", help="Optional ordering of selected rows")
p.add_argument("--retries", type=int, default=3, help="Retries per batch")
p.add_argument("--sleep-ms", type=int, default=150, help="Delay between successful calls")
p.add_argument("--prompt-cache-key", default="anifilebert-relabel-v1", help="Stable prompt cache key")
p.add_argument("--prompt-cache-retention", default="24h", help="Prompt cache retention hint")
p.add_argument("--reasoning-effort", default="medium", help="Reasoning effort (e.g. low/medium/high)")
p.add_argument("--checkpoint-rows", type=int, default=100, help="Write checkpoint every N processed rows")
p.add_argument("--failure-log", default="reports/llm_relabel_failures.log", help="Failure log path")
p.add_argument("--perf-log", default="", help="Optional JSON perf summary path")
p.add_argument("--http-timeout", type=int, default=240, help="HTTP timeout in seconds per request")
p.add_argument("--usd-per-1m-input", type=float, default=0.75, help="Input token price (USD per 1M tokens)")
p.add_argument("--usd-per-1m-output", type=float, default=4.5, help="Output token price (USD per 1M tokens)")
p.add_argument(
"--user-agent",
default="Codex Desktop/0.133.0-alpha.1 (Windows 10.0.22631; x86_64) unknown (Codex Desktop; 26.519.41501)",
help="User-Agent header",
)
return p.parse_args()
def select_row(record: Dict[str, Any], selector: str) -> bool:
if selector == "all":
return True
if selector == "discontinuous_title":
labels = record.get("labels", [])
if not isinstance(labels, list):
return False
in_title = [lb.endswith("TITLE") for lb in labels]
seen_title = False
seen_gap = False
for flag in in_title:
if flag:
if seen_title and seen_gap:
return True
seen_title = True
elif seen_title:
seen_gap = True
return False
filename = str(record.get("filename", ""))
return any(marker in filename for marker in LANG_MARKERS)
def load_rows(path: Path, selector: str) -> tuple[List[Dict[str, Any]], List[Row]]:
all_records: List[Dict[str, Any]] = []
selected: List[Row] = []
with path.open("r", encoding="utf-8") as f:
for line_no, line in enumerate(f, 1):
rec = json.loads(line)
all_records.append(rec)
if select_row(rec, selector):
selected.append(Row(line_no=line_no, record=rec))
return all_records, selected
def parse_model_json(text: str) -> Dict[str, Any]:
raw = text.strip()
raw = re.sub(r"^```(?:json)?\s*", "", raw)
raw = re.sub(r"\s*```$", "", raw)
return json.loads(raw)
def build_user_payload(batch_rows: Sequence[Row]) -> str:
rows: List[Dict[str, Any]] = []
for i, row in enumerate(batch_rows):
rec = row.record
rows.append(
{
"row_id": i,
"file_id": rec.get("file_id"),
"filename": rec.get("filename"),
"tokens": rec.get("tokens"),
"current_labels": rec.get("labels"),
}
)
return json.dumps({"rows": rows}, ensure_ascii=False)
def extract_output_text(response_obj: Dict[str, Any]) -> str:
output = response_obj.get("output", [])
for item in output:
for content in item.get("content", []):
if content.get("type") == "output_text":
return content.get("text", "")
raise ValueError("No output_text found in response")
def extract_function_args(response_obj: Dict[str, Any], func_name: str) -> Dict[str, Any]:
output = response_obj.get("output", [])
for item in output:
if item.get("type") == "function_call" and item.get("name") == func_name:
return json.loads(item.get("arguments", "{}"))
raise ValueError(f"No function_call '{func_name}' found in response")
def validate_labels(tokens: Sequence[str], labels: Sequence[str]) -> bool:
if len(tokens) != len(labels):
return False
for lb in labels:
if lb not in ALLOWED_LABELS:
return False
return True
def normalize_iob2_labels(labels: Sequence[str]) -> List[str]:
normalized: List[str] = []
prev_entity = ""
for lb in labels:
if not isinstance(lb, str) or not lb.startswith(("B-", "I-")):
normalized.append("O")
prev_entity = ""
continue
entity = lb.split("-", 1)[1]
prefix = "I" if prev_entity == entity else "B"
normalized.append(f"{prefix}-{entity}")
prev_entity = entity
return normalized
def title_segments(labels: Sequence[str]) -> List[tuple[int, int]]:
segments: List[tuple[int, int]] = []
i = 0
n = len(labels)
while i < n:
if str(labels[i]).endswith("TITLE"):
j = i + 1
while j < n and str(labels[j]).endswith("TITLE"):
j += 1
segments.append((i, j))
i = j
else:
i += 1
return segments
def force_single_title_segment(tokens: Sequence[str], labels: Sequence[str]) -> List[str]:
"""Guarantee TITLE is a single contiguous segment."""
if len(tokens) != len(labels):
return list(labels)
fixed = normalize_iob2_labels(labels)
segs = title_segments(fixed)
if len(segs) <= 1:
return fixed
first_episode = next((idx for idx, lb in enumerate(fixed) if str(lb).endswith("EPISODE")), len(fixed))
def score(seg: tuple[int, int]) -> tuple[int, int, int]:
start, end = seg
length = end - start
before_episode = 1 if start < first_episode else 0
return (before_episode, length, -start)
keep = max(segs, key=score)
ks, ke = keep
out = list(fixed)
for i in range(len(out)):
if str(out[i]).endswith("TITLE") and not (ks <= i < ke):
out[i] = "O"
out = normalize_iob2_labels(out)
return out
def force_bracket_delimiters_o(tokens: Sequence[str], labels: Sequence[str]) -> List[str]:
"""Keep standalone bracket delimiters outside entities for clean boundaries."""
if len(tokens) != len(labels):
return list(labels)
fixed = list(labels)
changed = False
for idx, token in enumerate(tokens):
if token in BRACKET_DELIMITER_TOKENS and fixed[idx] != "O":
fixed[idx] = "O"
changed = True
if not changed:
return list(labels)
return normalize_iob2_labels(fixed)
def response_schema() -> Dict[str, Any]:
return {
"type": "object",
"additionalProperties": False,
"properties": {
"results": {
"type": "array",
"items": {
"type": "object",
"additionalProperties": False,
"properties": {
"row_id": {"type": "integer"},
"labels": {
"type": "array",
"items": {"type": "string", "enum": sorted(ALLOWED_LABELS)},
},
},
"required": ["row_id", "labels"],
},
}
},
"required": ["results"],
}
def append_failure_log(path: str, message: str) -> None:
p = Path(path)
p.parent.mkdir(parents=True, exist_ok=True)
with p.open("a", encoding="utf-8") as f:
f.write(message.rstrip() + "\n")
def build_request_body(
model: str,
user_payload: str,
prompt_cache_key: str,
prompt_cache_retention: str,
reasoning_effort: str,
include_tools: bool = True,
include_tool_choice: bool = True,
include_reasoning: bool = True,
include_cache_key: bool = True,
include_cache_retention: bool = True,
) -> Dict[str, Any]:
body: Dict[str, Any] = {
"model": model,
"instructions": SYSTEM_INSTRUCTIONS,
"input": user_payload,
}
if include_cache_key:
body["prompt_cache_key"] = prompt_cache_key
if include_cache_retention:
body["prompt_cache_retention"] = prompt_cache_retention
if include_reasoning:
body["reasoning"] = {"effort": reasoning_effort}
if include_tools:
body["tools"] = [
{
"type": "function",
"name": "submit_labels",
"description": "Submit relabeled BIO labels.",
"parameters": response_schema(),
"strict": True,
}
]
if include_tool_choice and include_tools:
body["tool_choice"] = {"type": "function", "name": "submit_labels"}
return body
def parse_usage(response_obj: Dict[str, Any]) -> UsageStats:
usage = response_obj.get("usage", {}) or {}
in_details = usage.get("input_tokens_details", {}) or {}
out_details = usage.get("output_tokens_details", {}) or {}
return UsageStats(
input_tokens=int(usage.get("input_tokens", 0) or 0),
output_tokens=int(usage.get("output_tokens", 0) or 0),
cached_tokens=int(in_details.get("cached_tokens", 0) or 0),
reasoning_tokens=int(out_details.get("reasoning_tokens", 0) or 0),
)
def relabel_batch(
api_base: str,
api_key: str,
model: str,
batch_rows: Sequence[Row],
prompt_cache_key: str,
prompt_cache_retention: str,
reasoning_effort: str,
user_agent: str,
retries: int,
failure_log: str,
http_timeout: int,
) -> tuple[Dict[int, List[str]], UsageStats]:
url = f"{api_base.rstrip('/')}/responses"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"User-Agent": user_agent,
}
user_payload = build_user_payload(batch_rows)
cfg = {
"include_tools": True,
"include_tool_choice": True,
"include_reasoning": True,
"include_cache_key": True,
"include_cache_retention": True,
}
last_error: Exception | None = None
for attempt in range(1, retries + 1):
try:
body = build_request_body(
model=model,
user_payload=user_payload,
prompt_cache_key=prompt_cache_key,
prompt_cache_retention=prompt_cache_retention,
reasoning_effort=reasoning_effort,
include_tools=cfg["include_tools"],
include_tool_choice=cfg["include_tool_choice"],
include_reasoning=cfg["include_reasoning"],
include_cache_key=cfg["include_cache_key"],
include_cache_retention=cfg["include_cache_retention"],
)
resp = requests.post(url, headers=headers, json=body, timeout=http_timeout)
resp.raise_for_status()
obj = resp.json()
usage_stats = parse_usage(obj)
try:
parsed = extract_function_args(obj, "submit_labels")
except Exception:
text = extract_output_text(obj)
parsed = parse_model_json(text)
results = parsed.get("results")
if not isinstance(results, list):
append_failure_log(
failure_log,
f"[invalid-results] model={model} batch={len(batch_rows)} parsed_keys={list(parsed.keys())}",
)
raise ValueError("response JSON missing 'results' list")
mapping: Dict[int, List[str]] = {}
for item in results:
if not isinstance(item, dict):
continue
row_id = item.get("row_id")
labels = item.get("labels")
if not isinstance(row_id, int) or not isinstance(labels, list):
continue
if row_id < 0 or row_id >= len(batch_rows):
continue
tokens = batch_rows[row_id].record.get("tokens", [])
if not validate_labels(tokens, labels):
append_failure_log(
failure_log,
f"[invalid-labels] file_id={batch_rows[row_id].record.get('file_id')} "
f"tokens_len={len(tokens)} labels_len={len(labels)}",
)
continue
mapping[row_id] = labels
if len(mapping) != len(batch_rows):
missing = sorted(set(range(len(batch_rows))) - set(mapping))
append_failure_log(
failure_log,
f"[missing] model={model} batch={len(batch_rows)} missing={missing}",
)
raise ValueError(f"incomplete/invalid rows from model: missing={missing}")
return mapping, usage_stats
except Exception as exc: # noqa: BLE001
last_error = exc
# Some compatible gateways may not support all optional fields.
# Downgrade progressively and keep structured tool output whenever possible.
if isinstance(exc, requests.HTTPError) and exc.response is not None and exc.response.status_code == 400:
response_text = (exc.response.text or "")[:1200]
lowered = response_text.lower()
append_failure_log(
failure_log,
f"[http400] attempt={attempt} model={model} body_cfg={cfg} response={response_text!r}",
)
if "prompt_cache_retention" in lowered and cfg["include_cache_retention"]:
cfg["include_cache_retention"] = False
elif "prompt_cache_key" in lowered and cfg["include_cache_key"]:
cfg["include_cache_key"] = False
elif "reasoning" in lowered and cfg["include_reasoning"]:
cfg["include_reasoning"] = False
elif "tool_choice" in lowered and cfg["include_tool_choice"]:
cfg["include_tool_choice"] = False
elif "tools" in lowered and cfg["include_tools"]:
cfg["include_tools"] = False
if attempt == retries:
break
time.sleep(0.8 * attempt)
raise RuntimeError(f"failed relabel batch after {retries} attempts: {last_error}")
def write_jsonl(path: Path, records: Sequence[Dict[str, Any]]) -> None:
tmp = path.with_suffix(path.suffix + ".tmp")
with tmp.open("w", encoding="utf-8", newline="") as f:
for rec in records:
f.write(json.dumps(rec, ensure_ascii=False, separators=(",", ":")) + "\n")
tmp.replace(path)
def process_batch_with_fallback(
api_base: str,
api_key: str,
model: str,
batch: Sequence[Row],
prompt_cache_key: str,
prompt_cache_retention: str,
reasoning_effort: str,
user_agent: str,
retries: int,
failure_log: str,
http_timeout: int,
) -> List[tuple[Row, List[str]]]:
usage_total = UsageStats()
try:
mapping, usage = relabel_batch(
api_base=api_base,
api_key=api_key,
model=model,
batch_rows=batch,
prompt_cache_key=prompt_cache_key,
prompt_cache_retention=prompt_cache_retention,
reasoning_effort=reasoning_effort,
user_agent=user_agent,
retries=retries,
failure_log=failure_log,
http_timeout=http_timeout,
)
usage_total.add(usage)
except RuntimeError:
mapping = {}
for idx, row in enumerate(batch):
try:
single, usage = relabel_batch(
api_base=api_base,
api_key=api_key,
model=model,
batch_rows=[row],
prompt_cache_key=prompt_cache_key,
prompt_cache_retention=prompt_cache_retention,
reasoning_effort=reasoning_effort,
user_agent=user_agent,
retries=max(retries, 4),
failure_log=failure_log,
http_timeout=http_timeout,
)
usage_total.add(usage)
mapping[idx] = single[0]
except RuntimeError as exc:
append_failure_log(
failure_log,
f"[row-skip] file_id={row.record.get('file_id')} line={row.line_no} reason={exc}",
)
# Hard fallback: enforce contiguous TITLE rather than keeping polluted labels.
toks = row.record.get("tokens", [])
lbs = row.record.get("labels", [])
if isinstance(toks, list) and isinstance(lbs, list) and len(toks) == len(lbs):
mapping[idx] = force_single_title_segment(toks, lbs)
else:
mapping[idx] = lbs
updates: List[tuple[Row, List[str]]] = []
for row_id, labels in mapping.items():
row = batch[row_id]
rec = dict(row.record)
tokens = rec.get("tokens", [])
rec["labels"] = force_single_title_segment(tokens, labels)
repaired, _repairs = repair_jsonl_item(rec)
new_labels = repaired.get("labels", rec.get("labels", []))
new_labels = force_bracket_delimiters_o(tokens, new_labels)
updates.append((row, new_labels))
return updates, usage_total
def process_batch_timed(
meter: ConcurrentMeter,
api_base: str,
api_key: str,
model: str,
batch: Sequence[Row],
prompt_cache_key: str,
prompt_cache_retention: str,
reasoning_effort: str,
user_agent: str,
retries: int,
failure_log: str,
http_timeout: int,
) -> Dict[str, Any]:
meter.task_start()
t0 = time.time()
try:
updates, usage = process_batch_with_fallback(
api_base=api_base,
api_key=api_key,
model=model,
batch=batch,
prompt_cache_key=prompt_cache_key,
prompt_cache_retention=prompt_cache_retention,
reasoning_effort=reasoning_effort,
user_agent=user_agent,
retries=retries,
failure_log=failure_log,
http_timeout=http_timeout,
)
return {
"updates": updates,
"elapsed": time.time() - t0,
"batch_size": len(batch),
"usage": usage,
}
finally:
meter.task_end()
def main() -> None:
args = parse_args()
api_key = args.api_key or os.environ.get("ANIFILEBERT_RELABEL_API_KEY")
if not api_key:
raise SystemExit("Missing API key. Use --api-key or env ANIFILEBERT_RELABEL_API_KEY")
input_path = Path(args.input)
output_path = Path(args.output)
all_records, selected_rows = load_rows(input_path, args.selector)
if args.min_token_len > 0 or args.max_token_len > 0:
filtered: List[Row] = []
for row in selected_rows:
tok_len = len(row.record.get("tokens", []))
if tok_len < args.min_token_len:
continue
if args.max_token_len > 0 and tok_len > args.max_token_len:
continue
filtered.append(row)
selected_rows = filtered
if args.sort_by == "token_len_asc":
selected_rows.sort(key=lambda r: len(r.record.get("tokens", [])))
if args.skip_selected > 0:
selected_rows = selected_rows[args.skip_selected:]
if args.max_rows > 0:
selected_rows = selected_rows[: args.max_rows]
if not selected_rows:
print("selected_rows=0; nothing to do")
if output_path != input_path:
write_jsonl(output_path, all_records)
return
total = len(selected_rows)
changed = 0
concurrency = max(1, min(args.concurrency, 8))
batches: List[List[Row]] = [
selected_rows[i:i + args.batch_size]
for i in range(0, total, args.batch_size)
]
done_rows = 0
wall_start = time.time()
meter = ConcurrentMeter()
total_batch_elapsed = 0.0
completed_batches = 0
usage_total = UsageStats()
with ThreadPoolExecutor(max_workers=concurrency) as executor:
futures = [
executor.submit(
process_batch_timed,
meter,
api_base=args.api_base,
api_key=api_key,
model=args.model,
batch=batch,
prompt_cache_key=args.prompt_cache_key,
prompt_cache_retention=args.prompt_cache_retention,
reasoning_effort=args.reasoning_effort,
user_agent=args.user_agent,
retries=args.retries,
failure_log=args.failure_log,
http_timeout=args.http_timeout,
)
for batch in batches
]
for fut in as_completed(futures):
result = fut.result()
updates = result["updates"]
total_batch_elapsed += float(result["elapsed"])
completed_batches += 1
usage_total.add(result["usage"])
for row, new_labels in updates:
rec = row.record
if rec.get("labels") != new_labels:
rec["labels"] = new_labels
changed += 1
done_rows += len(updates)
snap = meter.snapshot()
wall_elapsed = max(1e-9, snap["timestamp"] - wall_start)
rows_per_sec = done_rows / wall_elapsed
avg_active = snap["active_time_accum"] / wall_elapsed
in_tok_per_sec = usage_total.input_tokens / wall_elapsed
out_tok_per_sec = usage_total.output_tokens / wall_elapsed
hourly_usd = 0.0
if args.usd_per_1m_input > 0 or args.usd_per_1m_output > 0:
cost = (usage_total.input_tokens / 1_000_000.0) * args.usd_per_1m_input + (
usage_total.output_tokens / 1_000_000.0
) * args.usd_per_1m_output
hourly_usd = cost / wall_elapsed * 3600.0
print(
f"processed={done_rows}/{total} changed={changed} "
f"rows_per_sec={rows_per_sec:.2f} active_now={int(snap['current_active'])} "
f"avg_active={avg_active:.2f} max_active={int(snap['max_active'])}/{concurrency} "
f"in_tok_s={in_tok_per_sec:.1f} out_tok_s={out_tok_per_sec:.1f} usd_h={hourly_usd:.3f}"
)
if args.checkpoint_rows > 0 and (done_rows % args.checkpoint_rows == 0 or done_rows == total):
write_jsonl(output_path, all_records)
if args.sleep_ms > 0:
time.sleep(args.sleep_ms / 1000.0)
# rows in selected_rows reference dicts in all_records by identity, so changes are already reflected.
write_jsonl(output_path, all_records)
wall_total = time.time() - wall_start
final_snap = meter.snapshot()
avg_active = final_snap["active_time_accum"] / max(1e-9, wall_total)
perf_summary = {
"wall_seconds": wall_total,
"rows_processed": done_rows,
"rows_per_second": done_rows / max(1e-9, wall_total),
"batches_completed": completed_batches,
"avg_batch_seconds": total_batch_elapsed / max(1, completed_batches),
"avg_active_workers": avg_active,
"max_active_workers": int(final_snap["max_active"]),
"configured_workers": concurrency,
"input_tokens": usage_total.input_tokens,
"output_tokens": usage_total.output_tokens,
"cached_tokens": usage_total.cached_tokens,
"reasoning_tokens": usage_total.reasoning_tokens,
"input_tokens_per_sec": usage_total.input_tokens / max(1e-9, wall_total),
"output_tokens_per_sec": usage_total.output_tokens / max(1e-9, wall_total),
"input_tokens_per_hour": usage_total.input_tokens / max(1e-9, wall_total) * 3600.0,
"output_tokens_per_hour": usage_total.output_tokens / max(1e-9, wall_total) * 3600.0,
"usd_per_1m_input": args.usd_per_1m_input,
"usd_per_1m_output": args.usd_per_1m_output,
}
if args.usd_per_1m_input > 0 or args.usd_per_1m_output > 0:
total_cost = (usage_total.input_tokens / 1_000_000.0) * args.usd_per_1m_input + (
usage_total.output_tokens / 1_000_000.0
) * args.usd_per_1m_output
perf_summary["usd_total"] = total_cost
perf_summary["usd_per_hour"] = total_cost / max(1e-9, wall_total) * 3600.0
if args.perf_log:
p = Path(args.perf_log)
p.parent.mkdir(parents=True, exist_ok=True)
p.write_text(json.dumps(perf_summary, ensure_ascii=False, indent=2), encoding="utf-8")
print(
f"perf wall={wall_total:.1f}s rows_per_sec={perf_summary['rows_per_second']:.2f} "
f"avg_active={avg_active:.2f} max_active={int(final_snap['max_active'])}/{concurrency} "
f"in_tok_s={perf_summary['input_tokens_per_sec']:.1f} out_tok_s={perf_summary['output_tokens_per_sec']:.1f}"
)
print(f"done selected_rows={total} changed_rows={changed} output={output_path}")
if __name__ == "__main__":
main()