AniFileBERT / anifilebert /inference.py
ModerRAS's picture
Train virtual-shard anime parser
359ff82
"""
Inference script for anime filename parser.
Loads a trained model and tokenizer, parses anime filenames,
and outputs structured metadata.
Usage:
python -m anifilebert.inference "[ANi] 葬送的芙莉莲 S2 - 03 [1080P][WEB-DL]"
python -m anifilebert.inference --input-file filenames.txt --output-file results.jsonl
"""
import argparse
import json
import re
import sys
from typing import Dict, List, Optional, Tuple
import torch
from transformers import BertForTokenClassification
from .config import Config
from .label_repairs import season_marker_number
from .tokenizer import AnimeTokenizer, load_tokenizer
# Chinese number mapping
CN_NUM_MAP: Dict[str, int] = {
"一": 1, "二": 2, "三": 3, "四": 4, "五": 5,
"六": 6, "七": 7, "八": 8, "九": 9, "十": 10,
}
STANDALONE_SPECIAL_RE = re.compile(
r"^(?:"
r"(?:BD\s*)?Menu\s*\d{0,2}(?:-\d{1,2})?|"
r"NCOP\s*\d{0,2}|NCED\s*\d{0,2}|"
r"OP\s*\d{0,2}|ED(?:\s*E?\d{0,2})?|"
r"PV\s*\d{0,2}|CM\s*\d{0,2}|"
r"OVA\s*\d{0,2}|OAD\s*\d{0,2}|SP\s*\d{0,2}|IV\d+"
r")$",
re.I,
)
BRACKETED_SEARCH_SPECIAL_RE = re.compile(
r"[\[【((]\s*((?:檢索|检索|検索)\s*[::][^\]】))]+?)\s*[\]】))]"
)
NEW_SHOW_BRACKET_TITLE_RE = re.compile(
r"[★☆][^★☆\[\]【】()()]{0,24}(?:新番|月番)[^★☆\[\]【】()()]{0,24}[★☆]"
r"\s*[\[【((]\s*([^\]】))]+?)\s*[\]】))]"
)
def extract_season_number(text: str) -> Optional[int]:
"""
Extract season number from various season formats.
Examples:
"S2" → 2, "Season 2" → 2, "第二季" → 2, "1st Season" → 1
"""
marker_value = season_marker_number(text)
if marker_value is not None:
return marker_value
# Arabic digits
match = re.search(r'(\d+)', text)
if match:
return int(match.group(1))
# Chinese digits
for cn, num in CN_NUM_MAP.items():
if cn in text:
return num
return None
def extract_episode_number(text: str) -> Optional[int]:
"""
Extract episode number from various episode formats.
Examples:
"03" → 3, "EP21" → 21, "第7话" → 7, "#01" → 1
"""
match = re.search(r'(\d+)', text)
if match:
return int(match.group(1))
return None
def extract_resolution(text: str) -> Optional[str]:
"""Extract resolution string (e.g., '1080P', '4K', '1920x1080')."""
# Strip brackets for matching
clean = text.strip("[]()【】")
return clean if clean else None
def normalize_field_text(text: str) -> str:
return trim_decorations(text).strip(" \t-_.")
def thin_source_priority(source: str) -> int:
normalized = source.lower().replace("_", "-").replace(" ", "")
if normalized in {
"nf", "netflix", "amzn", "baha", "cr", "abema", "dsnp", "u-next", "hulu", "at-x",
"web-dl", "webdl", "webrip", "web-rip", "bdrip", "bluray", "bdmv", "bd",
"dvdrip", "dvd", "tvrip", "hdtv",
}:
return 90
if normalized in {"chs", "cht", "gb", "big5", "jpn", "jp", "jpsc", "jptc", "繁中", "简中"}:
return 70
if normalized in {
"x264", "x265", "h.264", "h264", "h.265", "h265", "hevc", "avc", "av1",
"aac", "flac", "mp3", "dts", "opus", "10bit", "8bit", "hi10p", "ma10p",
"srt", "srtx2", "ass", "assx2",
}:
return 20
return 40 if re.search(r"[&+/,]", source) else 30
def normalize_source_text(text: str) -> str:
text = re.sub(r"\s+", "", text.strip())
text = re.sub(r"(?i)WEB[_ ]?DL", "WEB-DL", text)
text = re.sub(r"(?i)WEB[_ ]?Rip", "WebRip", text)
text = re.sub(r"(?i)U[_ ]?NEXT", "U-NEXT", text)
text = re.sub(r"(?i)AT[_ ]?X", "AT-X", text)
return text.replace("_", "-")
def choose_thin_source(sources: List[str]) -> Optional[str]:
cleaned = [normalize_source_text(source) for source in sources if normalize_field_text(source)]
if not cleaned:
return None
return max(enumerate(cleaned), key=lambda item: (thin_source_priority(item[1]), -item[0]))[1]
def normalize_standalone_special(text: str) -> Optional[str]:
special = normalize_field_text(text)
if not special:
return None
return special if STANDALONE_SPECIAL_RE.fullmatch(special) else None
def extract_bracketed_search_special(text: str) -> Optional[str]:
"""Return bracketed search-note tags such as [檢索:...]."""
for match in BRACKETED_SEARCH_SPECIAL_RE.finditer(text):
special = normalize_field_text(match.group(1))
if special:
return special
return None
def extract_new_show_bracket_title(text: str) -> Optional[str]:
"""Return title from release-promo layouts like ★04月新番★[葬送的芙莉莲]."""
for match in NEW_SHOW_BRACKET_TITLE_RE.finditer(text):
title = normalize_field_text(match.group(1))
if title:
return title
return None
def display_token(token: str) -> str:
"""Make whitespace tokens visible in debug output."""
if token == " ":
return "<SPACE>"
if token == "\t":
return "<TAB>"
return token
def trim_decorations(text: str) -> str:
"""Trim outer release brackets from an extracted entity."""
return text.strip().strip("[]()【】《》()").strip()
def join_entity_tokens(tokens: List[str], tokenizer: Optional[AnimeTokenizer] = None) -> str:
"""Join entity tokens according to the tokenizer granularity."""
if tokenizer is not None and getattr(tokenizer, "tokenizer_variant", "regex") == "char":
return "".join(tokens)
text = "".join(tokens)
if " " in tokens:
return text
return text
def labels_to_entities(
tokens: List[str],
labels: List[str],
tokenizer: Optional[AnimeTokenizer] = None,
) -> List[Tuple[str, str]]:
"""
Convert BIO labels into entity spans.
Illegal orphan I-X labels start a new entity so debug output exposes the
model behavior instead of silently dropping tokens.
"""
entities: List[Tuple[str, str]] = []
current_entity: Optional[str] = None
current_tokens: List[str] = []
for token, label in zip(tokens, labels):
if label.startswith("B-"):
if current_entity:
entities.append((current_entity, join_entity_tokens(current_tokens, tokenizer)))
current_entity = label[2:]
current_tokens = [token]
elif label.startswith("I-"):
entity_type = label[2:]
if current_entity == entity_type:
current_tokens.append(token)
else:
if current_entity:
entities.append((current_entity, join_entity_tokens(current_tokens, tokenizer)))
current_entity = entity_type
current_tokens = [token]
else:
if current_entity:
entities.append((current_entity, join_entity_tokens(current_tokens, tokenizer)))
current_entity = None
current_tokens = []
if current_entity:
entities.append((current_entity, join_entity_tokens(current_tokens, tokenizer)))
return entities
def is_allowed_bio_transition(previous_label: str, label: str) -> bool:
"""Return whether previous_label -> label is valid under IOB2."""
if label.startswith("I-"):
entity = label[2:]
return previous_label in {f"B-{entity}", f"I-{entity}"}
return True
_BIO_TRANSITION_CACHE: Dict[Tuple[Tuple[int, str], ...], torch.Tensor] = {}
def bio_transition_mask(id2label: Dict[int, str]) -> torch.Tensor:
"""Return cached valid-transition mask shaped [prev_label, next_label]."""
key = tuple(sorted((int(label_id), label) for label_id, label in id2label.items()))
cached = _BIO_TRANSITION_CACHE.get(key)
if cached is not None:
return cached
num_labels = max(id2label) + 1 if id2label else 0
mask = torch.zeros((num_labels, num_labels), dtype=torch.bool)
for prev_id in range(num_labels):
prev_label = id2label.get(prev_id, "O")
for label_id in range(num_labels):
label = id2label.get(label_id, "O")
mask[prev_id, label_id] = is_allowed_bio_transition(prev_label, label)
_BIO_TRANSITION_CACHE[key] = mask
return mask
def constrained_bio_decode(emissions: torch.Tensor, id2label: Dict[int, str]) -> List[int]:
"""
Decode token logits with hard BIO transition constraints.
This is a lightweight CRF-style Viterbi decoder without learned transition
weights. It prevents impossible orphan I-X spans at inference time.
"""
if emissions.numel() == 0:
return []
num_tokens, num_labels = emissions.shape
scores = emissions.detach().cpu()
transition_mask = bio_transition_mask(id2label)
backpointers = torch.zeros((num_tokens, num_labels), dtype=torch.long)
dp = torch.full((num_labels,), float("-inf"))
for label_id in range(num_labels):
label = id2label.get(label_id, "O")
if not label.startswith("I-"):
dp[label_id] = scores[0, label_id]
for idx in range(1, num_tokens):
candidates = dp.unsqueeze(1).expand(num_labels, num_labels)
candidates = candidates.masked_fill(~transition_mask, float("-inf"))
best_scores, best_prev = candidates.max(dim=0)
next_dp = best_scores + scores[idx]
backpointers[idx] = best_prev
dp = next_dp
best_last = int(torch.argmax(dp).item())
decoded = [best_last]
for idx in range(num_tokens - 1, 0, -1):
decoded.append(int(backpointers[idx, decoded[-1]].item()))
decoded.reverse()
return decoded
def postprocess(
tokens: List[str],
labels: List[str],
tokenizer: Optional[AnimeTokenizer] = None,
) -> Dict:
"""
Convert BIO-labeled tokens into structured metadata.
Merges consecutive B- / I- tokens of the same entity type,
then extracts structured fields.
"""
result: Dict = {
"title": None,
"season": None,
"episode": None,
"group": None,
"resolution": None,
"source": None,
"special": None,
}
entities = labels_to_entities(tokens, labels, tokenizer)
grouped_entities: Dict[str, List[str]] = {}
for entity_type, text in entities:
grouped_entities.setdefault(entity_type, []).append(text)
title_fragments = [
cleaned for text in grouped_entities.get("TITLE", [])
if (cleaned := normalize_field_text(text))
]
if title_fragments:
result["title"] = " ".join(title_fragments)
for text in grouped_entities.get("SEASON", []):
season_num = extract_season_number(text)
if season_num is not None:
result["season"] = season_num
for text in grouped_entities.get("EPISODE", []):
ep_num = extract_episode_number(text)
if ep_num is not None:
if result["episode"] is None:
result["episode"] = ep_num
for text in grouped_entities.get("GROUP", []):
group = normalize_field_text(text)
if result["group"] is None:
result["group"] = group
for text in grouped_entities.get("SPECIAL", []):
special = normalize_field_text(text)
result["special"] = special
for text in grouped_entities.get("RESOLUTION", []):
res = extract_resolution(text)
if res:
result["resolution"] = res
result["source"] = choose_thin_source(grouped_entities.get("SOURCE", []))
whole_text = join_entity_tokens(tokens, tokenizer)
new_show_title = extract_new_show_bracket_title(whole_text)
if new_show_title is not None and (
result["title"] is None
or result["title"].startswith(("★", "☆"))
or "新番" in result["title"]
or "月番" in result["title"]
):
result["title"] = new_show_title
search_special = extract_bracketed_search_special(whole_text)
if search_special is not None:
result["special"] = search_special
standalone_special = normalize_standalone_special(whole_text)
if standalone_special is not None:
result.update(
{
"title": None,
"season": None,
"episode": None,
"group": None,
"resolution": None,
"source": None,
"special": standalone_special,
}
)
return result
def parse_filename(
filename: str,
model: BertForTokenClassification,
tokenizer: AnimeTokenizer,
id2label: Dict[int, str],
max_length: int = 64,
debug: bool = False,
constrain_bio: bool = True,
) -> Dict:
"""
Parse an anime filename and extract structured metadata.
Args:
filename: Raw anime filename string.
model: Trained BertForTokenClassification model.
tokenizer: AnimeTokenizer instance.
id2label: Mapping from label ID to label string.
max_length: Maximum sequence length (including special tokens).
Returns:
Dict with parsed fields (title, season, episode, etc.).
"""
# Tokenize
tokens = tokenizer.tokenize(filename)
if not tokens:
return {"title": None, "season": None, "episode": None,
"group": None, "resolution": None, "source": None,
"special": None}
# Convert to input IDs
input_ids = tokenizer.convert_tokens_to_ids(tokens)
embedding_size = model.get_input_embeddings().weight.shape[0]
out_of_range_tokens = [
token for token, token_id in zip(tokens, input_ids)
if token_id >= embedding_size
]
if out_of_range_tokens:
input_ids = [
token_id if token_id < embedding_size else tokenizer.unk_token_id
for token_id in input_ids
]
unk_token_id = tokenizer.unk_token_id
unk_tokens = [token for token, token_id in zip(tokens, input_ids) if token_id == unk_token_id]
# Add special tokens
input_ids = [tokenizer.cls_token_id] + input_ids + [tokenizer.sep_token_id]
attention_mask = [1] * len(input_ids)
# Truncate if needed
if len(input_ids) > max_length:
input_ids = [input_ids[0]] + input_ids[1:max_length - 1] + [tokenizer.sep_token_id]
attention_mask = [1] * len(input_ids)
# Pad
pad_len = max_length - len(input_ids)
if pad_len > 0:
input_ids += [tokenizer.pad_token_id] * pad_len
attention_mask += [0] * pad_len
# Predict
device = next(model.parameters()).device
input_tensor = torch.tensor([input_ids], device=device)
mask_tensor = torch.tensor([attention_mask], device=device)
# Remove special token predictions
# Count real tokens used (minus CLS/SEP)
real_token_count = len(tokens)
# Truncate real tokens if we had to truncate
available = min(real_token_count, max_length - 2)
if available <= 0:
return {"title": None, "season": None, "episode": None,
"group": None, "resolution": None, "source": None,
"special": None}
with torch.no_grad():
logits = model(input_ids=input_tensor, attention_mask=mask_tensor).logits
token_logits = logits[0, 1:1 + available, :]
probabilities = torch.softmax(token_logits, dim=-1)
scores, greedy_predictions = torch.max(probabilities, dim=-1)
if constrain_bio:
pred_labels = constrained_bio_decode(token_logits, id2label)
selected_scores = [
probabilities[idx, label_id].detach().cpu().item()
for idx, label_id in enumerate(pred_labels)
]
else:
pred_labels = greedy_predictions.detach().cpu().tolist()
selected_scores = scores.detach().cpu().tolist()
label_strings = [id2label.get(p, "O") for p in pred_labels]
# Post-process
result = postprocess(
tokens[:available],
label_strings,
tokenizer=tokenizer,
)
if debug:
result["_debug"] = {
"tokenizer_variant": getattr(tokenizer, "tokenizer_variant", "regex"),
"decoder": "constrained_bio" if constrain_bio else "greedy",
"postprocess": "thin_normalize",
"max_length": max_length,
"token_count": len(tokens),
"available_token_count": available,
"truncated": len(tokens) > available,
"unk_count": len(unk_tokens),
"unk_rate": len(unk_tokens) / len(tokens) if tokens else 0.0,
"unk_tokens": unk_tokens[:50],
"vocab_mismatch": bool(out_of_range_tokens),
"model_embedding_size": int(embedding_size),
"tokenizer_vocab_size": int(tokenizer.vocab_size),
"out_of_range_tokens": out_of_range_tokens[:50],
"tokens": tokens[:available],
"labels": label_strings,
"scores": [round(float(score), 4) for score in selected_scores],
"token_table": [
{
"i": i,
"token": display_token(token),
"id": int(token_id),
"label": label,
"score": round(float(score), 4),
}
for i, (token, token_id, label, score) in enumerate(
zip(tokens[:available], input_ids[1:1 + available], label_strings, selected_scores)
)
],
"entities": [
{"type": entity_type, "text": text}
for entity_type, text in labels_to_entities(tokens[:available], label_strings, tokenizer)
],
}
return result
def main():
parser = argparse.ArgumentParser(description="Anime filename parser")
parser.add_argument("filename", nargs="?", type=str, help="Anime filename to parse")
parser.add_argument("--input-file", type=str, help="File with filenames (one per line)")
parser.add_argument("--output-file", type=str, help="Output file for results (JSONL)")
parser.add_argument("--model-dir", type=str, default=".",
help="Path to trained model directory")
parser.add_argument("--tokenizer", choices=["regex", "char"], default=None,
help="Tokenizer variant override. Defaults to checkpoint metadata")
parser.add_argument("--max-length", type=int, default=64,
help="Maximum sequence length")
parser.add_argument("--debug", action="store_true",
help="Include tokenizer, labels, scores, and entity spans in JSON output")
parser.add_argument("--no-constrained-bio", action="store_true",
help="Use greedy per-token decoding instead of constrained BIO Viterbi")
args = parser.parse_args()
# Load config
cfg = Config()
# Load tokenizer
print(f"Loading tokenizer from {args.model_dir}...", file=sys.stderr)
tokenizer = load_tokenizer(args.model_dir, args.tokenizer)
# Load model
print(f"Loading model from {args.model_dir}...", file=sys.stderr)
model = BertForTokenClassification.from_pretrained(args.model_dir)
model.eval()
id2label = {int(k): v for k, v in getattr(model.config, "id2label", cfg.id2label).items()}
max_length = args.max_length
if max_length == 64:
max_length = int(getattr(model.config, "max_seq_length", max_length))
# Process filenames
filenames_to_parse: List[str] = []
if args.filename:
filenames_to_parse.append(args.filename)
if args.input_file:
with open(args.input_file, 'r', encoding='utf-8') as f:
filenames_to_parse.extend(line.strip() for line in f if line.strip())
if not filenames_to_parse:
# Read from stdin
filenames_to_parse.extend(sys.stdin.read().strip().splitlines())
# Parse and output
results: List[Dict] = []
for fn in filenames_to_parse:
if not fn.strip():
continue
result = parse_filename(
fn,
model,
tokenizer,
id2label,
max_length,
debug=args.debug,
constrain_bio=not args.no_constrained_bio,
)
result["_input"] = fn
results.append(result)
if args.output_file is None:
print(json.dumps(result, ensure_ascii=False))
if args.output_file:
with open(args.output_file, 'w', encoding='utf-8') as f:
for r in results:
f.write(json.dumps(r, ensure_ascii=False) + '\n')
print(f"Results saved to {args.output_file}", file=sys.stderr)
if __name__ == "__main__":
main()