""" Inference script for anime filename parser. Loads a trained model and tokenizer, parses anime filenames, and outputs structured metadata. Usage: python -m anifilebert.inference "[ANi] 葬送的芙莉莲 S2 - 03 [1080P][WEB-DL]" python -m anifilebert.inference --input-file filenames.txt --output-file results.jsonl """ import argparse import json import re import sys from typing import Dict, List, Optional, Tuple import torch from transformers import BertForTokenClassification from .config import Config from .label_repairs import season_marker_number from .tokenizer import AnimeTokenizer, load_tokenizer # Chinese number mapping CN_NUM_MAP: Dict[str, int] = { "一": 1, "二": 2, "三": 3, "四": 4, "五": 5, "六": 6, "七": 7, "八": 8, "九": 9, "十": 10, } STANDALONE_SPECIAL_RE = re.compile( r"^(?:" r"(?:BD\s*)?Menu\s*\d{0,2}(?:-\d{1,2})?|" r"NCOP\s*\d{0,2}|NCED\s*\d{0,2}|" r"OP\s*\d{0,2}|ED(?:\s*E?\d{0,2})?|" r"PV\s*\d{0,2}|CM\s*\d{0,2}|" r"OVA\s*\d{0,2}|OAD\s*\d{0,2}|SP\s*\d{0,2}|IV\d+" r")$", re.I, ) BRACKETED_SEARCH_SPECIAL_RE = re.compile( r"[\[【((]\s*((?:檢索|检索|検索)\s*[::][^\]】))]+?)\s*[\]】))]" ) NEW_SHOW_BRACKET_TITLE_RE = re.compile( r"[★☆][^★☆\[\]【】()()]{0,24}(?:新番|月番)[^★☆\[\]【】()()]{0,24}[★☆]" r"\s*[\[【((]\s*([^\]】))]+?)\s*[\]】))]" ) def extract_season_number(text: str) -> Optional[int]: """ Extract season number from various season formats. Examples: "S2" → 2, "Season 2" → 2, "第二季" → 2, "1st Season" → 1 """ marker_value = season_marker_number(text) if marker_value is not None: return marker_value # Arabic digits match = re.search(r'(\d+)', text) if match: return int(match.group(1)) # Chinese digits for cn, num in CN_NUM_MAP.items(): if cn in text: return num return None def extract_episode_number(text: str) -> Optional[int]: """ Extract episode number from various episode formats. Examples: "03" → 3, "EP21" → 21, "第7话" → 7, "#01" → 1 """ match = re.search(r'(\d+)', text) if match: return int(match.group(1)) return None def extract_resolution(text: str) -> Optional[str]: """Extract resolution string (e.g., '1080P', '4K', '1920x1080').""" # Strip brackets for matching clean = text.strip("[]()【】") return clean if clean else None def normalize_field_text(text: str) -> str: return trim_decorations(text).strip(" \t-_.") def thin_source_priority(source: str) -> int: normalized = source.lower().replace("_", "-").replace(" ", "") if normalized in { "nf", "netflix", "amzn", "baha", "cr", "abema", "dsnp", "u-next", "hulu", "at-x", "web-dl", "webdl", "webrip", "web-rip", "bdrip", "bluray", "bdmv", "bd", "dvdrip", "dvd", "tvrip", "hdtv", }: return 90 if normalized in {"chs", "cht", "gb", "big5", "jpn", "jp", "jpsc", "jptc", "繁中", "简中"}: return 70 if normalized in { "x264", "x265", "h.264", "h264", "h.265", "h265", "hevc", "avc", "av1", "aac", "flac", "mp3", "dts", "opus", "10bit", "8bit", "hi10p", "ma10p", "srt", "srtx2", "ass", "assx2", }: return 20 return 40 if re.search(r"[&+/,]", source) else 30 def normalize_source_text(text: str) -> str: text = re.sub(r"\s+", "", text.strip()) text = re.sub(r"(?i)WEB[_ ]?DL", "WEB-DL", text) text = re.sub(r"(?i)WEB[_ ]?Rip", "WebRip", text) text = re.sub(r"(?i)U[_ ]?NEXT", "U-NEXT", text) text = re.sub(r"(?i)AT[_ ]?X", "AT-X", text) return text.replace("_", "-") def choose_thin_source(sources: List[str]) -> Optional[str]: cleaned = [normalize_source_text(source) for source in sources if normalize_field_text(source)] if not cleaned: return None return max(enumerate(cleaned), key=lambda item: (thin_source_priority(item[1]), -item[0]))[1] def normalize_standalone_special(text: str) -> Optional[str]: special = normalize_field_text(text) if not special: return None return special if STANDALONE_SPECIAL_RE.fullmatch(special) else None def extract_bracketed_search_special(text: str) -> Optional[str]: """Return bracketed search-note tags such as [檢索:...].""" for match in BRACKETED_SEARCH_SPECIAL_RE.finditer(text): special = normalize_field_text(match.group(1)) if special: return special return None def extract_new_show_bracket_title(text: str) -> Optional[str]: """Return title from release-promo layouts like ★04月新番★[葬送的芙莉莲].""" for match in NEW_SHOW_BRACKET_TITLE_RE.finditer(text): title = normalize_field_text(match.group(1)) if title: return title return None def display_token(token: str) -> str: """Make whitespace tokens visible in debug output.""" if token == " ": return "" if token == "\t": return "" return token def trim_decorations(text: str) -> str: """Trim outer release brackets from an extracted entity.""" return text.strip().strip("[]()【】《》()").strip() def join_entity_tokens(tokens: List[str], tokenizer: Optional[AnimeTokenizer] = None) -> str: """Join entity tokens according to the tokenizer granularity.""" if tokenizer is not None and getattr(tokenizer, "tokenizer_variant", "regex") == "char": return "".join(tokens) text = "".join(tokens) if " " in tokens: return text return text def labels_to_entities( tokens: List[str], labels: List[str], tokenizer: Optional[AnimeTokenizer] = None, ) -> List[Tuple[str, str]]: """ Convert BIO labels into entity spans. Illegal orphan I-X labels start a new entity so debug output exposes the model behavior instead of silently dropping tokens. """ entities: List[Tuple[str, str]] = [] current_entity: Optional[str] = None current_tokens: List[str] = [] for token, label in zip(tokens, labels): if label.startswith("B-"): if current_entity: entities.append((current_entity, join_entity_tokens(current_tokens, tokenizer))) current_entity = label[2:] current_tokens = [token] elif label.startswith("I-"): entity_type = label[2:] if current_entity == entity_type: current_tokens.append(token) else: if current_entity: entities.append((current_entity, join_entity_tokens(current_tokens, tokenizer))) current_entity = entity_type current_tokens = [token] else: if current_entity: entities.append((current_entity, join_entity_tokens(current_tokens, tokenizer))) current_entity = None current_tokens = [] if current_entity: entities.append((current_entity, join_entity_tokens(current_tokens, tokenizer))) return entities def is_allowed_bio_transition(previous_label: str, label: str) -> bool: """Return whether previous_label -> label is valid under IOB2.""" if label.startswith("I-"): entity = label[2:] return previous_label in {f"B-{entity}", f"I-{entity}"} return True _BIO_TRANSITION_CACHE: Dict[Tuple[Tuple[int, str], ...], torch.Tensor] = {} def bio_transition_mask(id2label: Dict[int, str]) -> torch.Tensor: """Return cached valid-transition mask shaped [prev_label, next_label].""" key = tuple(sorted((int(label_id), label) for label_id, label in id2label.items())) cached = _BIO_TRANSITION_CACHE.get(key) if cached is not None: return cached num_labels = max(id2label) + 1 if id2label else 0 mask = torch.zeros((num_labels, num_labels), dtype=torch.bool) for prev_id in range(num_labels): prev_label = id2label.get(prev_id, "O") for label_id in range(num_labels): label = id2label.get(label_id, "O") mask[prev_id, label_id] = is_allowed_bio_transition(prev_label, label) _BIO_TRANSITION_CACHE[key] = mask return mask def constrained_bio_decode(emissions: torch.Tensor, id2label: Dict[int, str]) -> List[int]: """ Decode token logits with hard BIO transition constraints. This is a lightweight CRF-style Viterbi decoder without learned transition weights. It prevents impossible orphan I-X spans at inference time. """ if emissions.numel() == 0: return [] num_tokens, num_labels = emissions.shape scores = emissions.detach().cpu() transition_mask = bio_transition_mask(id2label) backpointers = torch.zeros((num_tokens, num_labels), dtype=torch.long) dp = torch.full((num_labels,), float("-inf")) for label_id in range(num_labels): label = id2label.get(label_id, "O") if not label.startswith("I-"): dp[label_id] = scores[0, label_id] for idx in range(1, num_tokens): candidates = dp.unsqueeze(1).expand(num_labels, num_labels) candidates = candidates.masked_fill(~transition_mask, float("-inf")) best_scores, best_prev = candidates.max(dim=0) next_dp = best_scores + scores[idx] backpointers[idx] = best_prev dp = next_dp best_last = int(torch.argmax(dp).item()) decoded = [best_last] for idx in range(num_tokens - 1, 0, -1): decoded.append(int(backpointers[idx, decoded[-1]].item())) decoded.reverse() return decoded def postprocess( tokens: List[str], labels: List[str], tokenizer: Optional[AnimeTokenizer] = None, ) -> Dict: """ Convert BIO-labeled tokens into structured metadata. Merges consecutive B- / I- tokens of the same entity type, then extracts structured fields. """ result: Dict = { "title": None, "season": None, "episode": None, "group": None, "resolution": None, "source": None, "special": None, } entities = labels_to_entities(tokens, labels, tokenizer) grouped_entities: Dict[str, List[str]] = {} for entity_type, text in entities: grouped_entities.setdefault(entity_type, []).append(text) title_fragments = [ cleaned for text in grouped_entities.get("TITLE", []) if (cleaned := normalize_field_text(text)) ] if title_fragments: result["title"] = " ".join(title_fragments) for text in grouped_entities.get("SEASON", []): season_num = extract_season_number(text) if season_num is not None: result["season"] = season_num for text in grouped_entities.get("EPISODE", []): ep_num = extract_episode_number(text) if ep_num is not None: if result["episode"] is None: result["episode"] = ep_num for text in grouped_entities.get("GROUP", []): group = normalize_field_text(text) if result["group"] is None: result["group"] = group for text in grouped_entities.get("SPECIAL", []): special = normalize_field_text(text) result["special"] = special for text in grouped_entities.get("RESOLUTION", []): res = extract_resolution(text) if res: result["resolution"] = res result["source"] = choose_thin_source(grouped_entities.get("SOURCE", [])) whole_text = join_entity_tokens(tokens, tokenizer) new_show_title = extract_new_show_bracket_title(whole_text) if new_show_title is not None and ( result["title"] is None or result["title"].startswith(("★", "☆")) or "新番" in result["title"] or "月番" in result["title"] ): result["title"] = new_show_title search_special = extract_bracketed_search_special(whole_text) if search_special is not None: result["special"] = search_special standalone_special = normalize_standalone_special(whole_text) if standalone_special is not None: result.update( { "title": None, "season": None, "episode": None, "group": None, "resolution": None, "source": None, "special": standalone_special, } ) return result def parse_filename( filename: str, model: BertForTokenClassification, tokenizer: AnimeTokenizer, id2label: Dict[int, str], max_length: int = 64, debug: bool = False, constrain_bio: bool = True, ) -> Dict: """ Parse an anime filename and extract structured metadata. Args: filename: Raw anime filename string. model: Trained BertForTokenClassification model. tokenizer: AnimeTokenizer instance. id2label: Mapping from label ID to label string. max_length: Maximum sequence length (including special tokens). Returns: Dict with parsed fields (title, season, episode, etc.). """ # Tokenize tokens = tokenizer.tokenize(filename) if not tokens: return {"title": None, "season": None, "episode": None, "group": None, "resolution": None, "source": None, "special": None} # Convert to input IDs input_ids = tokenizer.convert_tokens_to_ids(tokens) embedding_size = model.get_input_embeddings().weight.shape[0] out_of_range_tokens = [ token for token, token_id in zip(tokens, input_ids) if token_id >= embedding_size ] if out_of_range_tokens: input_ids = [ token_id if token_id < embedding_size else tokenizer.unk_token_id for token_id in input_ids ] unk_token_id = tokenizer.unk_token_id unk_tokens = [token for token, token_id in zip(tokens, input_ids) if token_id == unk_token_id] # Add special tokens input_ids = [tokenizer.cls_token_id] + input_ids + [tokenizer.sep_token_id] attention_mask = [1] * len(input_ids) # Truncate if needed if len(input_ids) > max_length: input_ids = [input_ids[0]] + input_ids[1:max_length - 1] + [tokenizer.sep_token_id] attention_mask = [1] * len(input_ids) # Pad pad_len = max_length - len(input_ids) if pad_len > 0: input_ids += [tokenizer.pad_token_id] * pad_len attention_mask += [0] * pad_len # Predict device = next(model.parameters()).device input_tensor = torch.tensor([input_ids], device=device) mask_tensor = torch.tensor([attention_mask], device=device) # Remove special token predictions # Count real tokens used (minus CLS/SEP) real_token_count = len(tokens) # Truncate real tokens if we had to truncate available = min(real_token_count, max_length - 2) if available <= 0: return {"title": None, "season": None, "episode": None, "group": None, "resolution": None, "source": None, "special": None} with torch.no_grad(): logits = model(input_ids=input_tensor, attention_mask=mask_tensor).logits token_logits = logits[0, 1:1 + available, :] probabilities = torch.softmax(token_logits, dim=-1) scores, greedy_predictions = torch.max(probabilities, dim=-1) if constrain_bio: pred_labels = constrained_bio_decode(token_logits, id2label) selected_scores = [ probabilities[idx, label_id].detach().cpu().item() for idx, label_id in enumerate(pred_labels) ] else: pred_labels = greedy_predictions.detach().cpu().tolist() selected_scores = scores.detach().cpu().tolist() label_strings = [id2label.get(p, "O") for p in pred_labels] # Post-process result = postprocess( tokens[:available], label_strings, tokenizer=tokenizer, ) if debug: result["_debug"] = { "tokenizer_variant": getattr(tokenizer, "tokenizer_variant", "regex"), "decoder": "constrained_bio" if constrain_bio else "greedy", "postprocess": "thin_normalize", "max_length": max_length, "token_count": len(tokens), "available_token_count": available, "truncated": len(tokens) > available, "unk_count": len(unk_tokens), "unk_rate": len(unk_tokens) / len(tokens) if tokens else 0.0, "unk_tokens": unk_tokens[:50], "vocab_mismatch": bool(out_of_range_tokens), "model_embedding_size": int(embedding_size), "tokenizer_vocab_size": int(tokenizer.vocab_size), "out_of_range_tokens": out_of_range_tokens[:50], "tokens": tokens[:available], "labels": label_strings, "scores": [round(float(score), 4) for score in selected_scores], "token_table": [ { "i": i, "token": display_token(token), "id": int(token_id), "label": label, "score": round(float(score), 4), } for i, (token, token_id, label, score) in enumerate( zip(tokens[:available], input_ids[1:1 + available], label_strings, selected_scores) ) ], "entities": [ {"type": entity_type, "text": text} for entity_type, text in labels_to_entities(tokens[:available], label_strings, tokenizer) ], } return result def main(): parser = argparse.ArgumentParser(description="Anime filename parser") parser.add_argument("filename", nargs="?", type=str, help="Anime filename to parse") parser.add_argument("--input-file", type=str, help="File with filenames (one per line)") parser.add_argument("--output-file", type=str, help="Output file for results (JSONL)") parser.add_argument("--model-dir", type=str, default=".", help="Path to trained model directory") parser.add_argument("--tokenizer", choices=["regex", "char"], default=None, help="Tokenizer variant override. Defaults to checkpoint metadata") parser.add_argument("--max-length", type=int, default=64, help="Maximum sequence length") parser.add_argument("--debug", action="store_true", help="Include tokenizer, labels, scores, and entity spans in JSON output") parser.add_argument("--no-constrained-bio", action="store_true", help="Use greedy per-token decoding instead of constrained BIO Viterbi") args = parser.parse_args() # Load config cfg = Config() # Load tokenizer print(f"Loading tokenizer from {args.model_dir}...", file=sys.stderr) tokenizer = load_tokenizer(args.model_dir, args.tokenizer) # Load model print(f"Loading model from {args.model_dir}...", file=sys.stderr) model = BertForTokenClassification.from_pretrained(args.model_dir) model.eval() id2label = {int(k): v for k, v in getattr(model.config, "id2label", cfg.id2label).items()} max_length = args.max_length if max_length == 64: max_length = int(getattr(model.config, "max_seq_length", max_length)) # Process filenames filenames_to_parse: List[str] = [] if args.filename: filenames_to_parse.append(args.filename) if args.input_file: with open(args.input_file, 'r', encoding='utf-8') as f: filenames_to_parse.extend(line.strip() for line in f if line.strip()) if not filenames_to_parse: # Read from stdin filenames_to_parse.extend(sys.stdin.read().strip().splitlines()) # Parse and output results: List[Dict] = [] for fn in filenames_to_parse: if not fn.strip(): continue result = parse_filename( fn, model, tokenizer, id2label, max_length, debug=args.debug, constrain_bio=not args.no_constrained_bio, ) result["_input"] = fn results.append(result) if args.output_file is None: print(json.dumps(result, ensure_ascii=False)) if args.output_file: with open(args.output_file, 'w', encoding='utf-8') as f: for r in results: f.write(json.dumps(r, ensure_ascii=False) + '\n') print(f"Results saved to {args.output_file}", file=sys.stderr) if __name__ == "__main__": main()