ModerRAS commited on
Commit
fed9d99
·
1 Parent(s): 8165cc3

Add robust LLM relabel pipeline and enforce contiguous title

Browse files
datasets/AnimeName CHANGED
@@ -1 +1 @@
1
- Subproject commit 5de6ddeed7dafd43207953072a9e197f13b32077
 
1
+ Subproject commit ad48d8da74cf8e611a14f22ffc2a9734872e1f03
tools/enforce_contiguous_title.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Enforce a single contiguous TITLE span for every JSONL row.
4
+
5
+ This script is deterministic and streaming-friendly for very large datasets.
6
+ It is intended as a hard safety pass before/alongside LLM relabeling.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ import json
13
+ from pathlib import Path
14
+ from typing import Dict, List, Sequence, Tuple
15
+
16
+ from anifilebert.label_repairs import repair_jsonl_item
17
+
18
+
19
+ def parse_args() -> argparse.Namespace:
20
+ parser = argparse.ArgumentParser(description="Force contiguous TITLE spans in JSONL labels")
21
+ parser.add_argument("--input", required=True, help="Input JSONL")
22
+ parser.add_argument("--output", required=True, help="Output JSONL")
23
+ parser.add_argument("--manifest-output", default="", help="Optional manifest JSON")
24
+ parser.add_argument("--progress", type=int, default=50000, help="Progress print interval")
25
+ return parser.parse_args()
26
+
27
+
28
+ def normalize_iob2(labels: Sequence[str]) -> List[str]:
29
+ out: List[str] = []
30
+ prev = ""
31
+ for lb in labels:
32
+ if not isinstance(lb, str) or not lb.startswith(("B-", "I-")):
33
+ out.append("O")
34
+ prev = ""
35
+ continue
36
+ entity = lb.split("-", 1)[1]
37
+ prefix = "I" if prev == entity else "B"
38
+ out.append(f"{prefix}-{entity}")
39
+ prev = entity
40
+ return out
41
+
42
+
43
+ def is_discontinuous_title(labels: Sequence[str]) -> bool:
44
+ seen_title = False
45
+ seen_gap = False
46
+ for lb in labels:
47
+ is_title = isinstance(lb, str) and lb.endswith("TITLE")
48
+ if is_title:
49
+ if seen_title and seen_gap:
50
+ return True
51
+ seen_title = True
52
+ elif seen_title:
53
+ seen_gap = True
54
+ return False
55
+
56
+
57
+ def title_segments(labels: Sequence[str]) -> List[Tuple[int, int]]:
58
+ segs: List[Tuple[int, int]] = []
59
+ i = 0
60
+ n = len(labels)
61
+ while i < n:
62
+ if str(labels[i]).endswith("TITLE"):
63
+ j = i + 1
64
+ while j < n and str(labels[j]).endswith("TITLE"):
65
+ j += 1
66
+ segs.append((i, j))
67
+ i = j
68
+ else:
69
+ i += 1
70
+ return segs
71
+
72
+
73
+ def first_episode_or_special_index(labels: Sequence[str]) -> int:
74
+ for idx, lb in enumerate(labels):
75
+ text = str(lb)
76
+ if text.endswith("EPISODE") or text.endswith("SPECIAL"):
77
+ return idx
78
+ return len(labels)
79
+
80
+
81
+ def pick_primary_title_segment(labels: Sequence[str], segs: Sequence[Tuple[int, int]]) -> Tuple[int, int]:
82
+ if not segs:
83
+ return (-1, -1)
84
+ bound = first_episode_or_special_index(labels)
85
+ before = [seg for seg in segs if seg[0] < bound]
86
+ # Prefer the earliest title span before episode/special boundary.
87
+ if before:
88
+ return min(before, key=lambda seg: seg[0])
89
+ return min(segs, key=lambda seg: seg[0])
90
+
91
+
92
+ def enforce_contiguous_title(labels: Sequence[str]) -> List[str]:
93
+ fixed = normalize_iob2(labels)
94
+ segs = title_segments(fixed)
95
+ if len(segs) <= 1:
96
+ return fixed
97
+ keep_start, keep_end = pick_primary_title_segment(fixed, segs)
98
+ if keep_start < 0:
99
+ return fixed
100
+
101
+ out = list(fixed)
102
+ for idx, lb in enumerate(out):
103
+ if str(lb).endswith("TITLE") and not (keep_start <= idx < keep_end):
104
+ out[idx] = "O"
105
+ return normalize_iob2(out)
106
+
107
+
108
+ def main() -> None:
109
+ args = parse_args()
110
+ input_path = Path(args.input)
111
+ output_path = Path(args.output)
112
+ manifest_path = Path(args.manifest_output) if args.manifest_output else output_path.with_suffix(".contiguous_title.manifest.json")
113
+ output_path.parent.mkdir(parents=True, exist_ok=True)
114
+ manifest_path.parent.mkdir(parents=True, exist_ok=True)
115
+
116
+ rows = 0
117
+ changed_rows = 0
118
+ bad_before = 0
119
+ bad_after = 0
120
+ invalid_rows = 0
121
+
122
+ tmp_path = output_path.with_suffix(output_path.suffix + ".tmp")
123
+ with input_path.open("r", encoding="utf-8") as src, tmp_path.open("w", encoding="utf-8", newline="\n") as dst:
124
+ for line in src:
125
+ line = line.rstrip("\n")
126
+ if not line:
127
+ continue
128
+ rows += 1
129
+ rec = json.loads(line)
130
+ tokens = rec.get("tokens", [])
131
+ labels = rec.get("labels", [])
132
+ if not isinstance(tokens, list) or not isinstance(labels, list) or len(tokens) != len(labels):
133
+ invalid_rows += 1
134
+ dst.write(json.dumps(rec, ensure_ascii=False, separators=(",", ":")) + "\n")
135
+ continue
136
+
137
+ if is_discontinuous_title(labels):
138
+ bad_before += 1
139
+
140
+ new_labels = enforce_contiguous_title(labels)
141
+ out_rec: Dict = dict(rec)
142
+ out_rec["labels"] = new_labels
143
+ repaired, _ = repair_jsonl_item(out_rec)
144
+ out_labels = repaired.get("labels", new_labels)
145
+ if is_discontinuous_title(out_labels):
146
+ bad_after += 1
147
+
148
+ if out_labels != labels:
149
+ changed_rows += 1
150
+ repaired["labels"] = out_labels
151
+ dst.write(json.dumps(repaired, ensure_ascii=False, separators=(",", ":")) + "\n")
152
+
153
+ if args.progress > 0 and rows % args.progress == 0:
154
+ print(
155
+ f"rows={rows} changed={changed_rows} "
156
+ f"bad_before={bad_before} bad_after={bad_after} invalid={invalid_rows}"
157
+ )
158
+
159
+ tmp_path.replace(output_path)
160
+
161
+ manifest = {
162
+ "input": str(input_path),
163
+ "output": str(output_path),
164
+ "rows": rows,
165
+ "changed_rows": changed_rows,
166
+ "discontinuous_before": bad_before,
167
+ "discontinuous_after": bad_after,
168
+ "invalid_rows": invalid_rows,
169
+ }
170
+ manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2), encoding="utf-8")
171
+ print(json.dumps(manifest, ensure_ascii=False, indent=2))
172
+
173
+
174
+ if __name__ == "__main__":
175
+ main()
176
+
tools/llm_relabel_rows.py CHANGED
@@ -20,6 +20,7 @@ from pathlib import Path
20
  from typing import Any, Dict, List, Sequence
