File size: 30,298 Bytes
e964ae5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c5b41e
e964ae5
 
 
 
 
 
fed9d99
e964ae5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5af16a8
 
 
 
 
 
 
 
 
 
 
 
 
e964ae5
 
 
 
 
 
 
 
 
 
 
5af16a8
e964ae5
 
 
 
 
 
 
 
 
 
 
 
 
7c5b41e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e964ae5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c5b41e
 
 
e964ae5
 
 
 
 
 
 
7c5b41e
fed9d99
8165cc3
 
e964ae5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fed9d99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5af16a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e964ae5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fed9d99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c5b41e
 
 
 
 
 
 
 
 
 
 
 
e964ae5
 
 
 
 
 
 
 
 
 
 
fed9d99
7c5b41e
e964ae5
 
 
 
 
 
 
 
fed9d99
 
 
 
 
 
e964ae5
 
 
 
 
fed9d99
 
 
 
 
 
 
 
 
 
 
 
 
e964ae5
 
7c5b41e
e964ae5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c5b41e
e964ae5
 
fed9d99
 
e964ae5
fed9d99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e964ae5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fed9d99
e964ae5
7c5b41e
e964ae5
7c5b41e
e964ae5
 
 
 
 
 
 
 
 
 
fed9d99
e964ae5
7c5b41e
e964ae5
 
 
 
7c5b41e
e964ae5
 
 
 
 
 
 
 
 
 
fed9d99
e964ae5
7c5b41e
e964ae5
 
 
 
 
 
fed9d99
 
 
 
 
 
 
 
 
 
 
 
5af16a8
 
fed9d99
 
5af16a8
fed9d99
 
7c5b41e
 
 
 
 
 
 
 
 
 
 
 
 
 
fed9d99
7c5b41e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fed9d99
7c5b41e
 
 
 
 
 
 
 
 
e964ae5
 
 
 
 
 
 
 
 
 
 
 
7c5b41e
 
 
 
 
 
 
 
 
 
 
 
e964ae5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c5b41e
 
 
 
 
e964ae5
 
 
7c5b41e
 
e964ae5
 
 
 
 
 
 
 
 
 
fed9d99
e964ae5
 
 
 
7c5b41e
 
 
 
 
e964ae5
 
 
 
 
 
7c5b41e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e964ae5
 
 
 
 
 
 
7c5b41e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e964ae5
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
#!/usr/bin/env python3
"""
Relabel selected rows in a JSONL dataset via an OpenAI-compatible Responses API.

Designed for high-throughput cleanup with a stable prompt prefix and
`prompt_cache_key` to improve cache hit rates across calls.
"""

from __future__ import annotations

import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
import json
import os
import re
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Sequence

import requests
from anifilebert.label_repairs import repair_jsonl_item


ALLOWED_LABELS = {
    "O",
    "B-TITLE", "I-TITLE",
    "B-SEASON", "I-SEASON",
    "B-EPISODE", "I-EPISODE",
    "B-SPECIAL", "I-SPECIAL",
    "B-GROUP", "I-GROUP",
    "B-RESOLUTION", "I-RESOLUTION",
    "B-SOURCE", "I-SOURCE",
}

LANG_MARKERS = (
    "中文版",
    "日语版",
    "国语版",
    "粤语版",
    "英语版",
    "英配版",
    "中配版",
    "日配版",
)

BRACKET_DELIMITER_TOKENS = {
    "[",
    "]",
    "(",
    ")",
    "【",
    "】",
    "《",
    "》",
    "(",
    ")",
}

