Token Classification
Transformers
ONNX
Safetensors
English
Japanese
Chinese
bert
anime
filename-parsing
Eval Results (legacy)
Instructions to use ModerRAS/AniFileBERT with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ModerRAS/AniFileBERT with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("token-classification", model="ModerRAS/AniFileBERT")# Load model directly from transformers import AutoTokenizer, AutoModelForTokenClassification tokenizer = AutoTokenizer.from_pretrained("ModerRAS/AniFileBERT") model = AutoModelForTokenClassification.from_pretrained("ModerRAS/AniFileBERT") - Notebooks
- Google Colab
- Kaggle
Add robust LLM relabel pipeline and enforce contiguous title
Browse files- datasets/AnimeName +1 -1
- tools/enforce_contiguous_title.py +176 -0
- tools/llm_relabel_rows.py +159 -25
datasets/AnimeName
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
Subproject commit
|
|
|
|
| 1 |
+
Subproject commit ad48d8da74cf8e611a14f22ffc2a9734872e1f03
|
tools/enforce_contiguous_title.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Enforce a single contiguous TITLE span for every JSONL row.
|
| 4 |
+
|
| 5 |
+
This script is deterministic and streaming-friendly for very large datasets.
|
| 6 |
+
It is intended as a hard safety pass before/alongside LLM relabeling.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import json
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Dict, List, Sequence, Tuple
|
| 15 |
+
|
| 16 |
+
from anifilebert.label_repairs import repair_jsonl_item
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def parse_args() -> argparse.Namespace:
|
| 20 |
+
parser = argparse.ArgumentParser(description="Force contiguous TITLE spans in JSONL labels")
|
| 21 |
+
parser.add_argument("--input", required=True, help="Input JSONL")
|
| 22 |
+
parser.add_argument("--output", required=True, help="Output JSONL")
|
| 23 |
+
parser.add_argument("--manifest-output", default="", help="Optional manifest JSON")
|
| 24 |
+
parser.add_argument("--progress", type=int, default=50000, help="Progress print interval")
|
| 25 |
+
return parser.parse_args()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def normalize_iob2(labels: Sequence[str]) -> List[str]:
|
| 29 |
+
out: List[str] = []
|
| 30 |
+
prev = ""
|
| 31 |
+
for lb in labels:
|
| 32 |
+
if not isinstance(lb, str) or not lb.startswith(("B-", "I-")):
|
| 33 |
+
out.append("O")
|
| 34 |
+
prev = ""
|
| 35 |
+
continue
|
| 36 |
+
entity = lb.split("-", 1)[1]
|
| 37 |
+
prefix = "I" if prev == entity else "B"
|
| 38 |
+
out.append(f"{prefix}-{entity}")
|
| 39 |
+
prev = entity
|
| 40 |
+
return out
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def is_discontinuous_title(labels: Sequence[str]) -> bool:
|
| 44 |
+
seen_title = False
|
| 45 |
+
seen_gap = False
|
| 46 |
+
for lb in labels:
|
| 47 |
+
is_title = isinstance(lb, str) and lb.endswith("TITLE")
|
| 48 |
+
if is_title:
|
| 49 |
+
if seen_title and seen_gap:
|
| 50 |
+
return True
|
| 51 |
+
seen_title = True
|
| 52 |
+
elif seen_title:
|
| 53 |
+
seen_gap = True
|
| 54 |
+
return False
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def title_segments(labels: Sequence[str]) -> List[Tuple[int, int]]:
|
| 58 |
+
segs: List[Tuple[int, int]] = []
|
| 59 |
+
i = 0
|
| 60 |
+
n = len(labels)
|
| 61 |
+
while i < n:
|
| 62 |
+
if str(labels[i]).endswith("TITLE"):
|
| 63 |
+
j = i + 1
|
| 64 |
+
while j < n and str(labels[j]).endswith("TITLE"):
|
| 65 |
+
j += 1
|
| 66 |
+
segs.append((i, j))
|
| 67 |
+
i = j
|
| 68 |
+
else:
|
| 69 |
+
i += 1
|
| 70 |
+
return segs
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def first_episode_or_special_index(labels: Sequence[str]) -> int:
|
| 74 |
+
for idx, lb in enumerate(labels):
|
| 75 |
+
text = str(lb)
|
| 76 |
+
if text.endswith("EPISODE") or text.endswith("SPECIAL"):
|
| 77 |
+
return idx
|
| 78 |
+
return len(labels)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def pick_primary_title_segment(labels: Sequence[str], segs: Sequence[Tuple[int, int]]) -> Tuple[int, int]:
|
| 82 |
+
if not segs:
|
| 83 |
+
return (-1, -1)
|
| 84 |
+
bound = first_episode_or_special_index(labels)
|
| 85 |
+
before = [seg for seg in segs if seg[0] < bound]
|
| 86 |
+
# Prefer the earliest title span before episode/special boundary.
|
| 87 |
+
if before:
|
| 88 |
+
return min(before, key=lambda seg: seg[0])
|
| 89 |
+
return min(segs, key=lambda seg: seg[0])
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def enforce_contiguous_title(labels: Sequence[str]) -> List[str]:
|
| 93 |
+
fixed = normalize_iob2(labels)
|
| 94 |
+
segs = title_segments(fixed)
|
| 95 |
+
if len(segs) <= 1:
|
| 96 |
+
return fixed
|
| 97 |
+
keep_start, keep_end = pick_primary_title_segment(fixed, segs)
|
| 98 |
+
if keep_start < 0:
|
| 99 |
+
return fixed
|
| 100 |
+
|
| 101 |
+
out = list(fixed)
|
| 102 |
+
for idx, lb in enumerate(out):
|
| 103 |
+
if str(lb).endswith("TITLE") and not (keep_start <= idx < keep_end):
|
| 104 |
+
out[idx] = "O"
|
| 105 |
+
return normalize_iob2(out)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def main() -> None:
|
| 109 |
+
args = parse_args()
|
| 110 |
+
input_path = Path(args.input)
|
| 111 |
+
output_path = Path(args.output)
|
| 112 |
+
manifest_path = Path(args.manifest_output) if args.manifest_output else output_path.with_suffix(".contiguous_title.manifest.json")
|
| 113 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 114 |
+
manifest_path.parent.mkdir(parents=True, exist_ok=True)
|
| 115 |
+
|
| 116 |
+
rows = 0
|
| 117 |
+
changed_rows = 0
|
| 118 |
+
bad_before = 0
|
| 119 |
+
bad_after = 0
|
| 120 |
+
invalid_rows = 0
|
| 121 |
+
|
| 122 |
+
tmp_path = output_path.with_suffix(output_path.suffix + ".tmp")
|
| 123 |
+
with input_path.open("r", encoding="utf-8") as src, tmp_path.open("w", encoding="utf-8", newline="\n") as dst:
|
| 124 |
+
for line in src:
|
| 125 |
+
line = line.rstrip("\n")
|
| 126 |
+
if not line:
|
| 127 |
+
continue
|
| 128 |
+
rows += 1
|
| 129 |
+
rec = json.loads(line)
|
| 130 |
+
tokens = rec.get("tokens", [])
|
| 131 |
+
labels = rec.get("labels", [])
|
| 132 |
+
if not isinstance(tokens, list) or not isinstance(labels, list) or len(tokens) != len(labels):
|
| 133 |
+
invalid_rows += 1
|
| 134 |
+
dst.write(json.dumps(rec, ensure_ascii=False, separators=(",", ":")) + "\n")
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
if is_discontinuous_title(labels):
|
| 138 |
+
bad_before += 1
|
| 139 |
+
|
| 140 |
+
new_labels = enforce_contiguous_title(labels)
|
| 141 |
+
out_rec: Dict = dict(rec)
|
| 142 |
+
out_rec["labels"] = new_labels
|
| 143 |
+
repaired, _ = repair_jsonl_item(out_rec)
|
| 144 |
+
out_labels = repaired.get("labels", new_labels)
|
| 145 |
+
if is_discontinuous_title(out_labels):
|
| 146 |
+
bad_after += 1
|
| 147 |
+
|
| 148 |
+
if out_labels != labels:
|
| 149 |
+
changed_rows += 1
|
| 150 |
+
repaired["labels"] = out_labels
|
| 151 |
+
dst.write(json.dumps(repaired, ensure_ascii=False, separators=(",", ":")) + "\n")
|
| 152 |
+
|
| 153 |
+
if args.progress > 0 and rows % args.progress == 0:
|
| 154 |
+
print(
|
| 155 |
+
f"rows={rows} changed={changed_rows} "
|
| 156 |
+
f"bad_before={bad_before} bad_after={bad_after} invalid={invalid_rows}"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
tmp_path.replace(output_path)
|
| 160 |
+
|
| 161 |
+
manifest = {
|
| 162 |
+
"input": str(input_path),
|
| 163 |
+
"output": str(output_path),
|
| 164 |
+
"rows": rows,
|
| 165 |
+
"changed_rows": changed_rows,
|
| 166 |
+
"discontinuous_before": bad_before,
|
| 167 |
+
"discontinuous_after": bad_after,
|
| 168 |
+
"invalid_rows": invalid_rows,
|
| 169 |
+
}
|
| 170 |
+
manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2), encoding="utf-8")
|
| 171 |
+
print(json.dumps(manifest, ensure_ascii=False, indent=2))
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
main()
|
| 176 |
+
|
tools/llm_relabel_rows.py
CHANGED
|
@@ -20,6 +20,7 @@ from pathlib import Path
|
|
| 20 |
from typing import Any, Dict, List, Sequence
|
| 21 |
|
| 22 |
import requests
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
ALLOWED_LABELS = {
|
|
@@ -151,6 +152,7 @@ def parse_args() -> argparse.Namespace:
|
|
| 151 |
p.add_argument("--checkpoint-rows", type=int, default=100, help="Write checkpoint every N processed rows")
|
| 152 |
p.add_argument("--failure-log", default="reports/llm_relabel_failures.log", help="Failure log path")
|
| 153 |
p.add_argument("--perf-log", default="", help="Optional JSON perf summary path")
|
|
|
|
| 154 |
p.add_argument("--usd-per-1m-input", type=float, default=0.75, help="Input token price (USD per 1M tokens)")
|
| 155 |
p.add_argument("--usd-per-1m-output", type=float, default=4.5, help="Output token price (USD per 1M tokens)")
|
| 156 |
p.add_argument(
|
|
@@ -244,6 +246,64 @@ def validate_labels(tokens: Sequence[str], labels: Sequence[str]) -> bool:
|
|
| 244 |
return True
|
| 245 |
|
| 246 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
def response_schema() -> Dict[str, Any]:
|
| 248 |
return {
|
| 249 |
"type": "object",
|
|
@@ -276,6 +336,44 @@ def append_failure_log(path: str, message: str) -> None:
|
|
| 276 |
f.write(message.rstrip() + "\n")
|
| 277 |
|
| 278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
def parse_usage(response_obj: Dict[str, Any]) -> UsageStats:
|
| 280 |
usage = response_obj.get("usage", {}) or {}
|
| 281 |
in_details = usage.get("input_tokens_details", {}) or {}
|
|
@@ -299,6 +397,7 @@ def relabel_batch(
|
|
| 299 |
user_agent: str,
|
| 300 |
retries: int,
|
| 301 |
failure_log: str,
|
|
|
|
| 302 |
) -> tuple[Dict[int, List[str]], UsageStats]:
|
| 303 |
url = f"{api_base.rstrip('/')}/responses"
|
| 304 |
headers = {
|
|
@@ -308,29 +407,30 @@ def relabel_batch(
|
|
| 308 |
}
|
| 309 |
user_payload = build_user_payload(batch_rows)
|
| 310 |
|
| 311 |
-
|
| 312 |
-
"
|
| 313 |
-
"
|
| 314 |
-
"
|
| 315 |
-
"
|
| 316 |
-
"
|
| 317 |
-
"reasoning": {"effort": reasoning_effort},
|
| 318 |
-
"tools": [
|
| 319 |
-
{
|
| 320 |
-
"type": "function",
|
| 321 |
-
"name": "submit_labels",
|
| 322 |
-
"description": "Submit relabeled BIO labels.",
|
| 323 |
-
"parameters": response_schema(),
|
| 324 |
-
"strict": True,
|
| 325 |
-
}
|
| 326 |
-
],
|
| 327 |
-
"tool_choice": {"type": "function", "name": "submit_labels"},
|
| 328 |
}
|
| 329 |
|
| 330 |
last_error: Exception | None = None
|
| 331 |
for attempt in range(1, retries + 1):
|
| 332 |
try:
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
resp.raise_for_status()
|
| 335 |
obj = resp.json()
|
| 336 |
usage_stats = parse_usage(obj)
|
|
@@ -378,12 +478,25 @@ def relabel_batch(
|
|
| 378 |
return mapping, usage_stats
|
| 379 |
except Exception as exc: # noqa: BLE001
|
| 380 |
last_error = exc
|
| 381 |
-
# Some compatible gateways may not support
|
|
|
|
| 382 |
if isinstance(exc, requests.HTTPError) and exc.response is not None and exc.response.status_code == 400:
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
if attempt == retries:
|
| 388 |
break
|
| 389 |
time.sleep(0.8 * attempt)
|
|
@@ -410,6 +523,7 @@ def process_batch_with_fallback(
|
|
| 410 |
user_agent: str,
|
| 411 |
retries: int,
|
| 412 |
failure_log: str,
|
|
|
|
| 413 |
) -> List[tuple[Row, List[str]]]:
|
| 414 |
usage_total = UsageStats()
|
| 415 |
try:
|
|
@@ -424,6 +538,7 @@ def process_batch_with_fallback(
|
|
| 424 |
user_agent=user_agent,
|
| 425 |
retries=retries,
|
| 426 |
failure_log=failure_log,
|
|
|
|
| 427 |
)
|
| 428 |
usage_total.add(usage)
|
| 429 |
except RuntimeError:
|
|
@@ -441,6 +556,7 @@ def process_batch_with_fallback(
|
|
| 441 |
user_agent=user_agent,
|
| 442 |
retries=max(retries, 4),
|
| 443 |
failure_log=failure_log,
|
|
|
|
| 444 |
)
|
| 445 |
usage_total.add(usage)
|
| 446 |
mapping[idx] = single[0]
|
|
@@ -449,8 +565,23 @@ def process_batch_with_fallback(
|
|
| 449 |
failure_log,
|
| 450 |
f"[row-skip] file_id={row.record.get('file_id')} line={row.line_no} reason={exc}",
|
| 451 |
)
|
| 452 |
-
|
| 453 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
|
| 455 |
|
| 456 |
def process_batch_timed(
|
|
@@ -465,6 +596,7 @@ def process_batch_timed(
|
|
| 465 |
user_agent: str,
|
| 466 |
retries: int,
|
| 467 |
failure_log: str,
|
|
|
|
| 468 |
) -> Dict[str, Any]:
|
| 469 |
meter.task_start()
|
| 470 |
t0 = time.time()
|
|
@@ -480,6 +612,7 @@ def process_batch_timed(
|
|
| 480 |
user_agent=user_agent,
|
| 481 |
retries=retries,
|
| 482 |
failure_log=failure_log,
|
|
|
|
| 483 |
)
|
| 484 |
return {
|
| 485 |
"updates": updates,
|
|
@@ -552,6 +685,7 @@ def main() -> None:
|
|
| 552 |
user_agent=args.user_agent,
|
| 553 |
retries=args.retries,
|
| 554 |
failure_log=args.failure_log,
|
|
|
|
| 555 |
)
|
| 556 |
for batch in batches
|
| 557 |
]
|
|
|
|
| 20 |
from typing import Any, Dict, List, Sequence
|
| 21 |
|
| 22 |
import requests
|
| 23 |
+
from anifilebert.label_repairs import repair_jsonl_item
|
| 24 |
|
| 25 |
|
| 26 |
ALLOWED_LABELS = {
|
|
|
|
| 152 |
p.add_argument("--checkpoint-rows", type=int, default=100, help="Write checkpoint every N processed rows")
|
| 153 |
p.add_argument("--failure-log", default="reports/llm_relabel_failures.log", help="Failure log path")
|
| 154 |
p.add_argument("--perf-log", default="", help="Optional JSON perf summary path")
|
| 155 |
+
p.add_argument("--http-timeout", type=int, default=240, help="HTTP timeout in seconds per request")
|
| 156 |
p.add_argument("--usd-per-1m-input", type=float, default=0.75, help="Input token price (USD per 1M tokens)")
|
| 157 |
p.add_argument("--usd-per-1m-output", type=float, default=4.5, help="Output token price (USD per 1M tokens)")
|
| 158 |
p.add_argument(
|
|
|
|
| 246 |
return True
|
| 247 |
|
| 248 |
|
| 249 |
+
def normalize_iob2_labels(labels: Sequence[str]) -> List[str]:
|
| 250 |
+
normalized: List[str] = []
|
| 251 |
+
prev_entity = ""
|
| 252 |
+
for lb in labels:
|
| 253 |
+
if not isinstance(lb, str) or not lb.startswith(("B-", "I-")):
|
| 254 |
+
normalized.append("O")
|
| 255 |
+
prev_entity = ""
|
| 256 |
+
continue
|
| 257 |
+
entity = lb.split("-", 1)[1]
|
| 258 |
+
prefix = "I" if prev_entity == entity else "B"
|
| 259 |
+
normalized.append(f"{prefix}-{entity}")
|
| 260 |
+
prev_entity = entity
|
| 261 |
+
return normalized
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def title_segments(labels: Sequence[str]) -> List[tuple[int, int]]:
|
| 265 |
+
segments: List[tuple[int, int]] = []
|
| 266 |
+
i = 0
|
| 267 |
+
n = len(labels)
|
| 268 |
+
while i < n:
|
| 269 |
+
if str(labels[i]).endswith("TITLE"):
|
| 270 |
+
j = i + 1
|
| 271 |
+
while j < n and str(labels[j]).endswith("TITLE"):
|
| 272 |
+
j += 1
|
| 273 |
+
segments.append((i, j))
|
| 274 |
+
i = j
|
| 275 |
+
else:
|
| 276 |
+
i += 1
|
| 277 |
+
return segments
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def force_single_title_segment(tokens: Sequence[str], labels: Sequence[str]) -> List[str]:
|
| 281 |
+
"""Guarantee TITLE is a single contiguous segment."""
|
| 282 |
+
if len(tokens) != len(labels):
|
| 283 |
+
return list(labels)
|
| 284 |
+
fixed = normalize_iob2_labels(labels)
|
| 285 |
+
segs = title_segments(fixed)
|
| 286 |
+
if len(segs) <= 1:
|
| 287 |
+
return fixed
|
| 288 |
+
|
| 289 |
+
first_episode = next((idx for idx, lb in enumerate(fixed) if str(lb).endswith("EPISODE")), len(fixed))
|
| 290 |
+
|
| 291 |
+
def score(seg: tuple[int, int]) -> tuple[int, int, int]:
|
| 292 |
+
start, end = seg
|
| 293 |
+
length = end - start
|
| 294 |
+
before_episode = 1 if start < first_episode else 0
|
| 295 |
+
return (before_episode, length, -start)
|
| 296 |
+
|
| 297 |
+
keep = max(segs, key=score)
|
| 298 |
+
ks, ke = keep
|
| 299 |
+
out = list(fixed)
|
| 300 |
+
for i in range(len(out)):
|
| 301 |
+
if str(out[i]).endswith("TITLE") and not (ks <= i < ke):
|
| 302 |
+
out[i] = "O"
|
| 303 |
+
out = normalize_iob2_labels(out)
|
| 304 |
+
return out
|
| 305 |
+
|
| 306 |
+
|
| 307 |
def response_schema() -> Dict[str, Any]:
|
| 308 |
return {
|
| 309 |
"type": "object",
|
|
|
|
| 336 |
f.write(message.rstrip() + "\n")
|
| 337 |
|
| 338 |
|
| 339 |
+
def build_request_body(
|
| 340 |
+
model: str,
|
| 341 |
+
user_payload: str,
|
| 342 |
+
prompt_cache_key: str,
|
| 343 |
+
prompt_cache_retention: str,
|
| 344 |
+
reasoning_effort: str,
|
| 345 |
+
include_tools: bool = True,
|
| 346 |
+
include_tool_choice: bool = True,
|
| 347 |
+
include_reasoning: bool = True,
|
| 348 |
+
include_cache_key: bool = True,
|
| 349 |
+
include_cache_retention: bool = True,
|
| 350 |
+
) -> Dict[str, Any]:
|
| 351 |
+
body: Dict[str, Any] = {
|
| 352 |
+
"model": model,
|
| 353 |
+
"instructions": SYSTEM_INSTRUCTIONS,
|
| 354 |
+
"input": user_payload,
|
| 355 |
+
}
|
| 356 |
+
if include_cache_key:
|
| 357 |
+
body["prompt_cache_key"] = prompt_cache_key
|
| 358 |
+
if include_cache_retention:
|
| 359 |
+
body["prompt_cache_retention"] = prompt_cache_retention
|
| 360 |
+
if include_reasoning:
|
| 361 |
+
body["reasoning"] = {"effort": reasoning_effort}
|
| 362 |
+
if include_tools:
|
| 363 |
+
body["tools"] = [
|
| 364 |
+
{
|
| 365 |
+
"type": "function",
|
| 366 |
+
"name": "submit_labels",
|
| 367 |
+
"description": "Submit relabeled BIO labels.",
|
| 368 |
+
"parameters": response_schema(),
|
| 369 |
+
"strict": True,
|
| 370 |
+
}
|
| 371 |
+
]
|
| 372 |
+
if include_tool_choice and include_tools:
|
| 373 |
+
body["tool_choice"] = {"type": "function", "name": "submit_labels"}
|
| 374 |
+
return body
|
| 375 |
+
|
| 376 |
+
|
| 377 |
def parse_usage(response_obj: Dict[str, Any]) -> UsageStats:
|
| 378 |
usage = response_obj.get("usage", {}) or {}
|
| 379 |
in_details = usage.get("input_tokens_details", {}) or {}
|
|
|
|
| 397 |
user_agent: str,
|
| 398 |
retries: int,
|
| 399 |
failure_log: str,
|
| 400 |
+
http_timeout: int,
|
| 401 |
) -> tuple[Dict[int, List[str]], UsageStats]:
|
| 402 |
url = f"{api_base.rstrip('/')}/responses"
|
| 403 |
headers = {
|
|
|
|
| 407 |
}
|
| 408 |
user_payload = build_user_payload(batch_rows)
|
| 409 |
|
| 410 |
+
cfg = {
|
| 411 |
+
"include_tools": True,
|
| 412 |
+
"include_tool_choice": True,
|
| 413 |
+
"include_reasoning": True,
|
| 414 |
+
"include_cache_key": True,
|
| 415 |
+
"include_cache_retention": True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
}
|
| 417 |
|
| 418 |
last_error: Exception | None = None
|
| 419 |
for attempt in range(1, retries + 1):
|
| 420 |
try:
|
| 421 |
+
body = build_request_body(
|
| 422 |
+
model=model,
|
| 423 |
+
user_payload=user_payload,
|
| 424 |
+
prompt_cache_key=prompt_cache_key,
|
| 425 |
+
prompt_cache_retention=prompt_cache_retention,
|
| 426 |
+
reasoning_effort=reasoning_effort,
|
| 427 |
+
include_tools=cfg["include_tools"],
|
| 428 |
+
include_tool_choice=cfg["include_tool_choice"],
|
| 429 |
+
include_reasoning=cfg["include_reasoning"],
|
| 430 |
+
include_cache_key=cfg["include_cache_key"],
|
| 431 |
+
include_cache_retention=cfg["include_cache_retention"],
|
| 432 |
+
)
|
| 433 |
+
resp = requests.post(url, headers=headers, json=body, timeout=http_timeout)
|
| 434 |
resp.raise_for_status()
|
| 435 |
obj = resp.json()
|
| 436 |
usage_stats = parse_usage(obj)
|
|
|
|
| 478 |
return mapping, usage_stats
|
| 479 |
except Exception as exc: # noqa: BLE001
|
| 480 |
last_error = exc
|
| 481 |
+
# Some compatible gateways may not support all optional fields.
|
| 482 |
+
# Downgrade progressively and keep structured tool output whenever possible.
|
| 483 |
if isinstance(exc, requests.HTTPError) and exc.response is not None and exc.response.status_code == 400:
|
| 484 |
+
response_text = (exc.response.text or "")[:1200]
|
| 485 |
+
lowered = response_text.lower()
|
| 486 |
+
append_failure_log(
|
| 487 |
+
failure_log,
|
| 488 |
+
f"[http400] attempt={attempt} model={model} body_cfg={cfg} response={response_text!r}",
|
| 489 |
+
)
|
| 490 |
+
if "prompt_cache_retention" in lowered and cfg["include_cache_retention"]:
|
| 491 |
+
cfg["include_cache_retention"] = False
|
| 492 |
+
elif "prompt_cache_key" in lowered and cfg["include_cache_key"]:
|
| 493 |
+
cfg["include_cache_key"] = False
|
| 494 |
+
elif "reasoning" in lowered and cfg["include_reasoning"]:
|
| 495 |
+
cfg["include_reasoning"] = False
|
| 496 |
+
elif "tool_choice" in lowered and cfg["include_tool_choice"]:
|
| 497 |
+
cfg["include_tool_choice"] = False
|
| 498 |
+
elif "tools" in lowered and cfg["include_tools"]:
|
| 499 |
+
cfg["include_tools"] = False
|
| 500 |
if attempt == retries:
|
| 501 |
break
|
| 502 |
time.sleep(0.8 * attempt)
|
|
|
|
| 523 |
user_agent: str,
|
| 524 |
retries: int,
|
| 525 |
failure_log: str,
|
| 526 |
+
http_timeout: int,
|
| 527 |
) -> List[tuple[Row, List[str]]]:
|
| 528 |
usage_total = UsageStats()
|
| 529 |
try:
|
|
|
|
| 538 |
user_agent=user_agent,
|
| 539 |
retries=retries,
|
| 540 |
failure_log=failure_log,
|
| 541 |
+
http_timeout=http_timeout,
|
| 542 |
)
|
| 543 |
usage_total.add(usage)
|
| 544 |
except RuntimeError:
|
|
|
|
| 556 |
user_agent=user_agent,
|
| 557 |
retries=max(retries, 4),
|
| 558 |
failure_log=failure_log,
|
| 559 |
+
http_timeout=http_timeout,
|
| 560 |
)
|
| 561 |
usage_total.add(usage)
|
| 562 |
mapping[idx] = single[0]
|
|
|
|
| 565 |
failure_log,
|
| 566 |
f"[row-skip] file_id={row.record.get('file_id')} line={row.line_no} reason={exc}",
|
| 567 |
)
|
| 568 |
+
# Hard fallback: enforce contiguous TITLE rather than keeping polluted labels.
|
| 569 |
+
toks = row.record.get("tokens", [])
|
| 570 |
+
lbs = row.record.get("labels", [])
|
| 571 |
+
if isinstance(toks, list) and isinstance(lbs, list) and len(toks) == len(lbs):
|
| 572 |
+
mapping[idx] = force_single_title_segment(toks, lbs)
|
| 573 |
+
else:
|
| 574 |
+
mapping[idx] = lbs
|
| 575 |
+
|
| 576 |
+
updates: List[tuple[Row, List[str]]] = []
|
| 577 |
+
for row_id, labels in mapping.items():
|
| 578 |
+
row = batch[row_id]
|
| 579 |
+
rec = dict(row.record)
|
| 580 |
+
rec["labels"] = force_single_title_segment(rec.get("tokens", []), labels)
|
| 581 |
+
repaired, _repairs = repair_jsonl_item(rec)
|
| 582 |
+
new_labels = repaired.get("labels", rec.get("labels", []))
|
| 583 |
+
updates.append((row, new_labels))
|
| 584 |
+
return updates, usage_total
|
| 585 |
|
| 586 |
|
| 587 |
def process_batch_timed(
|
|
|
|
| 596 |
user_agent: str,
|
| 597 |
retries: int,
|
| 598 |
failure_log: str,
|
| 599 |
+
http_timeout: int,
|
| 600 |
) -> Dict[str, Any]:
|
| 601 |
meter.task_start()
|
| 602 |
t0 = time.time()
|
|
|
|
| 612 |
user_agent=user_agent,
|
| 613 |
retries=retries,
|
| 614 |
failure_log=failure_log,
|
| 615 |
+
http_timeout=http_timeout,
|
| 616 |
)
|
| 617 |
return {
|
| 618 |
"updates": updates,
|
|
|
|
| 685 |
user_agent=args.user_agent,
|
| 686 |
retries=args.retries,
|
| 687 |
failure_log=args.failure_log,
|
| 688 |
+
http_timeout=args.http_timeout,
|
| 689 |
)
|
| 690 |
for batch in batches
|
| 691 |
]
|