AniFileBERT / anifilebert /label_repairs.py
ModerRAS's picture
Organize parser modules and tools
8c50d16
"""Deterministic label repairs for known weak-label blind spots."""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
SEPARATOR_CHARS = set(" \t-_.|~~")
ROMAN_NUMERAL_VALUES = {
"II": 2,
"III": 3,
"IV": 4,
"V": 5,
"VI": 6,
"VII": 7,
"VIII": 8,
"IX": 9,
"Ⅱ": 2,
"Ⅲ": 3,
"Ⅳ": 4,
"Ⅴ": 5,
"Ⅵ": 6,
"Ⅶ": 7,
"Ⅷ": 8,
"Ⅸ": 9,
}
CN_NUMERAL_VALUES = {
"一": 1,
"二": 2,
"兩": 2,
"两": 2,
"貳": 2,
"贰": 2,
"弐": 2,
"弍": 2,
"三": 3,
"參": 3,
"叁": 3,
"参": 3,
"四": 4,
"肆": 4,
"五": 5,
"伍": 5,
"六": 6,
"陸": 6,
"陆": 6,
"七": 7,
"柒": 7,
"八": 8,
"捌": 8,
"九": 9,
"玖": 9,
"十": 10,
}
READING_MARKER_VALUES = {
"ni no sara": 2,
"ni no shou": 2,
"ni no sho": 2,
"ni no syo": 2,
"ni no shō": 2,
"ni gakki": 2,
"sono ni": 2,
"san no sara": 3,
"san no shou": 3,
"san no sho": 3,
"san no syo": 3,
"yon no sara": 4,
"shi no sara": 4,
"shin no sara": 4,
"go no sara": 5,
"gou no sara": 5,
}
# Bare "Ni" is often the Japanese particle に in romanized titles. Only repair
# it for titles that have been verified as a sequel marker in the release name.
STANDALONE_NI_SEASON_BASES = {
"Kakuriyo no Yadomeshi": 2,
}
EPISODE_CONTEXT_RE = re.compile(
r"^\s*(?:"
r"[-_]\s*(?:\d{1,4}|NCOP|NCED|OP|ED|OVA|OAD|SP|END)\b|"
r"#\s*\d{1,4}|"
r"[\[\(【《]\s*(?:EP?|#)?\d{1,4}"
r")",
re.I,
)
EPISODE_SPAN_RE = re.compile(
r"(?:"
r"[Ss]\d{1,2}[Ee]\d{1,4}(?:v\d+)?|"
r"(?:^|[\s._])[-_]\s*\d{1,4}(?:v\d+)?(?=$|[\s._\-\]\)】》\[])|"
r"[\[\(【《](?:EP?|#)?\d{1,4}(?:v\d+)?[\]\)】》]|"
r"(?:^|[\s._\-\[\(【《#])(?:EP?|第|#)\d{1,4}(?:v\d+)?(?:[话話集])?(?=$|[\s._\-\]\)】》])"
r")",
re.I,
)
BRACKET_RE = re.compile(r"\[([^\]]*)\]|\(([^)]*)\)|【([^】]*)】|《([^》]*)》")
RESOLUTION_RE = re.compile(r"(?<![A-Za-z0-9])(?:\d{3,4}[pP]|\d[Kk]|\d{3,4}[xX×]\d{3,4})(?![A-Za-z0-9])")
SOURCE_TOKEN_PATTERN = (
r"WEB[-_ ]?DL|WEB[-_ ]?Rip|BDRip|BluRay|BDMV|BD|DVDRip|DVD|TVRip|HDTV|"
r"Netflix|NF|AMZN|Baha|CR|ABEMA|DSNP|U[-_ ]?NEXT|Hulu|AT[-_ ]?X|"
r"x26[45]|h\.?26[45]|HEVC|AVC|AV1|AAC\d*(?:\.\d+)?|AAC|FLAC|MP3|DTS|Opus|"
r"CHS|CHT|GB|BIG5|JPN?|JPSC|JPTC|繁中|简中"
)
SOURCE_RE = re.compile(rf"(?<![A-Za-z0-9])(?:{SOURCE_TOKEN_PATTERN})(?![A-Za-z0-9])", re.I)
SOURCE_TAG_RE = re.compile(
rf"^(?:{SOURCE_TOKEN_PATTERN})(?:\s*(?:[&+/,_-]|,\s*)\s*(?:{SOURCE_TOKEN_PATTERN}))*$",
re.I,
)
SPECIAL_TAG_RE = re.compile(
r"^(?:檢索|检索|搜索|搜寻|搜尋|别名|別名|alias|search|keyword)\s*[::].+",
re.I,
)
SPECIAL_CODE_RE = re.compile(
r"^(?:NCOP|NCED|OP|ED|PV|CM)\d*$|^IV\d+$|^(?:OVA|OAD|SP)\d*$",
re.I,
)
READING_MARKER_RE = re.compile(
r"(?<![A-Za-z0-9])"
r"(?P<marker>"
r"Ni\s+no\s+(?:Sara|Shou|Sho|Syo|Shō)|"
r"San\s+no\s+(?:Sara|Shou|Sho|Syo)|"
r"(?:Yon|Shi|Shin)\s+no\s+Sara|"
r"(?:Go|Gou)\s+no\s+Sara|"
r"Ni\s+Gakki|"
r"Sono\s+Ni"
r")"
r"(?![A-Za-z0-9])",
)
ROMAN_MARKER_RE = re.compile(
r"(?<![A-Za-z0-9])"
r"(?P<marker>II|III|IV|V|VI|VII|VIII|IX|[ⅡⅢⅣⅤⅥⅦⅧⅨ])"
r"(?![A-Za-z0-9])"
)
CJK_MARKER_RE = re.compile(
r"(?P<marker>"
r"[一二三四五六七八九十兩两貳贰弐弍參叁参肆伍陸陆柒捌玖](?:\s*(?:ノ|の|之)\s*(?:章|期|季|部))?|"
r"第[一二三四五六七八九十兩两貳贰弐弍參叁参肆伍陸陆柒捌玖\d]+[季期部章]"
r")"
)
@dataclass(frozen=True)
class LabelRepair:
kind: str
marker: str
value: int
start: int
end: int
def clean_marker_text(text: str) -> str:
return text.strip().strip("[]()【】《》()").strip()
def cn_number_to_int(text: str) -> Optional[int]:
text = text.strip()
if text.isdigit():
return int(text)
if text in CN_NUMERAL_VALUES:
return CN_NUMERAL_VALUES[text]
values = CN_NUMERAL_VALUES
if text.startswith("十") and len(text) == 2:
return 10 + values.get(text[1], 0)
if text.endswith("十") and len(text) == 2:
return values.get(text[0], 0) * 10
if "十" in text and len(text) == 3:
return values.get(text[0], 0) * 10 + values.get(text[2], 0)
return None
def season_marker_number(text: str) -> Optional[int]:
"""Return season number for compact sequel markers such as II or Ni no Sara."""
clean = clean_marker_text(text)
if not clean:
return None
if clean in ROMAN_NUMERAL_VALUES:
return ROMAN_NUMERAL_VALUES[clean]
lowered = re.sub(r"\s+", " ", clean.lower()).strip()
if lowered in READING_MARKER_VALUES:
return READING_MARKER_VALUES[lowered]
if lowered == "ni":
return 2
explicit = re.fullmatch(r"第(.+)[季期部章]", clean)
if explicit:
return cn_number_to_int(explicit.group(1))
cjk = re.fullmatch(r"([一二三四五六七八九十兩两貳贰弐弍參叁参肆伍陸陆柒捌玖])(?:\s*(?:ノ|の|之)\s*(?:章|期|季|部))?", clean)
if cjk:
return cn_number_to_int(cjk.group(1))
return None
def token_offsets_in_text(text: str, tokens: Sequence[str]) -> Optional[List[Tuple[int, int]]]:
offsets: List[Tuple[int, int]] = []
cursor = 0
for token in tokens:
if token == "":
offsets.append((cursor, cursor))
continue
position = text.find(token, cursor)
if position < 0:
return None
end = position + len(token)
offsets.append((position, end))
cursor = end
return offsets
def has_episode_context(text: str, marker_end: int) -> bool:
tail = text[marker_end:]
if EPISODE_CONTEXT_RE.match(tail):
return True
# Some releases put a season marker at the end of a title bracket and the
# episode in the next bracket: `[Title 貳之章][01]`.
tail = tail.lstrip()
tail = re.sub(r"^[\]\)】》]\s*", "", tail)
tail = re.sub(
r"^(?:[\[\(【《]\s*(?:menu|menus|bdmenu|ncop|nced|op|ed|ova|oad|sp)\s*[\]\)】》]\s*){0,2}",
"",
tail,
flags=re.I,
)
return bool(EPISODE_CONTEXT_RE.match(tail))
def find_sequel_season_markers(text: str) -> List[LabelRepair]:
"""Find high-confidence sequel markers that should be labeled as SEASON."""
repairs: List[LabelRepair] = []
for pattern, kind in (
(READING_MARKER_RE, "reading"),
(ROMAN_MARKER_RE, "roman"),
(CJK_MARKER_RE, "cjk"),
):
for match in pattern.finditer(text):
marker = match.group("marker")
value = season_marker_number(marker)
if value is None or not has_episode_context(text, match.end()):
continue
repairs.append(LabelRepair(kind, marker, value, match.start(), match.end()))
for base, value in STANDALONE_NI_SEASON_BASES.items():
pattern = re.compile(rf"(?<![A-Za-z0-9]){re.escape(base)}\s+(?P<marker>Ni)(?![A-Za-z0-9])")
for match in pattern.finditer(text):
if not has_episode_context(text, match.end("marker")):
continue
repairs.append(
LabelRepair(
kind="verified_bare_ni",
marker=match.group("marker"),
value=value,
start=match.start("marker"),
end=match.end("marker"),
)
)
repairs.sort(key=lambda item: (item.start, item.end))
deduped: List[LabelRepair] = []
for repair in repairs:
if deduped and repair.start < deduped[-1].end:
previous = deduped[-1]
if (repair.end - repair.start) > (previous.end - previous.start):
deduped[-1] = repair
continue
deduped.append(repair)
return deduped
def labels_have_season_before(labels: Sequence[str], offsets: Sequence[Tuple[int, int]], marker_start: int) -> bool:
return any(label.endswith("SEASON") and end <= marker_start for label, (_start, end) in zip(labels, offsets))
def token_indices_for_span(offsets: Sequence[Tuple[int, int]], start: int, end: int) -> List[int]:
return [
idx for idx, (tok_start, tok_end) in enumerate(offsets)
if tok_start < end and tok_end > start
]
def label_span(labels: List[str], indices: Sequence[int], entity: str) -> None:
previous_is_same_entity = bool(indices) and indices[0] > 0 and labels[indices[0] - 1].endswith(entity)
first = not previous_is_same_entity
for idx in indices:
labels[idx] = f"B-{entity}" if first else f"I-{entity}"
first = False
def label_span_if_changed(labels: List[str], indices: Sequence[int], entity: str) -> bool:
previous_is_same_entity = bool(indices) and indices[0] > 0 and labels[indices[0] - 1].endswith(entity)
first_label = f"I-{entity}" if previous_is_same_entity else f"B-{entity}"
expected = [first_label] + [f"I-{entity}"] * max(0, len(indices) - 1)
if [labels[idx] for idx in indices] == expected:
return False
label_span(labels, indices, entity)
return True
def safe_to_overwrite_meta(labels: Sequence[str], indices: Sequence[int]) -> bool:
if not indices:
return False
return not any(
labels[idx].endswith(("GROUP", "EPISODE", "SEASON"))
for idx in indices
)
def mark_adjacent_title_separators_o(
tokens: Sequence[str],
labels: List[str],
marker_indices: Sequence[int],
) -> None:
if not marker_indices:
return
idx = marker_indices[0] - 1
while idx >= 0 and "".join(tokens[idx]).strip() == "" and labels[idx].endswith("TITLE"):
labels[idx] = "O"
idx -= 1
idx = marker_indices[-1] + 1
while idx < len(tokens) and tokens[idx] in SEPARATOR_CHARS and labels[idx].endswith("TITLE"):
labels[idx] = "O"
idx += 1
def first_episode_end(labels: Sequence[str], offsets: Sequence[Tuple[int, int]], text: str) -> int:
ends = [
end for label, (_start, end) in zip(labels, offsets)
if label.endswith("EPISODE")
]
if ends:
return min(ends)
match = EPISODE_SPAN_RE.search(text)
return match.end() if match else 0
def bracket_content_spans(text: str) -> Iterable[Tuple[str, int, int, int, int]]:
for match in BRACKET_RE.finditer(text):
groups = match.groups()
group_index = next((idx for idx, value in enumerate(groups) if value is not None), None)
if group_index is None:
continue
inner = groups[group_index] or ""
# The opening delimiter is one code point in all supported bracket forms.
inner_start = match.start() + 1
inner_end = inner_start + len(inner)
yield inner.strip(), inner_start, inner_end, match.start(), match.end()
def repair_structural_meta_labels(
text: str,
tokens: Sequence[str],
labels: List[str],
offsets: Sequence[Tuple[int, int]],
) -> List[LabelRepair]:
repairs: List[LabelRepair] = []
episode_end = first_episode_end(labels, offsets, text)
for clean, inner_start, inner_end, bracket_start, _bracket_end in bracket_content_spans(text):
if bracket_start < episode_end:
continue
if not clean:
continue
if SPECIAL_TAG_RE.fullmatch(clean) or SPECIAL_CODE_RE.fullmatch(clean):
indices = token_indices_for_span(offsets, inner_start, inner_end)
if safe_to_overwrite_meta(labels, indices) and label_span_if_changed(labels, indices, "SPECIAL"):
repairs.append(LabelRepair("special", clean, 0, inner_start, inner_end))
continue
if SOURCE_TAG_RE.fullmatch(clean):
indices = token_indices_for_span(offsets, inner_start, inner_end)
if safe_to_overwrite_meta(labels, indices) and label_span_if_changed(labels, indices, "SOURCE"):
repairs.append(LabelRepair("source", clean, 0, inner_start, inner_end))
continue
for match in RESOLUTION_RE.finditer(clean):
start = inner_start + match.start()
end = inner_start + match.end()
indices = token_indices_for_span(offsets, start, end)
if safe_to_overwrite_meta(labels, indices) and label_span_if_changed(labels, indices, "RESOLUTION"):
repairs.append(LabelRepair("resolution", match.group(0), 0, start, end))
for match in SOURCE_RE.finditer(clean):
start = inner_start + match.start()
end = inner_start + match.end()
indices = token_indices_for_span(offsets, start, end)
if safe_to_overwrite_meta(labels, indices) and label_span_if_changed(labels, indices, "SOURCE"):
repairs.append(LabelRepair("source", match.group(0), 0, start, end))
# Dot-separated WEB names often carry source/resolution after SxxEyy without
# brackets. Repair only after the episode span to avoid touching titles.
for pattern, entity in ((RESOLUTION_RE, "RESOLUTION"), (SOURCE_RE, "SOURCE")):
for match in pattern.finditer(text):
if match.start() < episode_end:
continue
indices = token_indices_for_span(offsets, match.start(), match.end())
if safe_to_overwrite_meta(labels, indices) and label_span_if_changed(labels, indices, entity):
repairs.append(LabelRepair(entity.lower(), match.group(0), 0, match.start(), match.end()))
return repairs
def repair_known_label_issues(
item: Dict,
) -> Tuple[List[str], List[str], List[LabelRepair]]:
"""
Repair known weak-label issues.
The repair is intentionally conservative:
- sequel markers must be immediately before an episode/special context;
- sequel marker spans must currently be part of TITLE/O, not group/meta;
- rows that already have a season before the marker are left alone;
- structural meta repairs only touch spans after the first episode.
"""
source_tokens = [str(token) for token in item.get("tokens", [])]
source_labels = [str(label) for label in item.get("labels", [])]
if len(source_tokens) != len(source_labels):
return source_tokens, source_labels, []
filename = str(item.get("filename") or "")
text = filename if filename else "".join(source_tokens)
offsets = token_offsets_in_text(text, source_tokens)
if offsets is None:
text = "".join(source_tokens)
offsets = token_offsets_in_text(text, source_tokens)
if offsets is None:
return source_tokens, source_labels, []
repaired_labels = list(source_labels)
applied: List[LabelRepair] = []
quick_text = text.lower()
has_sequel_marker_hint = any(
needle in text or needle in quick_text
for needle in (
" II", " III", " IV", " V", " VI", " VII", " VIII", " IX",
"Ⅱ", "Ⅲ", "Ⅳ", "Ⅴ", "Ⅵ", "Ⅶ", "Ⅷ", "Ⅸ",
"之章", "之期", "之季", "之部", "ノ章", "ノ期", "の章", "の期",
"貳", "贰", "弐", "弍", "參", "叁", "参", "肆", "陸", "陆",
"Ni ", " ni ", " no Sara", "Gakki",
)
)
if has_sequel_marker_hint:
for repair in find_sequel_season_markers(text):
if labels_have_season_before(repaired_labels, offsets, repair.start):
continue
indices = token_indices_for_span(offsets, repair.start, repair.end)
if not indices:
continue
existing = [repaired_labels[idx] for idx in indices]
if any(
label.endswith(("GROUP", "EPISODE", "RESOLUTION", "SOURCE", "SPECIAL"))
for label in existing
):
continue
if not any(label.endswith("TITLE") for label in existing):
continue
label_span(repaired_labels, indices, "SEASON")
mark_adjacent_title_separators_o(source_tokens, repaired_labels, indices)
applied.append(repair)
applied.extend(repair_structural_meta_labels(text, source_tokens, repaired_labels, offsets))
return source_tokens, repaired_labels, applied
def repair_sequel_season_labels(
item: Dict,
) -> Tuple[List[str], List[str], List[LabelRepair]]:
"""Backward-compatible wrapper for callers that repair known label issues."""
return repair_known_label_issues(item)
def repair_jsonl_item(item: Dict) -> Tuple[Dict, List[LabelRepair]]:
tokens, labels, repairs = repair_known_label_issues(item)
labels = normalize_iob2(labels)
if not repairs:
if labels == item.get("labels", []):
return item, []
repaired = dict(item)
repaired["labels"] = labels
return repaired, []
repaired = dict(item)
repaired["tokens"] = tokens
repaired["labels"] = labels
return repaired, repairs
def normalize_iob2(labels: Sequence[str]) -> List[str]:
normalized: List[str] = []
previous_entity: Optional[str] = None
for label in labels:
if not label.startswith(("B-", "I-")):
normalized.append("O")
previous_entity = None
continue
entity = label.split("-", 1)[1]
prefix = "I" if previous_entity == entity else "B"
normalized.append(f"{prefix}-{entity}")
previous_entity = entity
return normalized