SYSTEM_INSTRUCTIONS = """You relabel anime filename tokens with BIO tags.

Allowed labels only:
O, B/I-TITLE, B/I-SEASON, B/I-EPISODE, B/I-SPECIAL, B/I-GROUP, B/I-RESOLUTION, B/I-SOURCE.

Hard rules:
1) Output exactly one label per token.
2) Language markers like 中文版/日语版/国语版/粤语版/英语版/英配版/中配版/日配版 must be SOURCE.
3) Episode identifiers (e.g. 01, 13, EP13, 第13集/話/话) must be EPISODE.
4) If title already appears before episode number, episode-name text after the episode number should be O (not TITLE).
5) Preserve obvious GROUP/RESOLUTION/SOURCE tags when present.
6) If bracket delimiters are split into standalone tokens (`[ ] ( ) 【 】 《 》 ( )`), they must be O.

Return strict JSON only:
{"results":[{"row_id":int,"labels":[str,...]}]}
No markdown. No explanation.
"""


@dataclass
class Row:
    line_no: int
    record: Dict[str, Any]


class ConcurrentMeter:
    def __init__(self) -> None:
        self._lock = threading.Lock()
        self.current_active = 0
        self.max_active = 0
        self.active_time_accum = 0.0
        self.last_ts = time.time()

    def _accumulate(self, now: float) -> None:
        dt = now - self.last_ts
        if dt > 0:
            self.active_time_accum += self.current_active * dt
            self.last_ts = now

    def task_start(self) -> None:
        now = time.time()
        with self._lock:
            self._accumulate(now)
            self.current_active += 1
            if self.current_active > self.max_active:
                self.max_active = self.current_active

    def task_end(self) -> None:
        now = time.time()
        with self._lock:
            self._accumulate(now)
            if self.current_active > 0:
                self.current_active -= 1

    def snapshot(self) -> Dict[str, float]:
        now = time.time()
        with self._lock:
            self._accumulate(now)
            return {
                "current_active": float(self.current_active),
                "max_active": float(self.max_active),
                "active_time_accum": float(self.active_time_accum),
                "timestamp": now,
            }


@dataclass
class UsageStats:
    input_tokens: int = 0
    output_tokens: int = 0
    cached_tokens: int = 0
    reasoning_tokens: int = 0

    def add(self, other: "UsageStats") -> None:
        self.input_tokens += int(other.input_tokens)
        self.output_tokens += int(other.output_tokens)
        self.cached_tokens += int(other.cached_tokens)
        self.reasoning_tokens += int(other.reasoning_tokens)


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Relabel selected JSONL rows via Responses API")
    p.add_argument("--input", required=True, help="Input JSONL")
    p.add_argument("--output", required=True, help="Output JSONL (can equal input)")
    p.add_argument("--api-base", required=True, help="API base URL, e.g. http://host:port/v1")
    p.add_argument("--api-key", default=None, help="API key; falls back to env ANIFILEBERT_RELABEL_API_KEY")
    p.add_argument("--model", default="gpt-5.4-mini", help="Model name")
    p.add_argument(
        "--selector",
        choices=("language", "discontinuous_title", "all"),
        default="language",
        help="Row selector",
    )
    p.add_argument("--batch-size", type=int, default=12, help="Rows per request")
    p.add_argument("--concurrency", type=int, default=4, help="Parallel request workers")
    p.add_argument("--max-rows", type=int, default=0, help="Optional cap; 0 means no cap")
    p.add_argument("--skip-selected", type=int, default=0, help="Skip this many selected rows before processing")
    p.add_argument("--min-token-len", type=int, default=0, help="Only process rows with token length >= this value")
    p.add_argument("--max-token-len", type=int, default=0, help="Only process rows with token length <= this value (0 = no limit)")
    p.add_argument("--sort-by", choices=("none", "token_len_asc"), default="none", help="Optional ordering of selected rows")
    p.add_argument("--retries", type=int, default=3, help="Retries per batch")
    p.add_argument("--sleep-ms", type=int, default=150, help="Delay between successful calls")
    p.add_argument("--prompt-cache-key", default="anifilebert-relabel-v1", help="Stable prompt cache key")
    p.add_argument("--prompt-cache-retention", default="24h", help="Prompt cache retention hint")
    p.add_argument("--reasoning-effort", default="medium", help="Reasoning effort (e.g. low/medium/high)")
    p.add_argument("--checkpoint-rows", type=int, default=100, help="Write checkpoint every N processed rows")
    p.add_argument("--failure-log", default="reports/llm_relabel_failures.log", help="Failure log path")
    p.add_argument("--perf-log", default="", help="Optional JSON perf summary path")
    p.add_argument("--http-timeout", type=int, default=240, help="HTTP timeout in seconds per request")
    p.add_argument("--usd-per-1m-input", type=float, default=0.75, help="Input token price (USD per 1M tokens)")
    p.add_argument("--usd-per-1m-output", type=float, default=4.5, help="Output token price (USD per 1M tokens)")
    p.add_argument(
        "--user-agent",
        default="Codex Desktop/0.133.0-alpha.1 (Windows 10.0.22631; x86_64) unknown (Codex Desktop; 26.519.41501)",
        help="User-Agent header",
    )
    return p.parse_args()


