Token Classification
Transformers
ONNX
Safetensors
English
Japanese
Chinese
bert
anime
filename-parsing
Eval Results (legacy)
Instructions to use ModerRAS/AniFileBERT with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ModerRAS/AniFileBERT with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("token-classification", model="ModerRAS/AniFileBERT")# Load model directly from transformers import AutoTokenizer, AutoModelForTokenClassification tokenizer = AutoTokenizer.from_pretrained("ModerRAS/AniFileBERT") model = AutoModelForTokenClassification.from_pretrained("ModerRAS/AniFileBERT") - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| """ | |
| Relabel selected rows in a JSONL dataset via an OpenAI-compatible Responses API. | |
| Designed for high-throughput cleanup with a stable prompt prefix and | |
| `prompt_cache_key` to improve cache hit rates across calls. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| import json | |
| import os | |
| import re | |
| import threading | |
| import time | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Sequence | |
| import requests | |
| from anifilebert.label_repairs import repair_jsonl_item | |
| ALLOWED_LABELS = { | |
| "O", | |
| "B-TITLE", "I-TITLE", | |
| "B-SEASON", "I-SEASON", | |
| "B-EPISODE", "I-EPISODE", | |
| "B-SPECIAL", "I-SPECIAL", | |
| "B-GROUP", "I-GROUP", | |
| "B-RESOLUTION", "I-RESOLUTION", | |
| "B-SOURCE", "I-SOURCE", | |
| } | |
| LANG_MARKERS = ( | |
| "中文版", | |
| "日语版", | |
| "国语版", | |
| "粤语版", | |
| "英语版", | |
| "英配版", | |
| "中配版", | |
| "日配版", | |
| ) | |
| BRACKET_DELIMITER_TOKENS = { | |
| "[", | |
| "]", | |
| "(", | |
| ")", | |
| "【", | |
| "】", | |
| "《", | |
| "》", | |
| "(", | |
| ")", | |
| } | |
| SYSTEM_INSTRUCTIONS = """You relabel anime filename tokens with BIO tags. | |
| Allowed labels only: | |
| O, B/I-TITLE, B/I-SEASON, B/I-EPISODE, B/I-SPECIAL, B/I-GROUP, B/I-RESOLUTION, B/I-SOURCE. | |
| Hard rules: | |
| 1) Output exactly one label per token. | |
| 2) Language markers like 中文版/日语版/国语版/粤语版/英语版/英配版/中配版/日配版 must be SOURCE. | |
| 3) Episode identifiers (e.g. 01, 13, EP13, 第13集/話/话) must be EPISODE. | |
| 4) If title already appears before episode number, episode-name text after the episode number should be O (not TITLE). | |
| 5) Preserve obvious GROUP/RESOLUTION/SOURCE tags when present. | |
| 6) If bracket delimiters are split into standalone tokens (`[ ] ( ) 【 】 《 》 ( )`), they must be O. | |
| Return strict JSON only: | |
| {"results":[{"row_id":int,"labels":[str,...]}]} | |
| No markdown. No explanation. | |
| """ | |
| class Row: | |
| line_no: int | |
| record: Dict[str, Any] | |
| class ConcurrentMeter: | |
| def __init__(self) -> None: | |
| self._lock = threading.Lock() | |
| self.current_active = 0 | |
| self.max_active = 0 | |
| self.active_time_accum = 0.0 | |
| self.last_ts = time.time() | |
| def _accumulate(self, now: float) -> None: | |
| dt = now - self.last_ts | |
| if dt > 0: | |
| self.active_time_accum += self.current_active * dt | |
| self.last_ts = now | |
| def task_start(self) -> None: | |
| now = time.time() | |
| with self._lock: | |
| self._accumulate(now) | |
| self.current_active += 1 | |
| if self.current_active > self.max_active: | |
| self.max_active = self.current_active | |
| def task_end(self) -> None: | |
| now = time.time() | |
| with self._lock: | |
| self._accumulate(now) | |
| if self.current_active > 0: | |
| self.current_active -= 1 | |
| def snapshot(self) -> Dict[str, float]: | |
| now = time.time() | |
| with self._lock: | |
| self._accumulate(now) | |
| return { | |
| "current_active": float(self.current_active), | |
| "max_active": float(self.max_active), | |
| "active_time_accum": float(self.active_time_accum), | |
| "timestamp": now, | |
| } | |
| class UsageStats: | |
| input_tokens: int = 0 | |
| output_tokens: int = 0 | |
| cached_tokens: int = 0 | |
| reasoning_tokens: int = 0 | |
| def add(self, other: "UsageStats") -> None: | |
| self.input_tokens += int(other.input_tokens) | |
| self.output_tokens += int(other.output_tokens) | |
| self.cached_tokens += int(other.cached_tokens) | |
| self.reasoning_tokens += int(other.reasoning_tokens) | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser(description="Relabel selected JSONL rows via Responses API") | |
| p.add_argument("--input", required=True, help="Input JSONL") | |
| p.add_argument("--output", required=True, help="Output JSONL (can equal input)") | |
| p.add_argument("--api-base", required=True, help="API base URL, e.g. http://host:port/v1") | |
| p.add_argument("--api-key", default=None, help="API key; falls back to env ANIFILEBERT_RELABEL_API_KEY") | |
| p.add_argument("--model", default="gpt-5.4-mini", help="Model name") | |
| p.add_argument( | |
| "--selector", | |
| choices=("language", "discontinuous_title", "all"), | |
| default="language", | |
| help="Row selector", | |
| ) | |
| p.add_argument("--batch-size", type=int, default=12, help="Rows per request") | |
| p.add_argument("--concurrency", type=int, default=4, help="Parallel request workers") | |
| p.add_argument("--max-rows", type=int, default=0, help="Optional cap; 0 means no cap") | |
| p.add_argument("--skip-selected", type=int, default=0, help="Skip this many selected rows before processing") | |
| p.add_argument("--min-token-len", type=int, default=0, help="Only process rows with token length >= this value") | |
| p.add_argument("--max-token-len", type=int, default=0, help="Only process rows with token length <= this value (0 = no limit)") | |
| p.add_argument("--sort-by", choices=("none", "token_len_asc"), default="none", help="Optional ordering of selected rows") | |
| p.add_argument("--retries", type=int, default=3, help="Retries per batch") | |
| p.add_argument("--sleep-ms", type=int, default=150, help="Delay between successful calls") | |
| p.add_argument("--prompt-cache-key", default="anifilebert-relabel-v1", help="Stable prompt cache key") | |
| p.add_argument("--prompt-cache-retention", default="24h", help="Prompt cache retention hint") | |
| p.add_argument("--reasoning-effort", default="medium", help="Reasoning effort (e.g. low/medium/high)") | |
| p.add_argument("--checkpoint-rows", type=int, default=100, help="Write checkpoint every N processed rows") | |
| p.add_argument("--failure-log", default="reports/llm_relabel_failures.log", help="Failure log path") | |
| p.add_argument("--perf-log", default="", help="Optional JSON perf summary path") | |
| p.add_argument("--http-timeout", type=int, default=240, help="HTTP timeout in seconds per request") | |
| p.add_argument("--usd-per-1m-input", type=float, default=0.75, help="Input token price (USD per 1M tokens)") | |
| p.add_argument("--usd-per-1m-output", type=float, default=4.5, help="Output token price (USD per 1M tokens)") | |
| p.add_argument( | |
| "--user-agent", | |
| default="Codex Desktop/0.133.0-alpha.1 (Windows 10.0.22631; x86_64) unknown (Codex Desktop; 26.519.41501)", | |
| help="User-Agent header", | |
| ) | |
| return p.parse_args() | |
| def select_row(record: Dict[str, Any], selector: str) -> bool: | |
| if selector == "all": | |
| return True | |
| if selector == "discontinuous_title": | |
| labels = record.get("labels", []) | |
| if not isinstance(labels, list): | |
| return False | |
| in_title = [lb.endswith("TITLE") for lb in labels] | |
| seen_title = False | |
| seen_gap = False | |
| for flag in in_title: | |
| if flag: | |
| if seen_title and seen_gap: | |
| return True | |
| seen_title = True | |
| elif seen_title: | |
| seen_gap = True | |
| return False | |
| filename = str(record.get("filename", "")) | |
| return any(marker in filename for marker in LANG_MARKERS) | |
| def load_rows(path: Path, selector: str) -> tuple[List[Dict[str, Any]], List[Row]]: | |
| all_records: List[Dict[str, Any]] = [] | |
| selected: List[Row] = [] | |
| with path.open("r", encoding="utf-8") as f: | |
| for line_no, line in enumerate(f, 1): | |
| rec = json.loads(line) | |
| all_records.append(rec) | |
| if select_row(rec, selector): | |
| selected.append(Row(line_no=line_no, record=rec)) | |
| return all_records, selected | |
| def parse_model_json(text: str) -> Dict[str, Any]: | |
| raw = text.strip() | |
| raw = re.sub(r"^```(?:json)?\s*", "", raw) | |
| raw = re.sub(r"\s*```$", "", raw) | |
| return json.loads(raw) | |
| def build_user_payload(batch_rows: Sequence[Row]) -> str: | |
| rows: List[Dict[str, Any]] = [] | |
| for i, row in enumerate(batch_rows): | |
| rec = row.record | |
| rows.append( | |
| { | |
| "row_id": i, | |
| "file_id": rec.get("file_id"), | |
| "filename": rec.get("filename"), | |
| "tokens": rec.get("tokens"), | |
| "current_labels": rec.get("labels"), | |
| } | |
| ) | |
| return json.dumps({"rows": rows}, ensure_ascii=False) | |
| def extract_output_text(response_obj: Dict[str, Any]) -> str: | |
| output = response_obj.get("output", []) | |
| for item in output: | |
| for content in item.get("content", []): | |
| if content.get("type") == "output_text": | |
| return content.get("text", "") | |
| raise ValueError("No output_text found in response") | |
| def extract_function_args(response_obj: Dict[str, Any], func_name: str) -> Dict[str, Any]: | |
| output = response_obj.get("output", []) | |
| for item in output: | |
| if item.get("type") == "function_call" and item.get("name") == func_name: | |
| return json.loads(item.get("arguments", "{}")) | |
| raise ValueError(f"No function_call '{func_name}' found in response") | |
| def validate_labels(tokens: Sequence[str], labels: Sequence[str]) -> bool: | |
| if len(tokens) != len(labels): | |
| return False | |
| for lb in labels: | |
| if lb not in ALLOWED_LABELS: | |
| return False | |
| return True | |
| def normalize_iob2_labels(labels: Sequence[str]) -> List[str]: | |
| normalized: List[str] = [] | |
| prev_entity = "" | |
| for lb in labels: | |
| if not isinstance(lb, str) or not lb.startswith(("B-", "I-")): | |
| normalized.append("O") | |
| prev_entity = "" | |
| continue | |
| entity = lb.split("-", 1)[1] | |
| prefix = "I" if prev_entity == entity else "B" | |
| normalized.append(f"{prefix}-{entity}") | |
| prev_entity = entity | |
| return normalized | |
| def title_segments(labels: Sequence[str]) -> List[tuple[int, int]]: | |
| segments: List[tuple[int, int]] = [] | |
| i = 0 | |
| n = len(labels) | |
| while i < n: | |
| if str(labels[i]).endswith("TITLE"): | |
| j = i + 1 | |
| while j < n and str(labels[j]).endswith("TITLE"): | |
| j += 1 | |
| segments.append((i, j)) | |
| i = j | |
| else: | |
| i += 1 | |
| return segments | |
| def force_single_title_segment(tokens: Sequence[str], labels: Sequence[str]) -> List[str]: | |
| """Guarantee TITLE is a single contiguous segment.""" | |
| if len(tokens) != len(labels): | |
| return list(labels) | |
| fixed = normalize_iob2_labels(labels) | |
| segs = title_segments(fixed) | |
| if len(segs) <= 1: | |
| return fixed | |
| first_episode = next((idx for idx, lb in enumerate(fixed) if str(lb).endswith("EPISODE")), len(fixed)) | |
| def score(seg: tuple[int, int]) -> tuple[int, int, int]: | |
| start, end = seg | |
| length = end - start | |
| before_episode = 1 if start < first_episode else 0 | |
| return (before_episode, length, -start) | |
| keep = max(segs, key=score) | |
| ks, ke = keep | |
| out = list(fixed) | |
| for i in range(len(out)): | |
| if str(out[i]).endswith("TITLE") and not (ks <= i < ke): | |
| out[i] = "O" | |
| out = normalize_iob2_labels(out) | |
| return out | |
| def force_bracket_delimiters_o(tokens: Sequence[str], labels: Sequence[str]) -> List[str]: | |
| """Keep standalone bracket delimiters outside entities for clean boundaries.""" | |
| if len(tokens) != len(labels): | |
| return list(labels) | |
| fixed = list(labels) | |
| changed = False | |
| for idx, token in enumerate(tokens): | |
| if token in BRACKET_DELIMITER_TOKENS and fixed[idx] != "O": | |
| fixed[idx] = "O" | |
| changed = True | |
| if not changed: | |
| return list(labels) | |
| return normalize_iob2_labels(fixed) | |
| def response_schema() -> Dict[str, Any]: | |
| return { | |
| "type": "object", | |
| "additionalProperties": False, | |
| "properties": { | |
| "results": { | |
| "type": "array", | |
| "items": { | |
| "type": "object", | |
| "additionalProperties": False, | |
| "properties": { | |
| "row_id": {"type": "integer"}, | |
| "labels": { | |
| "type": "array", | |
| "items": {"type": "string", "enum": sorted(ALLOWED_LABELS)}, | |
| }, | |
| }, | |
| "required": ["row_id", "labels"], | |
| }, | |
| } | |
| }, | |
| "required": ["results"], | |
| } | |
| def append_failure_log(path: str, message: str) -> None: | |
| p = Path(path) | |
| p.parent.mkdir(parents=True, exist_ok=True) | |
| with p.open("a", encoding="utf-8") as f: | |
| f.write(message.rstrip() + "\n") | |
| def build_request_body( | |
| model: str, | |
| user_payload: str, | |
| prompt_cache_key: str, | |
| prompt_cache_retention: str, | |
| reasoning_effort: str, | |
| include_tools: bool = True, | |
| include_tool_choice: bool = True, | |
| include_reasoning: bool = True, | |
| include_cache_key: bool = True, | |
| include_cache_retention: bool = True, | |
| ) -> Dict[str, Any]: | |
| body: Dict[str, Any] = { | |
| "model": model, | |
| "instructions": SYSTEM_INSTRUCTIONS, | |
| "input": user_payload, | |
| } | |
| if include_cache_key: | |
| body["prompt_cache_key"] = prompt_cache_key | |
| if include_cache_retention: | |
| body["prompt_cache_retention"] = prompt_cache_retention | |
| if include_reasoning: | |
| body["reasoning"] = {"effort": reasoning_effort} | |
| if include_tools: | |
| body["tools"] = [ | |
| { | |
| "type": "function", | |
| "name": "submit_labels", | |
| "description": "Submit relabeled BIO labels.", | |
| "parameters": response_schema(), | |
| "strict": True, | |
| } | |
| ] | |
| if include_tool_choice and include_tools: | |
| body["tool_choice"] = {"type": "function", "name": "submit_labels"} | |
| return body | |
| def parse_usage(response_obj: Dict[str, Any]) -> UsageStats: | |
| usage = response_obj.get("usage", {}) or {} | |
| in_details = usage.get("input_tokens_details", {}) or {} | |
| out_details = usage.get("output_tokens_details", {}) or {} | |
| return UsageStats( | |
| input_tokens=int(usage.get("input_tokens", 0) or 0), | |
| output_tokens=int(usage.get("output_tokens", 0) or 0), | |
| cached_tokens=int(in_details.get("cached_tokens", 0) or 0), | |
| reasoning_tokens=int(out_details.get("reasoning_tokens", 0) or 0), | |
| ) | |
| def relabel_batch( | |
| api_base: str, | |
| api_key: str, | |
| model: str, | |
| batch_rows: Sequence[Row], | |
| prompt_cache_key: str, | |
| prompt_cache_retention: str, | |
| reasoning_effort: str, | |
| user_agent: str, | |
| retries: int, | |
| failure_log: str, | |
| http_timeout: int, | |
| ) -> tuple[Dict[int, List[str]], UsageStats]: | |
| url = f"{api_base.rstrip('/')}/responses" | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json", | |
| "User-Agent": user_agent, | |
| } | |
| user_payload = build_user_payload(batch_rows) | |
| cfg = { | |
| "include_tools": True, | |
| "include_tool_choice": True, | |
| "include_reasoning": True, | |
| "include_cache_key": True, | |
| "include_cache_retention": True, | |
| } | |
| last_error: Exception | None = None | |
| for attempt in range(1, retries + 1): | |
| try: | |
| body = build_request_body( | |
| model=model, | |
| user_payload=user_payload, | |
| prompt_cache_key=prompt_cache_key, | |
| prompt_cache_retention=prompt_cache_retention, | |
| reasoning_effort=reasoning_effort, | |
| include_tools=cfg["include_tools"], | |
| include_tool_choice=cfg["include_tool_choice"], | |
| include_reasoning=cfg["include_reasoning"], | |
| include_cache_key=cfg["include_cache_key"], | |
| include_cache_retention=cfg["include_cache_retention"], | |
| ) | |
| resp = requests.post(url, headers=headers, json=body, timeout=http_timeout) | |
| resp.raise_for_status() | |
| obj = resp.json() | |
| usage_stats = parse_usage(obj) | |
| try: | |
| parsed = extract_function_args(obj, "submit_labels") | |
| except Exception: | |
| text = extract_output_text(obj) | |
| parsed = parse_model_json(text) | |
| results = parsed.get("results") | |
| if not isinstance(results, list): | |
| append_failure_log( | |
| failure_log, | |
| f"[invalid-results] model={model} batch={len(batch_rows)} parsed_keys={list(parsed.keys())}", | |
| ) | |
| raise ValueError("response JSON missing 'results' list") | |
| mapping: Dict[int, List[str]] = {} | |
| for item in results: | |
| if not isinstance(item, dict): | |
| continue | |
| row_id = item.get("row_id") | |
| labels = item.get("labels") | |
| if not isinstance(row_id, int) or not isinstance(labels, list): | |
| continue | |
| if row_id < 0 or row_id >= len(batch_rows): | |
| continue | |
| tokens = batch_rows[row_id].record.get("tokens", []) | |
| if not validate_labels(tokens, labels): | |
| append_failure_log( | |
| failure_log, | |
| f"[invalid-labels] file_id={batch_rows[row_id].record.get('file_id')} " | |
| f"tokens_len={len(tokens)} labels_len={len(labels)}", | |
| ) | |
| continue | |
| mapping[row_id] = labels | |
| if len(mapping) != len(batch_rows): | |
| missing = sorted(set(range(len(batch_rows))) - set(mapping)) | |
| append_failure_log( | |
| failure_log, | |
| f"[missing] model={model} batch={len(batch_rows)} missing={missing}", | |
| ) | |
| raise ValueError(f"incomplete/invalid rows from model: missing={missing}") | |
| return mapping, usage_stats | |
| except Exception as exc: # noqa: BLE001 | |
| last_error = exc | |
| # Some compatible gateways may not support all optional fields. | |
| # Downgrade progressively and keep structured tool output whenever possible. | |
| if isinstance(exc, requests.HTTPError) and exc.response is not None and exc.response.status_code == 400: | |
| response_text = (exc.response.text or "")[:1200] | |
| lowered = response_text.lower() | |
| append_failure_log( | |
| failure_log, | |
| f"[http400] attempt={attempt} model={model} body_cfg={cfg} response={response_text!r}", | |
| ) | |
| if "prompt_cache_retention" in lowered and cfg["include_cache_retention"]: | |
| cfg["include_cache_retention"] = False | |
| elif "prompt_cache_key" in lowered and cfg["include_cache_key"]: | |
| cfg["include_cache_key"] = False | |
| elif "reasoning" in lowered and cfg["include_reasoning"]: | |
| cfg["include_reasoning"] = False | |
| elif "tool_choice" in lowered and cfg["include_tool_choice"]: | |
| cfg["include_tool_choice"] = False | |
| elif "tools" in lowered and cfg["include_tools"]: | |
| cfg["include_tools"] = False | |
| if attempt == retries: | |
| break | |
| time.sleep(0.8 * attempt) | |
| raise RuntimeError(f"failed relabel batch after {retries} attempts: {last_error}") | |
| def write_jsonl(path: Path, records: Sequence[Dict[str, Any]]) -> None: | |
| tmp = path.with_suffix(path.suffix + ".tmp") | |
| with tmp.open("w", encoding="utf-8", newline="") as f: | |
| for rec in records: | |
| f.write(json.dumps(rec, ensure_ascii=False, separators=(",", ":")) + "\n") | |
| tmp.replace(path) | |
| def process_batch_with_fallback( | |
| api_base: str, | |
| api_key: str, | |
| model: str, | |
| batch: Sequence[Row], | |
| prompt_cache_key: str, | |
| prompt_cache_retention: str, | |
| reasoning_effort: str, | |
| user_agent: str, | |
| retries: int, | |
| failure_log: str, | |
| http_timeout: int, | |
| ) -> List[tuple[Row, List[str]]]: | |
| usage_total = UsageStats() | |
| try: | |
| mapping, usage = relabel_batch( | |
| api_base=api_base, | |
| api_key=api_key, | |
| model=model, | |
| batch_rows=batch, | |
| prompt_cache_key=prompt_cache_key, | |
| prompt_cache_retention=prompt_cache_retention, | |
| reasoning_effort=reasoning_effort, | |
| user_agent=user_agent, | |
| retries=retries, | |
| failure_log=failure_log, | |
| http_timeout=http_timeout, | |
| ) | |
| usage_total.add(usage) | |
| except RuntimeError: | |
| mapping = {} | |
| for idx, row in enumerate(batch): | |
| try: | |
| single, usage = relabel_batch( | |
| api_base=api_base, | |
| api_key=api_key, | |
| model=model, | |
| batch_rows=[row], | |
| prompt_cache_key=prompt_cache_key, | |
| prompt_cache_retention=prompt_cache_retention, | |
| reasoning_effort=reasoning_effort, | |
| user_agent=user_agent, | |
| retries=max(retries, 4), | |
| failure_log=failure_log, | |
| http_timeout=http_timeout, | |
| ) | |
| usage_total.add(usage) | |
| mapping[idx] = single[0] | |
| except RuntimeError as exc: | |
| append_failure_log( | |
| failure_log, | |
| f"[row-skip] file_id={row.record.get('file_id')} line={row.line_no} reason={exc}", | |
| ) | |
| # Hard fallback: enforce contiguous TITLE rather than keeping polluted labels. | |
| toks = row.record.get("tokens", []) | |
| lbs = row.record.get("labels", []) | |
| if isinstance(toks, list) and isinstance(lbs, list) and len(toks) == len(lbs): | |
| mapping[idx] = force_single_title_segment(toks, lbs) | |
| else: | |
| mapping[idx] = lbs | |
| updates: List[tuple[Row, List[str]]] = [] | |
| for row_id, labels in mapping.items(): | |
| row = batch[row_id] | |
| rec = dict(row.record) | |
| tokens = rec.get("tokens", []) | |
| rec["labels"] = force_single_title_segment(tokens, labels) | |
| repaired, _repairs = repair_jsonl_item(rec) | |
| new_labels = repaired.get("labels", rec.get("labels", [])) | |
| new_labels = force_bracket_delimiters_o(tokens, new_labels) | |
| updates.append((row, new_labels)) | |
| return updates, usage_total | |
| def process_batch_timed( | |
| meter: ConcurrentMeter, | |
| api_base: str, | |
| api_key: str, | |
| model: str, | |
| batch: Sequence[Row], | |
| prompt_cache_key: str, | |
| prompt_cache_retention: str, | |
| reasoning_effort: str, | |
| user_agent: str, | |
| retries: int, | |
| failure_log: str, | |
| http_timeout: int, | |
| ) -> Dict[str, Any]: | |
| meter.task_start() | |
| t0 = time.time() | |
| try: | |
| updates, usage = process_batch_with_fallback( | |
| api_base=api_base, | |
| api_key=api_key, | |
| model=model, | |
| batch=batch, | |
| prompt_cache_key=prompt_cache_key, | |
| prompt_cache_retention=prompt_cache_retention, | |
| reasoning_effort=reasoning_effort, | |
| user_agent=user_agent, | |
| retries=retries, | |
| failure_log=failure_log, | |
| http_timeout=http_timeout, | |
| ) | |
| return { | |
| "updates": updates, | |
| "elapsed": time.time() - t0, | |
| "batch_size": len(batch), | |
| "usage": usage, | |
| } | |
| finally: | |
| meter.task_end() | |
| def main() -> None: | |
| args = parse_args() | |
| api_key = args.api_key or os.environ.get("ANIFILEBERT_RELABEL_API_KEY") | |
| if not api_key: | |
| raise SystemExit("Missing API key. Use --api-key or env ANIFILEBERT_RELABEL_API_KEY") | |
| input_path = Path(args.input) | |
| output_path = Path(args.output) | |
| all_records, selected_rows = load_rows(input_path, args.selector) | |
| if args.min_token_len > 0 or args.max_token_len > 0: | |
| filtered: List[Row] = [] | |
| for row in selected_rows: | |
| tok_len = len(row.record.get("tokens", [])) | |
| if tok_len < args.min_token_len: | |
| continue | |
| if args.max_token_len > 0 and tok_len > args.max_token_len: | |
| continue | |
| filtered.append(row) | |
| selected_rows = filtered | |
| if args.sort_by == "token_len_asc": | |
| selected_rows.sort(key=lambda r: len(r.record.get("tokens", []))) | |
| if args.skip_selected > 0: | |
| selected_rows = selected_rows[args.skip_selected:] | |
| if args.max_rows > 0: | |
| selected_rows = selected_rows[: args.max_rows] | |
| if not selected_rows: | |
| print("selected_rows=0; nothing to do") | |
| if output_path != input_path: | |
| write_jsonl(output_path, all_records) | |
| return | |
| total = len(selected_rows) | |
| changed = 0 | |
| concurrency = max(1, min(args.concurrency, 8)) | |
| batches: List[List[Row]] = [ | |
| selected_rows[i:i + args.batch_size] | |
| for i in range(0, total, args.batch_size) | |
| ] | |
| done_rows = 0 | |
| wall_start = time.time() | |
| meter = ConcurrentMeter() | |
| total_batch_elapsed = 0.0 | |
| completed_batches = 0 | |
| usage_total = UsageStats() | |
| with ThreadPoolExecutor(max_workers=concurrency) as executor: | |
| futures = [ | |
| executor.submit( | |
| process_batch_timed, | |
| meter, | |
| api_base=args.api_base, | |
| api_key=api_key, | |
| model=args.model, | |
| batch=batch, | |
| prompt_cache_key=args.prompt_cache_key, | |
| prompt_cache_retention=args.prompt_cache_retention, | |
| reasoning_effort=args.reasoning_effort, | |
| user_agent=args.user_agent, | |
| retries=args.retries, | |
| failure_log=args.failure_log, | |
| http_timeout=args.http_timeout, | |
| ) | |
| for batch in batches | |
| ] | |
| for fut in as_completed(futures): | |
| result = fut.result() | |
| updates = result["updates"] | |
| total_batch_elapsed += float(result["elapsed"]) | |
| completed_batches += 1 | |
| usage_total.add(result["usage"]) | |
| for row, new_labels in updates: | |
| rec = row.record | |
| if rec.get("labels") != new_labels: | |
| rec["labels"] = new_labels | |
| changed += 1 | |
| done_rows += len(updates) | |
| snap = meter.snapshot() | |
| wall_elapsed = max(1e-9, snap["timestamp"] - wall_start) | |
| rows_per_sec = done_rows / wall_elapsed | |
| avg_active = snap["active_time_accum"] / wall_elapsed | |
| in_tok_per_sec = usage_total.input_tokens / wall_elapsed | |
| out_tok_per_sec = usage_total.output_tokens / wall_elapsed | |
| hourly_usd = 0.0 | |
| if args.usd_per_1m_input > 0 or args.usd_per_1m_output > 0: | |
| cost = (usage_total.input_tokens / 1_000_000.0) * args.usd_per_1m_input + ( | |
| usage_total.output_tokens / 1_000_000.0 | |
| ) * args.usd_per_1m_output | |
| hourly_usd = cost / wall_elapsed * 3600.0 | |
| print( | |
| f"processed={done_rows}/{total} changed={changed} " | |
| f"rows_per_sec={rows_per_sec:.2f} active_now={int(snap['current_active'])} " | |
| f"avg_active={avg_active:.2f} max_active={int(snap['max_active'])}/{concurrency} " | |
| f"in_tok_s={in_tok_per_sec:.1f} out_tok_s={out_tok_per_sec:.1f} usd_h={hourly_usd:.3f}" | |
| ) | |
| if args.checkpoint_rows > 0 and (done_rows % args.checkpoint_rows == 0 or done_rows == total): | |
| write_jsonl(output_path, all_records) | |
| if args.sleep_ms > 0: | |
| time.sleep(args.sleep_ms / 1000.0) | |
| # rows in selected_rows reference dicts in all_records by identity, so changes are already reflected. | |
| write_jsonl(output_path, all_records) | |
| wall_total = time.time() - wall_start | |
| final_snap = meter.snapshot() | |
| avg_active = final_snap["active_time_accum"] / max(1e-9, wall_total) | |
| perf_summary = { | |
| "wall_seconds": wall_total, | |
| "rows_processed": done_rows, | |
| "rows_per_second": done_rows / max(1e-9, wall_total), | |
| "batches_completed": completed_batches, | |
| "avg_batch_seconds": total_batch_elapsed / max(1, completed_batches), | |
| "avg_active_workers": avg_active, | |
| "max_active_workers": int(final_snap["max_active"]), | |
| "configured_workers": concurrency, | |
| "input_tokens": usage_total.input_tokens, | |
| "output_tokens": usage_total.output_tokens, | |
| "cached_tokens": usage_total.cached_tokens, | |
| "reasoning_tokens": usage_total.reasoning_tokens, | |
| "input_tokens_per_sec": usage_total.input_tokens / max(1e-9, wall_total), | |
| "output_tokens_per_sec": usage_total.output_tokens / max(1e-9, wall_total), | |
| "input_tokens_per_hour": usage_total.input_tokens / max(1e-9, wall_total) * 3600.0, | |
| "output_tokens_per_hour": usage_total.output_tokens / max(1e-9, wall_total) * 3600.0, | |
| "usd_per_1m_input": args.usd_per_1m_input, | |
| "usd_per_1m_output": args.usd_per_1m_output, | |
| } | |
| if args.usd_per_1m_input > 0 or args.usd_per_1m_output > 0: | |
| total_cost = (usage_total.input_tokens / 1_000_000.0) * args.usd_per_1m_input + ( | |
| usage_total.output_tokens / 1_000_000.0 | |
| ) * args.usd_per_1m_output | |
| perf_summary["usd_total"] = total_cost | |
| perf_summary["usd_per_hour"] = total_cost / max(1e-9, wall_total) * 3600.0 | |
| if args.perf_log: | |
| p = Path(args.perf_log) | |
| p.parent.mkdir(parents=True, exist_ok=True) | |
| p.write_text(json.dumps(perf_summary, ensure_ascii=False, indent=2), encoding="utf-8") | |
| print( | |
| f"perf wall={wall_total:.1f}s rows_per_sec={perf_summary['rows_per_second']:.2f} " | |
| f"avg_active={avg_active:.2f} max_active={int(final_snap['max_active'])}/{concurrency} " | |
| f"in_tok_s={perf_summary['input_tokens_per_sec']:.1f} out_tok_s={perf_summary['output_tokens_per_sec']:.1f}" | |
| ) | |
| print(f"done selected_rows={total} changed_rows={changed} output={output_path}") | |
| if __name__ == "__main__": | |
| main() | |