ModerRAS commited on
Commit
e964ae5
·
1 Parent(s): 1804249

feat: add tool-call based llm relabel pipeline and update dataset pointer

Browse files
Files changed (2) hide show
  1. datasets/AnimeName +1 -1
  2. tools/llm_relabel_rows.py +444 -0
datasets/AnimeName CHANGED
@@ -1 +1 @@
1
- Subproject commit 56c54f9fb664335fc0c98f6c9dce8f2fbcc145a0
 
1
+ Subproject commit 9987cc8d7b7bf829d0022ee6e6a0b08de5327975
tools/llm_relabel_rows.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Relabel selected rows in a JSONL dataset via an OpenAI-compatible Responses API.
4
+
5
+ Designed for high-throughput cleanup with a stable prompt prefix and
6
+ `prompt_cache_key` to improve cache hit rates across calls.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ from concurrent.futures import ThreadPoolExecutor, as_completed
13
+ import json
14
+ import os
15
+ import re
16
+ import time
17
+ from dataclasses import dataclass
18
+ from pathlib import Path
19
+ from typing import Any, Dict, List, Sequence
20
+
21
+ import requests
22
+
23
+
24
+ ALLOWED_LABELS = {
25
+ "O",
26
+ "B-TITLE", "I-TITLE",
27
+ "B-SEASON", "I-SEASON",
28
+ "B-EPISODE", "I-EPISODE",
29
+ "B-SPECIAL", "I-SPECIAL",
30
+ "B-GROUP", "I-GROUP",
31
+ "B-RESOLUTION", "I-RESOLUTION",
32
+ "B-SOURCE", "I-SOURCE",
33
+ }
34
+
35
+ LANG_MARKERS = (
36
+ "中文版",
37
+ "日语版",
38
+ "国语版",
39
+ "粤语版",
40
+ "英语版",
41
+ "英配版",
42
+ "中配版",
43
+ "日配版",
44
+ )
45
+
46
+ SYSTEM_INSTRUCTIONS = """You relabel anime filename tokens with BIO tags.
47
+
48
+ Allowed labels only:
49
+ O, B/I-TITLE, B/I-SEASON, B/I-EPISODE, B/I-SPECIAL, B/I-GROUP, B/I-RESOLUTION, B/I-SOURCE.
50
+
51
+ Hard rules:
52
+ 1) Output exactly one label per token.
53
+ 2) Language markers like 中文版/日语版/国语版/粤语版/英语版/英配版/中配版/日配版 must be SOURCE.
54
+ 3) Episode identifiers (e.g. 01, 13, EP13, 第13集/話/话) must be EPISODE.
55
+ 4) If title already appears before episode number, episode-name text after the episode number should be O (not TITLE).
56
+ 5) Preserve obvious GROUP/RESOLUTION/SOURCE tags when present.
57
+
58
+ Return strict JSON only:
59
+ {"results":[{"row_id":int,"labels":[str,...]}]}
60
+ No markdown. No explanation.
61
+ """
62
+
63
+
64
+ @dataclass
65
+ class Row:
66
+ line_no: int
67
+ record: Dict[str, Any]
68
+
69
+
70
+ def parse_args() -> argparse.Namespace:
71
+ p = argparse.ArgumentParser(description="Relabel selected JSONL rows via Responses API")
72
+ p.add_argument("--input", required=True, help="Input JSONL")
73
+ p.add_argument("--output", required=True, help="Output JSONL (can equal input)")
74
+ p.add_argument("--api-base", required=True, help="API base URL, e.g. http://host:port/v1")
75
+ p.add_argument("--api-key", default=None, help="API key; falls back to env ANIFILEBERT_RELABEL_API_KEY")
76
+ p.add_argument("--model", default="gpt-5.4-mini", help="Model name")
77
+ p.add_argument(
78
+ "--selector",
79
+ choices=("language", "discontinuous_title", "all"),
80
+ default="language",
81
+ help="Row selector",
82
+ )
83
+ p.add_argument("--batch-size", type=int, default=12, help="Rows per request")
84
+ p.add_argument("--concurrency", type=int, default=4, help="Parallel request workers")
85
+ p.add_argument("--max-rows", type=int, default=0, help="Optional cap; 0 means no cap")
86
+ p.add_argument("--skip-selected", type=int, default=0, help="Skip this many selected rows before processing")
87
+ p.add_argument("--retries", type=int, default=3, help="Retries per batch")
88
+ p.add_argument("--sleep-ms", type=int, default=150, help="Delay between successful calls")
89
+ p.add_argument("--prompt-cache-key", default="anifilebert-relabel-v1", help="Stable prompt cache key")
90
+ p.add_argument("--prompt-cache-retention", default="24h", help="Prompt cache retention hint")
91
+ p.add_argument("--reasoning-effort", default="medium", help="Reasoning effort (e.g. low/medium/high)")
92
+ p.add_argument("--checkpoint-rows", type=int, default=100, help="Write checkpoint every N processed rows")
93
+ p.add_argument("--failure-log", default="reports/llm_relabel_failures.log", help="Failure log path")
94
+ p.add_argument(
95
+ "--user-agent",
96
+ default="Codex Desktop/0.133.0-alpha.1 (Windows 10.0.22631; x86_64) unknown (Codex Desktop; 26.519.41501)",
97
+ help="User-Agent header",
98
+ )
99
+ return p.parse_args()
100
+
101
+
102
+ def select_row(record: Dict[str, Any], selector: str) -> bool:
103
+ if selector == "all":
104
+ return True
105
+ if selector == "discontinuous_title":
106
+ labels = record.get("labels", [])
107
+ if not isinstance(labels, list):
108
+ return False
109
+ in_title = [lb.endswith("TITLE") for lb in labels]
110
+ seen_title = False
111
+ seen_gap = False
112
+ for flag in in_title:
113
+ if flag:
114
+ if seen_title and seen_gap:
115
+ return True
116
+ seen_title = True
117
+ elif seen_title:
118
+ seen_gap = True
119
+ return False
120
+ filename = str(record.get("filename", ""))
121
+ return any(marker in filename for marker in LANG_MARKERS)
122
+
123
+
124
+ def load_rows(path: Path, selector: str) -> tuple[List[Dict[str, Any]], List[Row]]:
125
+ all_records: List[Dict[str, Any]] = []
126
+ selected: List[Row] = []
127
+ with path.open("r", encoding="utf-8") as f:
128
+ for line_no, line in enumerate(f, 1):
129
+ rec = json.loads(line)
130
+ all_records.append(rec)
131
+ if select_row(rec, selector):
132
+ selected.append(Row(line_no=line_no, record=rec))
133
+ return all_records, selected
134
+
135
+
136
+ def parse_model_json(text: str) -> Dict[str, Any]:
137
+ raw = text.strip()
138
+ raw = re.sub(r"^```(?:json)?\s*", "", raw)
139
+ raw = re.sub(r"\s*```$", "", raw)
140
+ return json.loads(raw)
141
+
142
+
143
+ def build_user_payload(batch_rows: Sequence[Row]) -> str:
144
+ rows: List[Dict[str, Any]] = []
145
+ for i, row in enumerate(batch_rows):
146
+ rec = row.record
147
+ rows.append(
148
+ {
149
+ "row_id": i,
150
+ "file_id": rec.get("file_id"),
151
+ "filename": rec.get("filename"),
152
+ "tokens": rec.get("tokens"),
153
+ "current_labels": rec.get("labels"),
154
+ }
155
+ )
156
+ return json.dumps({"rows": rows}, ensure_ascii=False)
157
+
158
+
159
+ def extract_output_text(response_obj: Dict[str, Any]) -> str:
160
+ output = response_obj.get("output", [])
161
+ for item in output:
162
+ for content in item.get("content", []):
163
+ if content.get("type") == "output_text":
164
+ return content.get("text", "")
165
+ raise ValueError("No output_text found in response")
166
+
167
+
168
+ def extract_function_args(response_obj: Dict[str, Any], func_name: str) -> Dict[str, Any]:
169
+ output = response_obj.get("output", [])
170
+ for item in output:
171
+ if item.get("type") == "function_call" and item.get("name") == func_name:
172
+ return json.loads(item.get("arguments", "{}"))
173
+ raise ValueError(f"No function_call '{func_name}' found in response")
174
+
175
+
176
+ def validate_labels(tokens: Sequence[str], labels: Sequence[str]) -> bool:
177
+ if len(tokens) != len(labels):
178
+ return False
179
+ for lb in labels:
180
+ if lb not in ALLOWED_LABELS:
181
+ return False
182
+ return True
183
+
184
+
185
+ def response_schema() -> Dict[str, Any]:
186
+ return {
187
+ "type": "object",
188
+ "additionalProperties": False,
189
+ "properties": {
190
+ "results": {
191
+ "type": "array",
192
+ "items": {
193
+ "type": "object",
194
+ "additionalProperties": False,
195
+ "properties": {
196
+ "row_id": {"type": "integer"},
197
+ "labels": {
198
+ "type": "array",
199
+ "items": {"type": "string", "enum": sorted(ALLOWED_LABELS)},
200
+ },
201
+ },
202
+ "required": ["row_id", "labels"],
203
+ },
204
+ }
205
+ },
206
+ "required": ["results"],
207
+ }
208
+
209
+
210
+ def append_failure_log(path: str, message: str) -> None:
211
+ p = Path(path)
212
+ p.parent.mkdir(parents=True, exist_ok=True)
213
+ with p.open("a", encoding="utf-8") as f:
214
+ f.write(message.rstrip() + "\n")
215
+
216
+
217
+ def relabel_batch(
218
+ api_base: str,
219
+ api_key: str,
220
+ model: str,
221
+ batch_rows: Sequence[Row],
222
+ prompt_cache_key: str,
223
+ prompt_cache_retention: str,
224
+ reasoning_effort: str,
225
+ user_agent: str,
226
+ retries: int,
227
+ failure_log: str,
228
+ ) -> Dict[int, List[str]]:
229
+ url = f"{api_base.rstrip('/')}/responses"
230
+ headers = {
231
+ "Authorization": f"Bearer {api_key}",
232
+ "Content-Type": "application/json",
233
+ "User-Agent": user_agent,
234
+ }
235
+ user_payload = build_user_payload(batch_rows)
236
+
237
+ body = {
238
+ "model": model,
239
+ "instructions": SYSTEM_INSTRUCTIONS,
240
+ "input": user_payload,
241
+ "prompt_cache_key": prompt_cache_key,
242
+ "prompt_cache_retention": prompt_cache_retention,
243
+ "reasoning": {"effort": reasoning_effort},
244
+ "tools": [
245
+ {
246
+ "type": "function",
247
+ "name": "submit_labels",
248
+ "description": "Submit relabeled BIO labels.",
249
+ "parameters": response_schema(),
250
+ "strict": True,
251
+ }
252
+ ],
253
+ "tool_choice": {"type": "function", "name": "submit_labels"},
254
+ }
255
+
256
+ last_error: Exception | None = None
257
+ for attempt in range(1, retries + 1):
258
+ try:
259
+ resp = requests.post(url, headers=headers, json=body, timeout=120)
260
+ resp.raise_for_status()
261
+ obj = resp.json()
262
+ try:
263
+ parsed = extract_function_args(obj, "submit_labels")
264
+ except Exception:
265
+ text = extract_output_text(obj)
266
+ parsed = parse_model_json(text)
267
+ results = parsed.get("results")
268
+ if not isinstance(results, list):
269
+ append_failure_log(
270
+ failure_log,
271
+ f"[invalid-results] model={model} batch={len(batch_rows)} parsed_keys={list(parsed.keys())}",
272
+ )
273
+ raise ValueError("response JSON missing 'results' list")
274
+
275
+ mapping: Dict[int, List[str]] = {}
276
+ for item in results:
277
+ if not isinstance(item, dict):
278
+ continue
279
+ row_id = item.get("row_id")
280
+ labels = item.get("labels")
281
+ if not isinstance(row_id, int) or not isinstance(labels, list):
282
+ continue
283
+ if row_id < 0 or row_id >= len(batch_rows):
284
+ continue
285
+ tokens = batch_rows[row_id].record.get("tokens", [])
286
+ if not validate_labels(tokens, labels):
287
+ append_failure_log(
288
+ failure_log,
289
+ f"[invalid-labels] file_id={batch_rows[row_id].record.get('file_id')} "
290
+ f"tokens_len={len(tokens)} labels_len={len(labels)}",
291
+ )
292
+ continue
293
+ mapping[row_id] = labels
294
+
295
+ if len(mapping) != len(batch_rows):
296
+ missing = sorted(set(range(len(batch_rows))) - set(mapping))
297
+ append_failure_log(
298
+ failure_log,
299
+ f"[missing] model={model} batch={len(batch_rows)} missing={missing}",
300
+ )
301
+ raise ValueError(f"incomplete/invalid rows from model: missing={missing}")
302
+
303
+ return mapping
304
+ except Exception as exc: # noqa: BLE001
305
+ last_error = exc
306
+ # Some compatible gateways may not support prompt caching or reasoning fields.
307
+ if isinstance(exc, requests.HTTPError) and exc.response is not None and exc.response.status_code == 400:
308
+ body.pop("prompt_cache_retention", None)
309
+ body.pop("reasoning", None)
310
+ body.pop("tools", None)
311
+ body.pop("tool_choice", None)
312
+ if attempt == retries:
313
+ break
314
+ time.sleep(0.8 * attempt)
315
+
316
+ raise RuntimeError(f"failed relabel batch after {retries} attempts: {last_error}")
317
+
318
+
319
+ def write_jsonl(path: Path, records: Sequence[Dict[str, Any]]) -> None:
320
+ tmp = path.with_suffix(path.suffix + ".tmp")
321
+ with tmp.open("w", encoding="utf-8", newline="") as f:
322
+ for rec in records:
323
+ f.write(json.dumps(rec, ensure_ascii=False, separators=(",", ":")) + "\n")
324
+ tmp.replace(path)
325
+
326
+
327
+ def process_batch_with_fallback(
328
+ api_base: str,
329
+ api_key: str,
330
+ model: str,
331
+ batch: Sequence[Row],
332
+ prompt_cache_key: str,
333
+ prompt_cache_retention: str,
334
+ reasoning_effort: str,
335
+ user_agent: str,
336
+ retries: int,
337
+ failure_log: str,
338
+ ) -> List[tuple[Row, List[str]]]:
339
+ try:
340
+ mapping = relabel_batch(
341
+ api_base=api_base,
342
+ api_key=api_key,
343
+ model=model,
344
+ batch_rows=batch,
345
+ prompt_cache_key=prompt_cache_key,
346
+ prompt_cache_retention=prompt_cache_retention,
347
+ reasoning_effort=reasoning_effort,
348
+ user_agent=user_agent,
349
+ retries=retries,
350
+ failure_log=failure_log,
351
+ )
352
+ except RuntimeError:
353
+ mapping = {}
354
+ for idx, row in enumerate(batch):
355
+ try:
356
+ single = relabel_batch(
357
+ api_base=api_base,
358
+ api_key=api_key,
359
+ model=model,
360
+ batch_rows=[row],
361
+ prompt_cache_key=prompt_cache_key,
362
+ prompt_cache_retention=prompt_cache_retention,
363
+ reasoning_effort=reasoning_effort,
364
+ user_agent=user_agent,
365
+ retries=max(retries, 4),
366
+ failure_log=failure_log,
367
+ )
368
+ mapping[idx] = single[0]
369
+ except RuntimeError as exc:
370
+ append_failure_log(
371
+ failure_log,
372
+ f"[row-skip] file_id={row.record.get('file_id')} line={row.line_no} reason={exc}",
373
+ )
374
+ mapping[idx] = row.record.get("labels", [])
375
+ return [(batch[row_id], labels) for row_id, labels in mapping.items()]
376
+
377
+
378
+ def main() -> None:
379
+ args = parse_args()
380
+ api_key = args.api_key or os.environ.get("ANIFILEBERT_RELABEL_API_KEY")
381
+ if not api_key:
382
+ raise SystemExit("Missing API key. Use --api-key or env ANIFILEBERT_RELABEL_API_KEY")
383
+
384
+ input_path = Path(args.input)
385
+ output_path = Path(args.output)
386
+
387
+ all_records, selected_rows = load_rows(input_path, args.selector)
388
+ if args.skip_selected > 0:
389
+ selected_rows = selected_rows[args.skip_selected:]
390
+ if args.max_rows > 0:
391
+ selected_rows = selected_rows[: args.max_rows]
392
+ if not selected_rows:
393
+ print("selected_rows=0; nothing to do")
394
+ if output_path != input_path:
395
+ write_jsonl(output_path, all_records)
396
+ return
397
+
398
+ total = len(selected_rows)
399
+ changed = 0
400
+ concurrency = max(1, min(args.concurrency, 8))
401
+ batches: List[List[Row]] = [
402
+ selected_rows[i:i + args.batch_size]
403
+ for i in range(0, total, args.batch_size)
404
+ ]
405
+
406
+ done_rows = 0
407
+ with ThreadPoolExecutor(max_workers=concurrency) as executor:
408
+ futures = [
409
+ executor.submit(
410
+ process_batch_with_fallback,
411
+ api_base=args.api_base,
412
+ api_key=api_key,
413
+ model=args.model,
414
+ batch=batch,
415
+ prompt_cache_key=args.prompt_cache_key,
416
+ prompt_cache_retention=args.prompt_cache_retention,
417
+ reasoning_effort=args.reasoning_effort,
418
+ user_agent=args.user_agent,
419
+ retries=args.retries,
420
+ failure_log=args.failure_log,
421
+ )
422
+ for batch in batches
423
+ ]
424
+ for fut in as_completed(futures):
425
+ updates = fut.result()
426
+ for row, new_labels in updates:
427
+ rec = row.record
428
+ if rec.get("labels") != new_labels:
429
+ rec["labels"] = new_labels
430
+ changed += 1
431
+ done_rows += len(updates)
432
+ print(f"processed={done_rows}/{total} changed={changed}")
433
+ if args.checkpoint_rows > 0 and (done_rows % args.checkpoint_rows == 0 or done_rows == total):
434
+ write_jsonl(output_path, all_records)
435
+ if args.sleep_ms > 0:
436
+ time.sleep(args.sleep_ms / 1000.0)
437
+
438
+ # rows in selected_rows reference dicts in all_records by identity, so changes are already reflected.
439
+ write_jsonl(output_path, all_records)
440
+ print(f"done selected_rows={total} changed_rows={changed} output={output_path}")
441
+
442
+
443
+ if __name__ == "__main__":
444
+ main()