def select_row(record: Dict[str, Any], selector: str) -> bool:
    if selector == "all":
        return True
    if selector == "discontinuous_title":
        labels = record.get("labels", [])
        if not isinstance(labels, list):
            return False
        in_title = [lb.endswith("TITLE") for lb in labels]
        seen_title = False
        seen_gap = False
        for flag in in_title:
            if flag:
                if seen_title and seen_gap:
                    return True
                seen_title = True
            elif seen_title:
                seen_gap = True
        return False
    filename = str(record.get("filename", ""))
    return any(marker in filename for marker in LANG_MARKERS)


def load_rows(path: Path, selector: str) -> tuple[List[Dict[str, Any]], List[Row]]:
    all_records: List[Dict[str, Any]] = []
    selected: List[Row] = []
    with path.open("r", encoding="utf-8") as f:
        for line_no, line in enumerate(f, 1):
            rec = json.loads(line)
            all_records.append(rec)
            if select_row(rec, selector):
                selected.append(Row(line_no=line_no, record=rec))
    return all_records, selected


def parse_model_json(text: str) -> Dict[str, Any]:
    raw = text.strip()
    raw = re.sub(r"^```(?:json)?\s*", "", raw)
    raw = re.sub(r"\s*```$", "", raw)
    return json.loads(raw)


def build_user_payload(batch_rows: Sequence[Row]) -> str:
    rows: List[Dict[str, Any]] = []
    for i, row in enumerate(batch_rows):
        rec = row.record
        rows.append(
            {
                "row_id": i,
                "file_id": rec.get("file_id"),
                "filename": rec.get("filename"),
                "tokens": rec.get("tokens"),
                "current_labels": rec.get("labels"),
            }
        )
    return json.dumps({"rows": rows}, ensure_ascii=False)


def extract_output_text(response_obj: Dict[str, Any]) -> str:
    output = response_obj.get("output", [])
    for item in output:
        for content in item.get("content", []):
            if content.get("type") == "output_text":
                return content.get("text", "")
    raise ValueError("No output_text found in response")


def extract_function_args(response_obj: Dict[str, Any], func_name: str) -> Dict[str, Any]:
    output = response_obj.get("output", [])
    for item in output:
        if item.get("type") == "function_call" and item.get("name") == func_name:
            return json.loads(item.get("arguments", "{}"))
    raise ValueError(f"No function_call '{func_name}' found in response")


def validate_labels(tokens: Sequence[str], labels: Sequence[str]) -> bool:
    if len(tokens) != len(labels):
        return False
    for lb in labels:
        if lb not in ALLOWED_LABELS:
            return False
    return True


def normalize_iob2_labels(labels: Sequence[str]) -> List[str]:
    normalized: List[str] = []
    prev_entity = ""
    for lb in labels:
        if not isinstance(lb, str) or not lb.startswith(("B-", "I-")):
            normalized.append("O")
            prev_entity = ""
            continue
        entity = lb.split("-", 1)[1]
        prefix = "I" if prev_entity == entity else "B"
        normalized.append(f"{prefix}-{entity}")
        prev_entity = entity
    return normalized


