#!/usr/bin/env python3 """ Relabel selected rows in a JSONL dataset via an OpenAI-compatible Responses API. Designed for high-throughput cleanup with a stable prompt prefix and `prompt_cache_key` to improve cache hit rates across calls. """ from __future__ import annotations import argparse from concurrent.futures import ThreadPoolExecutor, as_completed import json import os import re import threading import time from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Sequence import requests from anifilebert.label_repairs import repair_jsonl_item ALLOWED_LABELS = { "O", "B-TITLE", "I-TITLE", "B-SEASON", "I-SEASON", "B-EPISODE", "I-EPISODE", "B-SPECIAL", "I-SPECIAL", "B-GROUP", "I-GROUP", "B-RESOLUTION", "I-RESOLUTION", "B-SOURCE", "I-SOURCE", } LANG_MARKERS = ( "中文版", "日语版", "国语版", "粤语版", "英语版", "英配版", "中配版", "日配版", ) SYSTEM_INSTRUCTIONS = """You relabel anime filename tokens with BIO tags. Allowed labels only: O, B/I-TITLE, B/I-SEASON, B/I-EPISODE, B/I-SPECIAL, B/I-GROUP, B/I-RESOLUTION, B/I-SOURCE. Hard rules: 1) Output exactly one label per token. 2) Language markers like 中文版/日语版/国语版/粤语版/英语版/英配版/中配版/日配版 must be SOURCE. 3) Episode identifiers (e.g. 01, 13, EP13, 第13集/話/话) must be EPISODE. 4) If title already appears before episode number, episode-name text after the episode number should be O (not TITLE). 5) Preserve obvious GROUP/RESOLUTION/SOURCE tags when present. Return strict JSON only: {"results":[{"row_id":int,"labels":[str,...]}]} No markdown. No explanation. """ @dataclass class Row: line_no: int record: Dict[str, Any] class ConcurrentMeter: def __init__(self) -> None: self._lock = threading.Lock() self.current_active = 0 self.max_active = 0 self.active_time_accum = 0.0 self.last_ts = time.time() def _accumulate(self, now: float) -> None: dt = now - self.last_ts if dt > 0: self.active_time_accum += self.current_active * dt self.last_ts = now def task_start(self) -> None: now = time.time() with self._lock: self._accumulate(now) self.current_active += 1 if self.current_active > self.max_active: self.max_active = self.current_active def task_end(self) -> None: now = time.time() with self._lock: self._accumulate(now) if self.current_active > 0: self.current_active -= 1 def snapshot(self) -> Dict[str, float]: now = time.time() with self._lock: self._accumulate(now) return { "current_active": float(self.current_active), "max_active": float(self.max_active), "active_time_accum": float(self.active_time_accum), "timestamp": now, } @dataclass class UsageStats: input_tokens: int = 0 output_tokens: int = 0 cached_tokens: int = 0 reasoning_tokens: int = 0 def add(self, other: "UsageStats") -> None: self.input_tokens += int(other.input_tokens) self.output_tokens += int(other.output_tokens) self.cached_tokens += int(other.cached_tokens) self.reasoning_tokens += int(other.reasoning_tokens) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Relabel selected JSONL rows via Responses API") p.add_argument("--input", required=True, help="Input JSONL") p.add_argument("--output", required=True, help="Output JSONL (can equal input)") p.add_argument("--api-base", required=True, help="API base URL, e.g. http://host:port/v1") p.add_argument("--api-key", default=None, help="API key; falls back to env ANIFILEBERT_RELABEL_API_KEY") p.add_argument("--model", default="gpt-5.4-mini", help="Model name") p.add_argument( "--selector", choices=("language", "discontinuous_title", "all"), default="language", help="Row selector", ) p.add_argument("--batch-size", type=int, default=12, help="Rows per request") p.add_argument("--concurrency", type=int, default=4, help="Parallel request workers") p.add_argument("--max-rows", type=int, default=0, help="Optional cap; 0 means no cap") p.add_argument("--skip-selected", type=int, default=0, help="Skip this many selected rows before processing") p.add_argument("--min-token-len", type=int, default=0, help="Only process rows with token length >= this value") p.add_argument("--max-token-len", type=int, default=0, help="Only process rows with token length <= this value (0 = no limit)") p.add_argument("--sort-by", choices=("none", "token_len_asc"), default="none", help="Optional ordering of selected rows") p.add_argument("--retries", type=int, default=3, help="Retries per batch") p.add_argument("--sleep-ms", type=int, default=150, help="Delay between successful calls") p.add_argument("--prompt-cache-key", default="anifilebert-relabel-v1", help="Stable prompt cache key") p.add_argument("--prompt-cache-retention", default="24h", help="Prompt cache retention hint") p.add_argument("--reasoning-effort", default="medium", help="Reasoning effort (e.g. low/medium/high)") p.add_argument("--checkpoint-rows", type=int, default=100, help="Write checkpoint every N processed rows") p.add_argument("--failure-log", default="reports/llm_relabel_failures.log", help="Failure log path") p.add_argument("--perf-log", default="", help="Optional JSON perf summary path") p.add_argument("--http-timeout", type=int, default=240, help="HTTP timeout in seconds per request") p.add_argument("--usd-per-1m-input", type=float, default=0.75, help="Input token price (USD per 1M tokens)") p.add_argument("--usd-per-1m-output", type=float, default=4.5, help="Output token price (USD per 1M tokens)") p.add_argument( "--user-agent", default="Codex Desktop/0.133.0-alpha.1 (Windows 10.0.22631; x86_64) unknown (Codex Desktop; 26.519.41501)", help="User-Agent header", ) return p.parse_args() def select_row(record: Dict[str, Any], selector: str) -> bool: if selector == "all": return True if selector == "discontinuous_title": labels = record.get("labels", []) if not isinstance(labels, list): return False in_title = [lb.endswith("TITLE") for lb in labels] seen_title = False seen_gap = False for flag in in_title: if flag: if seen_title and seen_gap: return True seen_title = True elif seen_title: seen_gap = True return False filename = str(record.get("filename", "")) return any(marker in filename for marker in LANG_MARKERS) def load_rows(path: Path, selector: str) -> tuple[List[Dict[str, Any]], List[Row]]: all_records: List[Dict[str, Any]] = [] selected: List[Row] = [] with path.open("r", encoding="utf-8") as f: for line_no, line in enumerate(f, 1): rec = json.loads(line) all_records.append(rec) if select_row(rec, selector): selected.append(Row(line_no=line_no, record=rec)) return all_records, selected def parse_model_json(text: str) -> Dict[str, Any]: raw = text.strip() raw = re.sub(r"^```(?:json)?\s*", "", raw) raw = re.sub(r"\s*```$", "", raw) return json.loads(raw) def build_user_payload(batch_rows: Sequence[Row]) -> str: rows: List[Dict[str, Any]] = [] for i, row in enumerate(batch_rows): rec = row.record rows.append( { "row_id": i, "file_id": rec.get("file_id"), "filename": rec.get("filename"), "tokens": rec.get("tokens"), "current_labels": rec.get("labels"), } ) return json.dumps({"rows": rows}, ensure_ascii=False) def extract_output_text(response_obj: Dict[str, Any]) -> str: output = response_obj.get("output", []) for item in output: for content in item.get("content", []): if content.get("type") == "output_text": return content.get("text", "") raise ValueError("No output_text found in response") def extract_function_args(response_obj: Dict[str, Any], func_name: str) -> Dict[str, Any]: output = response_obj.get("output", []) for item in output: if item.get("type") == "function_call" and item.get("name") == func_name: return json.loads(item.get("arguments", "{}")) raise ValueError(f"No function_call '{func_name}' found in response") def validate_labels(tokens: Sequence[str], labels: Sequence[str]) -> bool: if len(tokens) != len(labels): return False for lb in labels: if lb not in ALLOWED_LABELS: return False return True def normalize_iob2_labels(labels: Sequence[str]) -> List[str]: normalized: List[str] = [] prev_entity = "" for lb in labels: if not isinstance(lb, str) or not lb.startswith(("B-", "I-")): normalized.append("O") prev_entity = "" continue entity = lb.split("-", 1)[1] prefix = "I" if prev_entity == entity else "B" normalized.append(f"{prefix}-{entity}") prev_entity = entity return normalized def title_segments(labels: Sequence[str]) -> List[tuple[int, int]]: segments: List[tuple[int, int]] = [] i = 0 n = len(labels) while i < n: if str(labels[i]).endswith("TITLE"): j = i + 1 while j < n and str(labels[j]).endswith("TITLE"): j += 1 segments.append((i, j)) i = j else: i += 1 return segments def force_single_title_segment(tokens: Sequence[str], labels: Sequence[str]) -> List[str]: """Guarantee TITLE is a single contiguous segment.""" if len(tokens) != len(labels): return list(labels) fixed = normalize_iob2_labels(labels) segs = title_segments(fixed) if len(segs) <= 1: return fixed first_episode = next((idx for idx, lb in enumerate(fixed) if str(lb).endswith("EPISODE")), len(fixed)) def score(seg: tuple[int, int]) -> tuple[int, int, int]: start, end = seg length = end - start before_episode = 1 if start < first_episode else 0 return (before_episode, length, -start) keep = max(segs, key=score) ks, ke = keep out = list(fixed) for i in range(len(out)): if str(out[i]).endswith("TITLE") and not (ks <= i < ke): out[i] = "O" out = normalize_iob2_labels(out) return out def response_schema() -> Dict[str, Any]: return { "type": "object", "additionalProperties": False, "properties": { "results": { "type": "array", "items": { "type": "object", "additionalProperties": False, "properties": { "row_id": {"type": "integer"}, "labels": { "type": "array", "items": {"type": "string", "enum": sorted(ALLOWED_LABELS)}, }, }, "required": ["row_id", "labels"], }, } }, "required": ["results"], } def append_failure_log(path: str, message: str) -> None: p = Path(path) p.parent.mkdir(parents=True, exist_ok=True) with p.open("a", encoding="utf-8") as f: f.write(message.rstrip() + "\n") def build_request_body( model: str, user_payload: str, prompt_cache_key: str, prompt_cache_retention: str, reasoning_effort: str, include_tools: bool = True, include_tool_choice: bool = True, include_reasoning: bool = True, include_cache_key: bool = True, include_cache_retention: bool = True, ) -> Dict[str, Any]: body: Dict[str, Any] = { "model": model, "instructions": SYSTEM_INSTRUCTIONS, "input": user_payload, } if include_cache_key: body["prompt_cache_key"] = prompt_cache_key if include_cache_retention: body["prompt_cache_retention"] = prompt_cache_retention if include_reasoning: body["reasoning"] = {"effort": reasoning_effort} if include_tools: body["tools"] = [ { "type": "function", "name": "submit_labels", "description": "Submit relabeled BIO labels.", "parameters": response_schema(), "strict": True, } ] if include_tool_choice and include_tools: body["tool_choice"] = {"type": "function", "name": "submit_labels"} return body def parse_usage(response_obj: Dict[str, Any]) -> UsageStats: usage = response_obj.get("usage", {}) or {} in_details = usage.get("input_tokens_details", {}) or {} out_details = usage.get("output_tokens_details", {}) or {} return UsageStats( input_tokens=int(usage.get("input_tokens", 0) or 0), output_tokens=int(usage.get("output_tokens", 0) or 0), cached_tokens=int(in_details.get("cached_tokens", 0) or 0), reasoning_tokens=int(out_details.get("reasoning_tokens", 0) or 0), ) def relabel_batch( api_base: str, api_key: str, model: str, batch_rows: Sequence[Row], prompt_cache_key: str, prompt_cache_retention: str, reasoning_effort: str, user_agent: str, retries: int, failure_log: str, http_timeout: int, ) -> tuple[Dict[int, List[str]], UsageStats]: url = f"{api_base.rstrip('/')}/responses" headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "User-Agent": user_agent, } user_payload = build_user_payload(batch_rows) cfg = { "include_tools": True, "include_tool_choice": True, "include_reasoning": True, "include_cache_key": True, "include_cache_retention": True, } last_error: Exception | None = None for attempt in range(1, retries + 1): try: body = build_request_body( model=model, user_payload=user_payload, prompt_cache_key=prompt_cache_key, prompt_cache_retention=prompt_cache_retention, reasoning_effort=reasoning_effort, include_tools=cfg["include_tools"], include_tool_choice=cfg["include_tool_choice"], include_reasoning=cfg["include_reasoning"], include_cache_key=cfg["include_cache_key"], include_cache_retention=cfg["include_cache_retention"], ) resp = requests.post(url, headers=headers, json=body, timeout=http_timeout) resp.raise_for_status() obj = resp.json() usage_stats = parse_usage(obj) try: parsed = extract_function_args(obj, "submit_labels") except Exception: text = extract_output_text(obj) parsed = parse_model_json(text) results = parsed.get("results") if not isinstance(results, list): append_failure_log( failure_log, f"[invalid-results] model={model} batch={len(batch_rows)} parsed_keys={list(parsed.keys())}", ) raise ValueError("response JSON missing 'results' list") mapping: Dict[int, List[str]] = {} for item in results: if not isinstance(item, dict): continue row_id = item.get("row_id") labels = item.get("labels") if not isinstance(row_id, int) or not isinstance(labels, list): continue if row_id < 0 or row_id >= len(batch_rows): continue tokens = batch_rows[row_id].record.get("tokens", []) if not validate_labels(tokens, labels): append_failure_log( failure_log, f"[invalid-labels] file_id={batch_rows[row_id].record.get('file_id')} " f"tokens_len={len(tokens)} labels_len={len(labels)}", ) continue mapping[row_id] = labels if len(mapping) != len(batch_rows): missing = sorted(set(range(len(batch_rows))) - set(mapping)) append_failure_log( failure_log, f"[missing] model={model} batch={len(batch_rows)} missing={missing}", ) raise ValueError(f"incomplete/invalid rows from model: missing={missing}") return mapping, usage_stats except Exception as exc: # noqa: BLE001 last_error = exc # Some compatible gateways may not support all optional fields. # Downgrade progressively and keep structured tool output whenever possible. if isinstance(exc, requests.HTTPError) and exc.response is not None and exc.response.status_code == 400: response_text = (exc.response.text or "")[:1200] lowered = response_text.lower() append_failure_log( failure_log, f"[http400] attempt={attempt} model={model} body_cfg={cfg} response={response_text!r}", ) if "prompt_cache_retention" in lowered and cfg["include_cache_retention"]: cfg["include_cache_retention"] = False elif "prompt_cache_key" in lowered and cfg["include_cache_key"]: cfg["include_cache_key"] = False elif "reasoning" in lowered and cfg["include_reasoning"]: cfg["include_reasoning"] = False elif "tool_choice" in lowered and cfg["include_tool_choice"]: cfg["include_tool_choice"] = False elif "tools" in lowered and cfg["include_tools"]: cfg["include_tools"] = False if attempt == retries: break time.sleep(0.8 * attempt) raise RuntimeError(f"failed relabel batch after {retries} attempts: {last_error}") def write_jsonl(path: Path, records: Sequence[Dict[str, Any]]) -> None: tmp = path.with_suffix(path.suffix + ".tmp") with tmp.open("w", encoding="utf-8", newline="") as f: for rec in records: f.write(json.dumps(rec, ensure_ascii=False, separators=(",", ":")) + "\n") tmp.replace(path) def process_batch_with_fallback( api_base: str, api_key: str, model: str, batch: Sequence[Row], prompt_cache_key: str, prompt_cache_retention: str, reasoning_effort: str, user_agent: str, retries: int, failure_log: str, http_timeout: int, ) -> List[tuple[Row, List[str]]]: usage_total = UsageStats() try: mapping, usage = relabel_batch( api_base=api_base, api_key=api_key, model=model, batch_rows=batch, prompt_cache_key=prompt_cache_key, prompt_cache_retention=prompt_cache_retention, reasoning_effort=reasoning_effort, user_agent=user_agent, retries=retries, failure_log=failure_log, http_timeout=http_timeout, ) usage_total.add(usage) except RuntimeError: mapping = {} for idx, row in enumerate(batch): try: single, usage = relabel_batch( api_base=api_base, api_key=api_key, model=model, batch_rows=[row], prompt_cache_key=prompt_cache_key, prompt_cache_retention=prompt_cache_retention, reasoning_effort=reasoning_effort, user_agent=user_agent, retries=max(retries, 4), failure_log=failure_log, http_timeout=http_timeout, ) usage_total.add(usage) mapping[idx] = single[0] except RuntimeError as exc: append_failure_log( failure_log, f"[row-skip] file_id={row.record.get('file_id')} line={row.line_no} reason={exc}", ) # Hard fallback: enforce contiguous TITLE rather than keeping polluted labels. toks = row.record.get("tokens", []) lbs = row.record.get("labels", []) if isinstance(toks, list) and isinstance(lbs, list) and len(toks) == len(lbs): mapping[idx] = force_single_title_segment(toks, lbs) else: mapping[idx] = lbs updates: List[tuple[Row, List[str]]] = [] for row_id, labels in mapping.items(): row = batch[row_id] rec = dict(row.record) rec["labels"] = force_single_title_segment(rec.get("tokens", []), labels) repaired, _repairs = repair_jsonl_item(rec) new_labels = repaired.get("labels", rec.get("labels", [])) updates.append((row, new_labels)) return updates, usage_total def process_batch_timed( meter: ConcurrentMeter, api_base: str, api_key: str, model: str, batch: Sequence[Row], prompt_cache_key: str, prompt_cache_retention: str, reasoning_effort: str, user_agent: str, retries: int, failure_log: str, http_timeout: int, ) -> Dict[str, Any]: meter.task_start() t0 = time.time() try: updates, usage = process_batch_with_fallback( api_base=api_base, api_key=api_key, model=model, batch=batch, prompt_cache_key=prompt_cache_key, prompt_cache_retention=prompt_cache_retention, reasoning_effort=reasoning_effort, user_agent=user_agent, retries=retries, failure_log=failure_log, http_timeout=http_timeout, ) return { "updates": updates, "elapsed": time.time() - t0, "batch_size": len(batch), "usage": usage, } finally: meter.task_end() def main() -> None: args = parse_args() api_key = args.api_key or os.environ.get("ANIFILEBERT_RELABEL_API_KEY") if not api_key: raise SystemExit("Missing API key. Use --api-key or env ANIFILEBERT_RELABEL_API_KEY") input_path = Path(args.input) output_path = Path(args.output) all_records, selected_rows = load_rows(input_path, args.selector) if args.min_token_len > 0 or args.max_token_len > 0: filtered: List[Row] = [] for row in selected_rows: tok_len = len(row.record.get("tokens", [])) if tok_len < args.min_token_len: continue if args.max_token_len > 0 and tok_len > args.max_token_len: continue filtered.append(row) selected_rows = filtered if args.sort_by == "token_len_asc": selected_rows.sort(key=lambda r: len(r.record.get("tokens", []))) if args.skip_selected > 0: selected_rows = selected_rows[args.skip_selected:] if args.max_rows > 0: selected_rows = selected_rows[: args.max_rows] if not selected_rows: print("selected_rows=0; nothing to do") if output_path != input_path: write_jsonl(output_path, all_records) return total = len(selected_rows) changed = 0 concurrency = max(1, min(args.concurrency, 8)) batches: List[List[Row]] = [ selected_rows[i:i + args.batch_size] for i in range(0, total, args.batch_size) ] done_rows = 0 wall_start = time.time() meter = ConcurrentMeter() total_batch_elapsed = 0.0 completed_batches = 0 usage_total = UsageStats() with ThreadPoolExecutor(max_workers=concurrency) as executor: futures = [ executor.submit( process_batch_timed, meter, api_base=args.api_base, api_key=api_key, model=args.model, batch=batch, prompt_cache_key=args.prompt_cache_key, prompt_cache_retention=args.prompt_cache_retention, reasoning_effort=args.reasoning_effort, user_agent=args.user_agent, retries=args.retries, failure_log=args.failure_log, http_timeout=args.http_timeout, ) for batch in batches ] for fut in as_completed(futures): result = fut.result() updates = result["updates"] total_batch_elapsed += float(result["elapsed"]) completed_batches += 1 usage_total.add(result["usage"]) for row, new_labels in updates: rec = row.record if rec.get("labels") != new_labels: rec["labels"] = new_labels changed += 1 done_rows += len(updates) snap = meter.snapshot() wall_elapsed = max(1e-9, snap["timestamp"] - wall_start) rows_per_sec = done_rows / wall_elapsed avg_active = snap["active_time_accum"] / wall_elapsed in_tok_per_sec = usage_total.input_tokens / wall_elapsed out_tok_per_sec = usage_total.output_tokens / wall_elapsed hourly_usd = 0.0 if args.usd_per_1m_input > 0 or args.usd_per_1m_output > 0: cost = (usage_total.input_tokens / 1_000_000.0) * args.usd_per_1m_input + ( usage_total.output_tokens / 1_000_000.0 ) * args.usd_per_1m_output hourly_usd = cost / wall_elapsed * 3600.0 print( f"processed={done_rows}/{total} changed={changed} " f"rows_per_sec={rows_per_sec:.2f} active_now={int(snap['current_active'])} " f"avg_active={avg_active:.2f} max_active={int(snap['max_active'])}/{concurrency} " f"in_tok_s={in_tok_per_sec:.1f} out_tok_s={out_tok_per_sec:.1f} usd_h={hourly_usd:.3f}" ) if args.checkpoint_rows > 0 and (done_rows % args.checkpoint_rows == 0 or done_rows == total): write_jsonl(output_path, all_records) if args.sleep_ms > 0: time.sleep(args.sleep_ms / 1000.0) # rows in selected_rows reference dicts in all_records by identity, so changes are already reflected. write_jsonl(output_path, all_records) wall_total = time.time() - wall_start final_snap = meter.snapshot() avg_active = final_snap["active_time_accum"] / max(1e-9, wall_total) perf_summary = { "wall_seconds": wall_total, "rows_processed": done_rows, "rows_per_second": done_rows / max(1e-9, wall_total), "batches_completed": completed_batches, "avg_batch_seconds": total_batch_elapsed / max(1, completed_batches), "avg_active_workers": avg_active, "max_active_workers": int(final_snap["max_active"]), "configured_workers": concurrency, "input_tokens": usage_total.input_tokens, "output_tokens": usage_total.output_tokens, "cached_tokens": usage_total.cached_tokens, "reasoning_tokens": usage_total.reasoning_tokens, "input_tokens_per_sec": usage_total.input_tokens / max(1e-9, wall_total), "output_tokens_per_sec": usage_total.output_tokens / max(1e-9, wall_total), "input_tokens_per_hour": usage_total.input_tokens / max(1e-9, wall_total) * 3600.0, "output_tokens_per_hour": usage_total.output_tokens / max(1e-9, wall_total) * 3600.0, "usd_per_1m_input": args.usd_per_1m_input, "usd_per_1m_output": args.usd_per_1m_output, } if args.usd_per_1m_input > 0 or args.usd_per_1m_output > 0: total_cost = (usage_total.input_tokens / 1_000_000.0) * args.usd_per_1m_input + ( usage_total.output_tokens / 1_000_000.0 ) * args.usd_per_1m_output perf_summary["usd_total"] = total_cost perf_summary["usd_per_hour"] = total_cost / max(1e-9, wall_total) * 3600.0 if args.perf_log: p = Path(args.perf_log) p.parent.mkdir(parents=True, exist_ok=True) p.write_text(json.dumps(perf_summary, ensure_ascii=False, indent=2), encoding="utf-8") print( f"perf wall={wall_total:.1f}s rows_per_sec={perf_summary['rows_per_second']:.2f} " f"avg_active={avg_active:.2f} max_active={int(final_snap['max_active'])}/{concurrency} " f"in_tok_s={perf_summary['input_tokens_per_sec']:.1f} out_tok_s={perf_summary['output_tokens_per_sec']:.1f}" ) print(f"done selected_rows={total} changed_rows={changed} output={output_path}") if __name__ == "__main__": main()