""" Tiny BERT models for anime filename token classification. The default linear token-classification head is kept for compatibility. A learned linear-chain CRF head is also available for structural sequence-label training while preserving the same emission logits used by the thin runtime. """ from __future__ import annotations import os from typing import List, Optional import torch from torch import nn from transformers import BertConfig, BertForTokenClassification, BertModel, BertPreTrainedModel from transformers.modeling_outputs import TokenClassifierOutput from transformers.modeling_utils import PreTrainedModel from .config import Config class LinearChainCRF(nn.Module): """A small batched linear-chain CRF for BIO token classification.""" def __init__(self, num_labels: int, id2label: Optional[dict] = None) -> None: super().__init__() self.num_labels = num_labels self.start_transitions = nn.Parameter(torch.zeros(num_labels)) self.end_transitions = nn.Parameter(torch.zeros(num_labels)) self.transitions = nn.Parameter(torch.zeros(num_labels, num_labels)) self.register_buffer("start_allowed", torch.ones(num_labels, dtype=torch.bool)) self.register_buffer("transition_allowed", torch.ones(num_labels, num_labels, dtype=torch.bool)) if id2label: self._configure_bio_masks(id2label) @staticmethod def _normalize_label_map(id2label: dict) -> dict[int, str]: return {int(label_id): str(label) for label_id, label in id2label.items()} def _configure_bio_masks(self, id2label: dict) -> None: label_map = self._normalize_label_map(id2label) for prev_id in range(self.num_labels): prev_label = label_map.get(prev_id, "O") self.start_allowed[prev_id] = not prev_label.startswith("I-") for next_id in range(self.num_labels): next_label = label_map.get(next_id, "O") if next_label.startswith("I-"): entity = next_label[2:] allowed = prev_label in {f"B-{entity}", f"I-{entity}"} else: allowed = True self.transition_allowed[prev_id, next_id] = allowed def neg_log_likelihood( self, emissions: torch.Tensor, tags: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: """Return mean negative log likelihood for a padded batch.""" if emissions.ndim != 3: raise ValueError("emissions must have shape [batch, seq, labels]") if tags.shape != emissions.shape[:2]: raise ValueError("tags must have shape [batch, seq]") if mask.shape != tags.shape: raise ValueError("mask must have shape [batch, seq]") mask = mask.bool() lengths = mask.long().sum(dim=1) if torch.any(lengths == 0): raise ValueError("CRF received an empty token sequence") safe_tags = tags.masked_fill(~mask, 0) log_partition = self._compute_log_partition(emissions, mask) gold_score = self._compute_gold_score(emissions, safe_tags, mask, lengths) return (log_partition - gold_score).mean() def _compute_log_partition(self, emissions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, _num_labels = emissions.shape emissions = emissions.float() start_transitions = self.start_transitions.float() transition_scores = self.transitions.float() scores = start_transitions + emissions[:, 0] for idx in range(1, sequence_length): next_scores = ( scores.unsqueeze(2) + transition_scores.unsqueeze(0) + emissions[:, idx].unsqueeze(1) ) next_scores = torch.logsumexp(next_scores, dim=1) scores = torch.where(mask[:, idx].unsqueeze(1), next_scores, scores) scores = scores + self.end_transitions return torch.logsumexp(scores, dim=1) def _compute_gold_score( self, emissions: torch.Tensor, tags: torch.Tensor, mask: torch.Tensor, lengths: torch.Tensor, ) -> torch.Tensor: emissions = emissions.float() start_transitions = self.start_transitions.float() transition_scores = self.transitions.float() end_transitions = self.end_transitions.float() batch_indices = torch.arange(emissions.shape[0], device=emissions.device) score = start_transitions[tags[:, 0]] score = score + emissions[batch_indices, 0, tags[:, 0]] for idx in range(1, emissions.shape[1]): transition_score = transition_scores[tags[:, idx - 1], tags[:, idx]] emission_score = emissions[batch_indices, idx, tags[:, idx]] score = score + (transition_score + emission_score) * mask[:, idx] last_tag_indices = (lengths - 1).unsqueeze(1) last_tags = tags.gather(1, last_tag_indices).squeeze(1) return score + end_transitions[last_tags] def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> List[List[int]]: """Viterbi decode a padded batch and return variable-length label IDs.""" if emissions.ndim != 3: raise ValueError("emissions must have shape [batch, seq, labels]") mask = mask.bool() lengths = mask.long().sum(dim=1) if torch.any(lengths == 0): raise ValueError("CRF received an empty token sequence") start_transitions = self.start_transitions.masked_fill(~self.start_allowed, float("-inf")) transition_scores = self.transitions.masked_fill(~self.transition_allowed, float("-inf")) scores = start_transitions + emissions[:, 0] history: List[torch.Tensor] = [] for idx in range(1, emissions.shape[1]): next_scores = scores.unsqueeze(2) + transition_scores.unsqueeze(0) best_scores, best_tags = next_scores.max(dim=1) best_scores = best_scores + emissions[:, idx] scores = torch.where(mask[:, idx].unsqueeze(1), best_scores, scores) history.append(best_tags) scores = scores + self.end_transitions best_last_tags = scores.argmax(dim=1) paths: List[List[int]] = [] for batch_idx in range(emissions.shape[0]): length = int(lengths[batch_idx].item()) best_tag = int(best_last_tags[batch_idx].item()) path = [best_tag] for hist in reversed(history[: max(0, length - 1)]): best_tag = int(hist[batch_idx, best_tag].item()) path.append(best_tag) path.reverse() paths.append(path) return paths class BertCrfForTokenClassification(BertPreTrainedModel): """BERT emission classifier trained with a learned CRF sequence loss.""" config_class = BertConfig def __init__(self, config: BertConfig) -> None: super().__init__(config) self.num_labels = config.num_labels self.bert = BertModel(config, add_pooling_layer=False) classifier_dropout = getattr(config, "classifier_dropout", None) dropout_prob = classifier_dropout if classifier_dropout is not None else config.hidden_dropout_prob self.dropout = nn.Dropout(dropout_prob) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.crf = LinearChainCRF(config.num_labels, getattr(config, "id2label", None)) self.post_init() # Keep CRF transitions neutral when bootstrapping from a linear checkpoint. nn.init.zeros_(self.crf.start_transitions) nn.init.zeros_(self.crf.end_transitions) nn.init.zeros_(self.crf.transitions) def _crf_inputs( self, logits: torch.Tensor, labels: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: if logits.shape[1] <= 2: raise ValueError("CRF token classification expects CLS, tokens, and SEP positions") emissions = logits[:, 1:-1, :] if attention_mask is None: if labels is None: mask = torch.ones(emissions.shape[:2], dtype=torch.bool, device=logits.device) else: mask = labels[:, 1:-1].ne(-100) else: if labels is None: real_lengths = attention_mask.long().sum(dim=1).clamp_min(2) - 2 positions = torch.arange(emissions.shape[1], device=logits.device).unsqueeze(0) mask = positions < real_lengths.unsqueeze(1) else: mask = attention_mask[:, 1:-1].bool() mask = mask & labels[:, 1:-1].ne(-100) tags = labels[:, 1:-1] if labels is not None else None return emissions, tags, mask def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> TokenClassifierOutput: return_dict = return_dict if return_dict is not None else getattr(self.config, "return_dict", True) outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = self.dropout(outputs[0]) logits = self.classifier(sequence_output) loss = None if labels is not None: emissions, tags, mask = self._crf_inputs(logits, labels, attention_mask) if tags is None: raise ValueError("labels are required for CRF loss") loss = self.crf.neg_log_likelihood(emissions, tags, mask) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def decode(self, logits: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> List[List[int]]: """Decode full-sequence logits, excluding CLS/SEP and padding positions.""" emissions, _tags, mask = self._crf_inputs(logits, None, attention_mask) return self.crf.decode(emissions, mask) def build_bert_config(config: Config) -> BertConfig: """Build the Hugging Face BERT config shared by both model heads.""" return BertConfig( vocab_size=config.vocab_size, hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers, num_attention_heads=config.num_attention_heads, intermediate_size=config.intermediate_size, max_position_embeddings=config.max_position_embeddings, num_labels=config.num_labels, hidden_dropout_prob=config.hidden_dropout_prob, attention_probs_dropout_prob=config.attention_probs_dropout_prob, id2label=config.id2label, label2id=config.label2id, ) def normalize_model_head(model_head: Optional[str]) -> str: head = (model_head or "linear").strip().lower() if head not in {"linear", "crf"}: raise ValueError(f"Unsupported model head: {model_head}") return head def create_model(config: Config, model_head: str = "linear") -> PreTrainedModel: """ Create a Tiny BERT model for token classification. Args: config: Config object with model hyperparameters. model_head: ``linear`` for Hugging Face's standard token classifier or ``crf`` for a learned linear-chain CRF sequence head. """ head = normalize_model_head(model_head) bert_config = build_bert_config(config) bert_config.model_head = head if head == "crf": bert_config.architectures = ["BertCrfForTokenClassification"] return BertCrfForTokenClassification(bert_config) bert_config.architectures = ["BertForTokenClassification"] return BertForTokenClassification(bert_config) def infer_model_head(config: BertConfig) -> str: head = getattr(config, "model_head", None) if head: return normalize_model_head(str(head)) architectures = getattr(config, "architectures", None) or [] if any("Crf" in str(architecture) or "CRF" in str(architecture) for architecture in architectures): return "crf" return "linear" def load_model(model_dir: str, model_head: Optional[str] = None) -> PreTrainedModel: """Load a linear or CRF token classifier from a Hugging Face checkpoint.""" config = BertConfig.from_pretrained(model_dir) head = normalize_model_head(model_head) if model_head is not None else infer_model_head(config) if head == "crf": return BertCrfForTokenClassification.from_pretrained(model_dir) return BertForTokenClassification.from_pretrained(model_dir) def save_model_head_config(model: PreTrainedModel, model_head: str) -> None: """Persist the selected head in config.json for later auto-loading.""" head = normalize_model_head(model_head) model.config.model_head = head model.config.architectures = [ "BertCrfForTokenClassification" if head == "crf" else "BertForTokenClassification" ] def count_parameters(model) -> int: """Count total trainable parameters in a model.""" return sum(p.numel() for p in model.parameters()) def print_model_summary(model): """Print model architecture summary with parameter count.""" total_params = count_parameters(model) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Total parameters: {total_params:,}") print(f"Trainable parameters: {trainable_params:,}") print(f"Parameter limit: 5,000,000") if total_params < 5_000_000: print(f"[OK] Within 5M limit ({(5_000_000 - total_params):,} remaining)") else: print(f"[FAIL] Exceeds 5M limit by {total_params - 5_000_000:,}") return total_params if __name__ == "__main__": cfg = Config() cfg.vocab_size = 3000 model = create_model(cfg, model_head=os.environ.get("ANIFILEBERT_MODEL_HEAD", "linear")) print_model_summary(model)