def title_segments(labels: Sequence[str]) -> List[tuple[int, int]]:
    segments: List[tuple[int, int]] = []
    i = 0
    n = len(labels)
    while i < n:
        if str(labels[i]).endswith("TITLE"):
            j = i + 1
            while j < n and str(labels[j]).endswith("TITLE"):
                j += 1
            segments.append((i, j))
            i = j
        else:
            i += 1
    return segments


def force_single_title_segment(tokens: Sequence[str], labels: Sequence[str]) -> List[str]:
    """Guarantee TITLE is a single contiguous segment."""
    if len(tokens) != len(labels):
        return list(labels)
    fixed = normalize_iob2_labels(labels)
    segs = title_segments(fixed)
    if len(segs) <= 1:
        return fixed

    first_episode = next((idx for idx, lb in enumerate(fixed) if str(lb).endswith("EPISODE")), len(fixed))

    def score(seg: tuple[int, int]) -> tuple[int, int, int]:
        start, end = seg
        length = end - start
        before_episode = 1 if start < first_episode else 0
        return (before_episode, length, -start)

    keep = max(segs, key=score)
    ks, ke = keep
    out = list(fixed)
    for i in range(len(out)):
        if str(out[i]).endswith("TITLE") and not (ks <= i < ke):
            out[i] = "O"
    out = normalize_iob2_labels(out)
    return out


def force_bracket_delimiters_o(tokens: Sequence[str], labels: Sequence[str]) -> List[str]:
    """Keep standalone bracket delimiters outside entities for clean boundaries."""
    if len(tokens) != len(labels):
        return list(labels)
    fixed = list(labels)
    changed = False
    for idx, token in enumerate(tokens):
        if token in BRACKET_DELIMITER_TOKENS and fixed[idx] != "O":
            fixed[idx] = "O"
            changed = True
    if not changed:
        return list(labels)
    return normalize_iob2_labels(fixed)


def response_schema() -> Dict[str, Any]:
    return {
        "type": "object",
        "additionalProperties": False,
        "properties": {
            "results": {
                "type": "array",
                "items": {
                    "type": "object",
                    "additionalProperties": False,
                    "properties": {
                        "row_id": {"type": "integer"},
                        "labels": {
                            "type": "array",
                            "items": {"type": "string", "enum": sorted(ALLOWED_LABELS)},
                        },
                    },
                    "required": ["row_id", "labels"],
                },
            }
        },
        "required": ["results"],
    }


def append_failure_log(path: str, message: str) -> None:
    p = Path(path)
    p.parent.mkdir(parents=True, exist_ok=True)
    with p.open("a", encoding="utf-8") as f:
        f.write(message.rstrip() + "\n")


def build_request_body(
    model: str,
    user_payload: str,
    prompt_cache_key: str,
    prompt_cache_retention: str,
    reasoning_effort: str,
    include_tools: bool = True,
    include_tool_choice: bool = True,
    include_reasoning: bool = True,
    include_cache_key: bool = True,
    include_cache_retention: bool = True,
) -> Dict[str, Any]:
    body: Dict[str, Any] = {
        "model": model,
        "instructions": SYSTEM_INSTRUCTIONS,
        "input": user_payload,
    }
    if include_cache_key:
        body["prompt_cache_key"] = prompt_cache_key
    if include_cache_retention:
        body["prompt_cache_retention"] = prompt_cache_retention
    if include_reasoning:
        body["reasoning"] = {"effort": reasoning_effort}
    if include_tools:
        body["tools"] = [
            {
                "type": "function",
                "name": "submit_labels",
                "description": "Submit relabeled BIO labels.",
                "parameters": response_schema(),
                "strict": True,
            }
        ]
    if include_tool_choice and include_tools:
        body["tool_choice"] = {"type": "function", "name": "submit_labels"}
    return body