21
 
22
  import requests
 
23
 
24
 
25
  ALLOWED_LABELS = {
@@ -151,6 +152,7 @@ def parse_args() -> argparse.Namespace:
151
  p.add_argument("--checkpoint-rows", type=int, default=100, help="Write checkpoint every N processed rows")
152
  p.add_argument("--failure-log", default="reports/llm_relabel_failures.log", help="Failure log path")
153
  p.add_argument("--perf-log", default="", help="Optional JSON perf summary path")
 
154
  p.add_argument("--usd-per-1m-input", type=float, default=0.75, help="Input token price (USD per 1M tokens)")
155
  p.add_argument("--usd-per-1m-output", type=float, default=4.5, help="Output token price (USD per 1M tokens)")
156
  p.add_argument(
@@ -244,6 +246,64 @@ def validate_labels(tokens: Sequence[str], labels: Sequence[str]) -> bool:
244
  return True
245
 
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  def response_schema() -> Dict[str, Any]:
248
  return {
249
  "type": "object",
@@ -276,6 +336,44 @@ def append_failure_log(path: str, message: str) -> None:
276
  f.write(message.rstrip() + "\n")
277
 
278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  def parse_usage(response_obj: Dict[str, Any]) -> UsageStats:
280
  usage = response_obj.get("usage", {}) or {}
281
  in_details = usage.get("input_tokens_details", {}) or {}
@@ -299,6 +397,7 @@ def relabel_batch(
299
  user_agent: str,
300
  retries: int,
301
  failure_log: str,
 
302
  ) -> tuple[Dict[int, List[str]], UsageStats]:
303
  url = f"{api_base.rstrip('/')}/responses"
304
  headers = {
@@ -308,29 +407,30 @@ def relabel_batch(
308
  }
309
  user_payload = build_user_payload(batch_rows)
310
 
311
- body = {
312
- "model": model,
313
- "instructions": SYSTEM_INSTRUCTIONS,
314
- "input": user_payload,
315
- "prompt_cache_key": prompt_cache_key,
316
- "prompt_cache_retention": prompt_cache_retention,
317
- "reasoning": {"effort": reasoning_effort},
318
- "tools": [
319
- {
320
- "type": "function",
321
- "name": "submit_labels",
322
- "description": "Submit relabeled BIO labels.",
323
- "parameters": response_schema(),
324
- "strict": True,
325
- }
326
- ],
327
- "tool_choice": {"type": "function", "name": "submit_labels"},
328
  }
329
 
330
  last_error: Exception | None = None
331
  for attempt in range(1, retries + 1):
332
  try:
333
- resp = requests.post(url, headers=headers, json=body, timeout=120)
 
 
 
 
 
 
 
 
 
 
 
 
334
  resp.raise_for_status()
335
  obj = resp.json()
336
  usage_stats = parse_usage(obj)
@@ -378,12 +478,25 @@ def relabel_batch(
378
  return mapping, usage_stats
379
  except Exception as exc: # noqa: BLE001
380
  last_error = exc
381
- # Some compatible gateways may not support prompt caching or reasoning fields.
 
382
  if isinstance(exc, requests.HTTPError) and exc.response is not None and exc.response.status_code == 400:
383
- body.pop("prompt_cache_retention", None)
384
- body.pop("reasoning", None)
385
- body.pop("tools", None)
386
- body.pop("tool_choice", None)
 
 
 
 
 
 
 
 
 
 
 
 
387
  if attempt == retries:
388
  break
389
  time.sleep(0.8 * attempt)
@@ -410,6 +523,7 @@ def process_batch_with_fallback(
410
  user_agent: str,
411
  retries: int,
412
  failure_log: str,
 
413
  ) -> List[tuple[Row, List[str]]]:
414
  usage_total = UsageStats()
415
  try:
@@ -424,6 +538,7 @@ def process_batch_with_fallback(
424
  user_agent=user_agent,
425
  retries=retries,
426
  failure_log=failure_log,
 
427
  )
428
  usage_total.add(usage)
429
  except RuntimeError:
@@ -441,6 +556,7 @@ def process_batch_with_fallback(
441
  user_agent=user_agent,
442
  retries=max(retries, 4),
443
  failure_log=failure_log,
 
444
  )
445
  usage_total.add(usage)
446
  mapping[idx] = single[0]
@@ -449,8 +565,23 @@ def process_batch_with_fallback(
449
  failure_log,
450
  f"[row-skip] file_id={row.record.get('file_id')} line={row.line_no} reason={exc}",
451
  )
452
- mapping[idx] = row.record.get("labels", [])
453
- return [(batch[row_id], labels) for row_id, labels in mapping.items()], usage_total
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
 
456
  def process_batch_timed(
@@ -465,6 +596,7 @@ def process_batch_timed(
465
  user_agent: str,
466
  retries: int,
467
  failure_log: str,
 
468
  ) -> Dict[str, Any]:
469
  meter.task_start()
470
  t0 = time.time()
@@ -480,6 +612,7 @@ def process_batch_timed(
480
  user_agent=user_agent,
481
  retries=retries,
482
  failure_log=failure_log,
 
483
  )
484
  return {
485
  "updates": updates,
@@ -552,6 +685,7 @@ def main() -> None:
552
  user_agent=args.user_agent,
553
  retries=args.retries,
554
  failure_log=args.failure_log,
 
555
  )
556
  for batch in batches
557
  ]
 
20
  from typing import Any, Dict, List, Sequence
21
 
22
  import requests
23
+ from anifilebert.label_repairs import repair_jsonl_item
24
 
25
 
26
  ALLOWED_LABELS = {
 
152
  p.add_argument("--checkpoint-rows", type=int, default=100, help="Write checkpoint every N processed rows")
153
  p.add_argument("--failure-log", default="reports/llm_relabel_failures.log", help="Failure log path")
154
  p.add_argument("--perf-log", default="", help="Optional JSON perf summary path")
155
+ p.add_argument("--http-timeout", type=int, default=240, help="HTTP timeout in seconds per request")
156
  p.add_argument("--usd-per-1m-input", type=float, default=0.75, help="Input token price (USD per 1M tokens)")
157
  p.add_argument("--usd-per-1m-output", type=float, default=4.5, help="Output token price (USD per 1M tokens)")
158
  p.add_argument(
 
246
  return True
247
 
248
 
249
+ def normalize_iob2_labels(labels: Sequence[str]) -> List[str]:
250
+ normalized: List[str] = []
251
+ prev_entity = ""
252
+ for lb in labels:
253
+ if not isinstance(lb, str) or not lb.startswith(("B-", "I-")):
254
+ normalized.append("O")
255
+ prev_entity = ""
256
+ continue
257
+ entity = lb.split("-", 1)[1]
258
+ prefix = "I" if prev_entity == entity else "B"
259
+ normalized.append(f"{prefix}-{entity}")
260
+ prev_entity = entity
261
+ return normalized
262
+
263
+
264
+ def title_segments(labels: Sequence[str]) -> List[tuple[int, int]]:
265
+ segments: List[tuple[int, int]] = []
266
+ i = 0
267
+ n = len(labels)
268
+ while i < n:
269
+ if str(labels[i]).endswith("TITLE"):
270
+ j = i + 1
271
+ while j < n and str(labels[j]).endswith("TITLE"):
272
+ j += 1
273
+ segments.append((i, j))
274
+ i = j
275
+ else:
276
+ i += 1
277
+ return segments
278
+
279
+
280
+ def force_single_title_segment(tokens: Sequence[str], labels: Sequence[str]) -> List[str]:
281
+ """Guarantee TITLE is a single contiguous segment."""
282
+ if len(tokens) != len(labels):
283
+ return list(labels)
284
+ fixed = normalize_iob2_labels(labels)
285
+ segs = title_segments(fixed)
286
+ if len(segs) <= 1:
287
+ return fixed
288
+
289
+ first_episode = next((idx for idx, lb in enumerate(fixed) if str(lb).endswith("EPISODE")), len(fixed))
290
+
291
+ def score(seg: tuple[int, int]) -> tuple[int, int, int]:
292
+ start, end = seg
293
+ length = end - start
294
+ before_episode = 1 if start < first_episode else 0
295
+ return (before_episode, length, -start)
296
+
297
+ keep = max(segs, key=score)
298
+ ks, ke = keep
299
+ out = list(fixed)
300
+ for i in range(len(out)):
301
+ if str(out[i]).endswith("TITLE") and not (ks <= i < ke):
302
+ out[i] = "O"
303
+ out = normalize_iob2_labels(out)
304
+ return out
305
+
306
+
307
  def response_schema() -> Dict[str, Any]:
308
  return {
309
  "type": "object",
 
336
  f.write(message.rstrip() + "\n")
337
 
338
 
339
+ def build_request_body(
340
+ model: str,
341
+ user_payload: str,
342
+ prompt_cache_key: str,
343
+ prompt_cache_retention: str,
344
+ reasoning_effort: str,
345
+ include_tools: bool = True,
346
+ include_tool_choice: bool = True,
347
+ include_reasoning: bool = True,
348
+ include_cache_key: bool = True,
349
+ include_cache_retention: bool = True,
350
+ ) -> Dict[str, Any]:
351
+ body: Dict[str, Any] = {
352
+ "model": model,
353
+ "instructions": SYSTEM_INSTRUCTIONS,
354
+ "input": user_payload,
355
+ }
356
+ if include_cache_key:
357
+ body["prompt_cache_key"] = prompt_cache_key
358
+ if include_cache_retention:
359
+ body["prompt_cache_retention"] = prompt_cache_retention
360
+ if include_reasoning:
361
+ body["reasoning"] = {"effort": reasoning_effort}
362
+ if include_tools:
363
+ body["tools"] = [
364
+ {
365
+ "type": "function",
366
+ "name": "submit_labels",
367
+ "description": "Submit relabeled BIO labels.",
368
+ "parameters": response_schema(),
369
+ "strict": True,
370
+ }
371
+ ]
372
+ if include_tool_choice and include_tools:
373
+ body["tool_choice"] = {"type": "function", "name": "submit_labels"}
374
+ return body
375
+
376
+
377
  def parse_usage(response_obj: Dict[str, Any]) -> UsageStats:
378
  usage = response_obj.get("usage", {}) or {}
379
  in_details = usage.get("input_tokens_details", {}) or {}
 
397
  user_agent: str,
398
  retries: int,
399
  failure_log: str,
400
+ http_timeout: int,
401
  ) -> tuple[Dict[int, List[str]], UsageStats]:
402
  url = f"{api_base.rstrip('/')}/responses"
403
  headers = {
 
407
  }
408
  user_payload = build_user_payload(batch_rows)
409
 
410
+ cfg = {
411
+ "include_tools": True,
412
+ "include_tool_choice": True,
413
+ "include_reasoning": True,
414
+ "include_cache_key": True,
415
+ "include_cache_retention": True,
 
 
 
 
 
 
 
 
 
 
 
416
  }
417
 
418
  last_error: Exception | None = None
419
  for attempt in range(1, retries + 1):
420
  try:
421
+ body = build_request_body(
422
+ model=model,
423
+ user_payload=user_payload,
424
+ prompt_cache_key=prompt_cache_key,
425
+ prompt_cache_retention=prompt_cache_retention,
426
+ reasoning_effort=reasoning_effort,
427
+ include_tools=cfg["include_tools"],
428
+ include_tool_choice=cfg["include_tool_choice"],
429
+ include_reasoning=cfg["include_reasoning"],
430
+ include_cache_key=cfg["include_cache_key"],
431
+ include_cache_retention=cfg["include_cache_retention"],
432
+ )
433
+ resp = requests.post(url, headers=headers, json=body, timeout=http_timeout)
434
  resp.raise_for_status()
435
  obj = resp.json()
436
  usage_stats = parse_usage(obj)
 
478
  return mapping, usage_stats
479
  except Exception as exc: # noqa: BLE001
480
  last_error = exc
481
+ # Some compatible gateways may not support all optional fields.
482
+ # Downgrade progressively and keep structured tool output whenever possible.
483
  if isinstance(exc, requests.HTTPError) and exc.response is not None and exc.response.status_code == 400:
484
+ response_text = (exc.response.text or "")[:1200]
485
+ lowered = response_text.lower()
486
+ append_failure_log(
487
+ failure_log,
488
+ f"[http400] attempt={attempt} model={model} body_cfg={cfg} response={response_text!r}",
489
+ )
490
+ if "prompt_cache_retention" in lowered and cfg["include_cache_retention"]:
491
+ cfg["include_cache_retention"] = False
492
+ elif "prompt_cache_key" in lowered and cfg["include_cache_key"]:
493
+ cfg["include_cache_key"] = False
494
+ elif "reasoning" in lowered and cfg["include_reasoning"]:
495
+ cfg["include_reasoning"] = False
496
+ elif "tool_choice" in lowered and cfg["include_tool_choice"]:
497
+ cfg["include_tool_choice"] = False
498
+ elif "tools" in lowered and cfg["include_tools"]:
499
+ cfg["include_tools"] = False
500
  if attempt == retries:
501
  break
502
  time.sleep(0.8 * attempt)
 
523
  user_agent: str,
524
  retries: int,
525
  failure_log: str,
526
+ http_timeout: int,
527
  ) -> List[tuple[Row, List[str]]]:
528
  usage_total = UsageStats()
529
  try:
 
538
  user_agent=user_agent,
539
  retries=retries,
540
  failure_log=failure_log,
541
+ http_timeout=http_timeout,
542
  )
543
  usage_total.add(usage)
544
  except RuntimeError:
 
556
  user_agent=user_agent,
557
  retries=max(retries, 4),
558
  failure_log=failure_log,
559
+ http_timeout=http_timeout,
560
  )
561
  usage_total.add(usage)
562
  mapping[idx] = single[0]
 
565
  failure_log,
566
  f"[row-skip] file_id={row.record.get('file_id')} line={row.line_no} reason={exc}",
567
  )
568
+ # Hard fallback: enforce contiguous TITLE rather than keeping polluted labels.
569
+ toks = row.record.get("tokens", [])
570
+ lbs = row.record.get("labels", [])
571
+ if isinstance(toks, list) and isinstance(lbs, list) and len(toks) == len(lbs):
572
+ mapping[idx] = force_single_title_segment(toks, lbs)
573
+ else:
574
+ mapping[idx] = lbs
575
+
576
+ updates: List[tuple[Row, List[str]]] = []
577
+ for row_id, labels in mapping.items():
578
+ row = batch[row_id]
579
+ rec = dict(row.record)
580
+ rec["labels"] = force_single_title_segment(rec.get("tokens", []), labels)
581
+ repaired, _repairs = repair_jsonl_item(rec)
582
+ new_labels = repaired.get("labels", rec.get("labels", []))
583
+ updates.append((row, new_labels))
584
+ return updates, usage_total
585
 
586
 
587
  def process_batch_timed(
 
596
  user_agent: str,
597
  retries: int,
598
  failure_log: str,
599
+ http_timeout: int,
600
  ) -> Dict[str, Any]:
601
  meter.task_start()
602
  t0 = time.time()
 
612
  user_agent=user_agent,
613
  retries=retries,
614
  failure_log=failure_log,
615
+ http_timeout=http_timeout,
616
  )
617
  return {
618
  "updates": updates,
 
685
  user_agent=args.user_agent,
686
  retries=args.retries,
687
  failure_log=args.failure_log,
688
+ http_timeout=args.http_timeout,
689
  )
690
  for batch in batches
691
  ]