def parse_usage(response_obj: Dict[str, Any]) -> UsageStats:
    usage = response_obj.get("usage", {}) or {}
    in_details = usage.get("input_tokens_details", {}) or {}
    out_details = usage.get("output_tokens_details", {}) or {}
    return UsageStats(
        input_tokens=int(usage.get("input_tokens", 0) or 0),
        output_tokens=int(usage.get("output_tokens", 0) or 0),
        cached_tokens=int(in_details.get("cached_tokens", 0) or 0),
        reasoning_tokens=int(out_details.get("reasoning_tokens", 0) or 0),
    )


def relabel_batch(
    api_base: str,
    api_key: str,
    model: str,
    batch_rows: Sequence[Row],
    prompt_cache_key: str,
    prompt_cache_retention: str,
    reasoning_effort: str,
    user_agent: str,
    retries: int,
    failure_log: str,
    http_timeout: int,
) -> tuple[Dict[int, List[str]], UsageStats]:
    url = f"{api_base.rstrip('/')}/responses"
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json",
        "User-Agent": user_agent,
    }
    user_payload = build_user_payload(batch_rows)

    cfg = {
        "include_tools": True,
        "include_tool_choice": True,
        "include_reasoning": True,
        "include_cache_key": True,
        "include_cache_retention": True,
    }

    last_error: Exception | None = None
    for attempt in range(1, retries + 1):
        try:
            body = build_request_body(
                model=model,
                user_payload=user_payload,
                prompt_cache_key=prompt_cache_key,
                prompt_cache_retention=prompt_cache_retention,
                reasoning_effort=reasoning_effort,
                include_tools=cfg["include_tools"],
                include_tool_choice=cfg["include_tool_choice"],
                include_reasoning=cfg["include_reasoning"],
                include_cache_key=cfg["include_cache_key"],
                include_cache_retention=cfg["include_cache_retention"],
            )
            resp = requests.post(url, headers=headers, json=body, timeout=http_timeout)
            resp.raise_for_status()
            obj = resp.json()
            usage_stats = parse_usage(obj)
            try:
                parsed = extract_function_args(obj, "submit_labels")
            except Exception:
                text = extract_output_text(obj)
                parsed = parse_model_json(text)
            results = parsed.get("results")
            if not isinstance(results, list):
                append_failure_log(
                    failure_log,
                    f"[invalid-results] model={model} batch={len(batch_rows)} parsed_keys={list(parsed.keys())}",
                )
                raise ValueError("response JSON missing 'results' list")

            mapping: Dict[int, List[str]] = {}
            for item in results:
                if not isinstance(item, dict):
                    continue
                row_id = item.get("row_id")
                labels = item.get("labels")
                if not isinstance(row_id, int) or not isinstance(labels, list):
                    continue
                if row_id < 0 or row_id >= len(batch_rows):
                    continue
                tokens = batch_rows[row_id].record.get("tokens", [])
                if not validate_labels(tokens, labels):
                    append_failure_log(
                        failure_log,
                        f"[invalid-labels] file_id={batch_rows[row_id].record.get('file_id')} "
                        f"tokens_len={len(tokens)} labels_len={len(labels)}",
                    )
                    continue
                mapping[row_id] = labels

            if len(mapping) != len(batch_rows):
                missing = sorted(set(range(len(batch_rows))) - set(mapping))
                append_failure_log(
                    failure_log,
                    f"[missing] model={model} batch={len(batch_rows)} missing={missing}",
                )
                raise ValueError(f"incomplete/invalid rows from model: missing={missing}")

            return mapping, usage_stats
        except Exception as exc:  # noqa: BLE001
            last_error = exc
            # Some compatible gateways may not support all optional fields.
            # Downgrade progressively and keep structured tool output whenever possible.
            if isinstance(exc, requests.HTTPError) and exc.response is not None and exc.response.status_code == 400:
                response_text = (exc.response.text or "")[:1200]
                lowered = response_text.lower()
                append_failure_log(
                    failure_log,
                    f"[http400] attempt={attempt} model={model} body_cfg={cfg} response={response_text!r}",
                )
                if "prompt_cache_retention" in lowered and cfg["include_cache_retention"]:
                    cfg["include_cache_retention"] = False
                elif "prompt_cache_key" in lowered and cfg["include_cache_key"]:
                    cfg["include_cache_key"] = False
                elif "reasoning" in lowered and cfg["include_reasoning"]:
                    cfg["include_reasoning"] = False
                elif "tool_choice" in lowered and cfg["include_tool_choice"]:
                    cfg["include_tool_choice"] = False
                elif "tools" in lowered and cfg["include_tools"]:
                    cfg["include_tools"] = False
            if attempt == retries:
                break
            time.sleep(0.8 * attempt)

    raise RuntimeError(f"failed relabel batch after {retries} attempts: {last_error}")


def write_jsonl(path: Path, records: Sequence[Dict[str, Any]]) -> None:
    tmp = path.with_suffix(path.suffix + ".tmp")
    with tmp.open("w", encoding="utf-8", newline="") as f:
        for rec in records:
            f.write(json.dumps(rec, ensure_ascii=False, separators=(",", ":")) + "\n")
    tmp.replace(path)


def process_batch_with_fallback(
    api_base: str,
    api_key: str,
    model: str,
    batch: Sequence[Row],
    prompt_cache_key: str,
    prompt_cache_retention: str,
    reasoning_effort: str,
    user_agent: str,
    retries: int,
    failure_log: str,
    http_timeout: int,
) -> List[tuple[Row, List[str]]]:
    usage_total = UsageStats()
    try:
        mapping, usage = relabel_batch(
            api_base=api_base,
            api_key=api_key,
            model=model,
            batch_rows=batch,
            prompt_cache_key=prompt_cache_key,
            prompt_cache_retention=prompt_cache_retention,
            reasoning_effort=reasoning_effort,
            user_agent=user_agent,
            retries=retries,
            failure_log=failure_log,
            http_timeout=http_timeout,
        )
        usage_total.add(usage)
    except RuntimeError:
        mapping = {}
        for idx, row in enumerate(batch):
            try:
                single, usage = relabel_batch(
                    api_base=api_base,
                    api_key=api_key,
                    model=model,
                    batch_rows=[row],
                    prompt_cache_key=prompt_cache_key,
                    prompt_cache_retention=prompt_cache_retention,
                    reasoning_effort=reasoning_effort,
                    user_agent=user_agent,
                    retries=max(retries, 4),
                    failure_log=failure_log,
                    http_timeout=http_timeout,
                )
                usage_total.add(usage)
                mapping[idx] = single[0]
            except RuntimeError as exc:
                append_failure_log(
                    failure_log,
                    f"[row-skip] file_id={row.record.get('file_id')} line={row.line_no} reason={exc}",
                )
                # Hard fallback: enforce contiguous TITLE rather than keeping polluted labels.
                toks = row.record.get("tokens", [])
                lbs = row.record.get("labels", [])
                if isinstance(toks, list) and isinstance(lbs, list) and len(toks) == len(lbs):
                    mapping[idx] = force_single_title_segment(toks, lbs)
                else:
                    mapping[idx] = lbs

    updates: List[tuple[Row, List[str]]] = []
    for row_id, labels in mapping.items():
        row = batch[row_id]
        rec = dict(row.record)
        tokens = rec.get("tokens", [])
        rec["labels"] = force_single_title_segment(tokens, labels)
        repaired, _repairs = repair_jsonl_item(rec)
        new_labels = repaired.get("labels", rec.get("labels", []))
        new_labels = force_bracket_delimiters_o(tokens, new_labels)
        updates.append((row, new_labels))
    return updates, usage_total


def process_batch_timed(
    meter: ConcurrentMeter,
    api_base: str,
    api_key: str,
    model: str,
    batch: Sequence[Row],
    prompt_cache_key: str,
    prompt_cache_retention: str,
    reasoning_effort: str,
    user_agent: str,
    retries: int,
    failure_log: str,
    http_timeout: int,
) -> Dict[str, Any]:
    meter.task_start()
    t0 = time.time()
    try:
        updates, usage = process_batch_with_fallback(
            api_base=api_base,
            api_key=api_key,
            model=model,
            batch=batch,
            prompt_cache_key=prompt_cache_key,
            prompt_cache_retention=prompt_cache_retention,
            reasoning_effort=reasoning_effort,
            user_agent=user_agent,
            retries=retries,
            failure_log=failure_log,
            http_timeout=http_timeout,
        )
        return {
            "updates": updates,
            "elapsed": time.time() - t0,
            "batch_size": len(batch),
            "usage": usage,
        }
    finally:
        meter.task_end()


def main() -> None:
    args = parse_args()
    api_key = args.api_key or os.environ.get("ANIFILEBERT_RELABEL_API_KEY")
    if not api_key:
        raise SystemExit("Missing API key. Use --api-key or env ANIFILEBERT_RELABEL_API_KEY")

    input_path = Path(args.input)
    output_path = Path(args.output)

    all_records, selected_rows = load_rows(input_path, args.selector)
    if args.min_token_len > 0 or args.max_token_len > 0:
        filtered: List[Row] = []
        for row in selected_rows:
            tok_len = len(row.record.get("tokens", []))
            if tok_len < args.min_token_len:
                continue
            if args.max_token_len > 0 and tok_len > args.max_token_len:
                continue
            filtered.append(row)
        selected_rows = filtered
    if args.sort_by == "token_len_asc":
        selected_rows.sort(key=lambda r: len(r.record.get("tokens", [])))
    if args.skip_selected > 0:
        selected_rows = selected_rows[args.skip_selected:]
    if args.max_rows > 0:
        selected_rows = selected_rows[: args.max_rows]
    if not selected_rows:
        print("selected_rows=0; nothing to do")
        if output_path != input_path:
            write_jsonl(output_path, all_records)
        return

    total = len(selected_rows)
    changed = 0
    concurrency = max(1, min(args.concurrency, 8))
    batches: List[List[Row]] = [
        selected_rows[i:i + args.batch_size]
        for i in range(0, total, args.batch_size)
    ]

    done_rows = 0
    wall_start = time.time()
    meter = ConcurrentMeter()
    total_batch_elapsed = 0.0
    completed_batches = 0
    usage_total = UsageStats()
    with ThreadPoolExecutor(max_workers=concurrency) as executor:
        futures = [
            executor.submit(
                process_batch_timed,
                meter,
                api_base=args.api_base,
                api_key=api_key,
                model=args.model,
                batch=batch,
                prompt_cache_key=args.prompt_cache_key,
                prompt_cache_retention=args.prompt_cache_retention,
                reasoning_effort=args.reasoning_effort,
                user_agent=args.user_agent,
                retries=args.retries,
                failure_log=args.failure_log,
                http_timeout=args.http_timeout,
            )
            for batch in batches
        ]
        for fut in as_completed(futures):
            result = fut.result()
            updates = result["updates"]
            total_batch_elapsed += float(result["elapsed"])
            completed_batches += 1
            usage_total.add(result["usage"])
            for row, new_labels in updates:
                rec = row.record
                if rec.get("labels") != new_labels:
                    rec["labels"] = new_labels
                    changed += 1
            done_rows += len(updates)
            snap = meter.snapshot()
            wall_elapsed = max(1e-9, snap["timestamp"] - wall_start)
            rows_per_sec = done_rows / wall_elapsed
            avg_active = snap["active_time_accum"] / wall_elapsed
            in_tok_per_sec = usage_total.input_tokens / wall_elapsed
            out_tok_per_sec = usage_total.output_tokens / wall_elapsed
            hourly_usd = 0.0
            if args.usd_per_1m_input > 0 or args.usd_per_1m_output > 0:
                cost = (usage_total.input_tokens / 1_000_000.0) * args.usd_per_1m_input + (
                    usage_total.output_tokens / 1_000_000.0
                ) * args.usd_per_1m_output
                hourly_usd = cost / wall_elapsed * 3600.0
            print(
                f"processed={done_rows}/{total} changed={changed} "
                f"rows_per_sec={rows_per_sec:.2f} active_now={int(snap['current_active'])} "
                f"avg_active={avg_active:.2f} max_active={int(snap['max_active'])}/{concurrency} "
                f"in_tok_s={in_tok_per_sec:.1f} out_tok_s={out_tok_per_sec:.1f} usd_h={hourly_usd:.3f}"
            )
            if args.checkpoint_rows > 0 and (done_rows % args.checkpoint_rows == 0 or done_rows == total):
                write_jsonl(output_path, all_records)
            if args.sleep_ms > 0:
                time.sleep(args.sleep_ms / 1000.0)

    # rows in selected_rows reference dicts in all_records by identity, so changes are already reflected.
    write_jsonl(output_path, all_records)
    wall_total = time.time() - wall_start
    final_snap = meter.snapshot()
    avg_active = final_snap["active_time_accum"] / max(1e-9, wall_total)
    perf_summary = {
        "wall_seconds": wall_total,
        "rows_processed": done_rows,
        "rows_per_second": done_rows / max(1e-9, wall_total),
        "batches_completed": completed_batches,
        "avg_batch_seconds": total_batch_elapsed / max(1, completed_batches),
        "avg_active_workers": avg_active,
        "max_active_workers": int(final_snap["max_active"]),
        "configured_workers": concurrency,
        "input_tokens": usage_total.input_tokens,
        "output_tokens": usage_total.output_tokens,
        "cached_tokens": usage_total.cached_tokens,
        "reasoning_tokens": usage_total.reasoning_tokens,
        "input_tokens_per_sec": usage_total.input_tokens / max(1e-9, wall_total),
        "output_tokens_per_sec": usage_total.output_tokens / max(1e-9, wall_total),
        "input_tokens_per_hour": usage_total.input_tokens / max(1e-9, wall_total) * 3600.0,
        "output_tokens_per_hour": usage_total.output_tokens / max(1e-9, wall_total) * 3600.0,
        "usd_per_1m_input": args.usd_per_1m_input,
        "usd_per_1m_output": args.usd_per_1m_output,
    }
    if args.usd_per_1m_input > 0 or args.usd_per_1m_output > 0:
        total_cost = (usage_total.input_tokens / 1_000_000.0) * args.usd_per_1m_input + (
            usage_total.output_tokens / 1_000_000.0
        ) * args.usd_per_1m_output
        perf_summary["usd_total"] = total_cost
        perf_summary["usd_per_hour"] = total_cost / max(1e-9, wall_total) * 3600.0
    if args.perf_log:
        p = Path(args.perf_log)
        p.parent.mkdir(parents=True, exist_ok=True)
        p.write_text(json.dumps(perf_summary, ensure_ascii=False, indent=2), encoding="utf-8")
    print(
        f"perf wall={wall_total:.1f}s rows_per_sec={perf_summary['rows_per_second']:.2f} "
        f"avg_active={avg_active:.2f} max_active={int(final_snap['max_active'])}/{concurrency} "
        f"in_tok_s={perf_summary['input_tokens_per_sec']:.1f} out_tok_s={perf_summary['output_tokens_per_sec']:.1f}"
    )
    print(f"done selected_rows={total} changed_rows={changed} output={output_path}")


if __name__ == "__main__":
    main()