ModerRAS's picture
Implement schema v2 anime filename labels
ed49faa
raw
history blame
19.8 kB
"""
Tiny BERT models for anime filename token classification.
The default linear token-classification head is kept for compatibility. A
learned linear-chain CRF head is also available for structural sequence-label
training while preserving the same emission logits used by the thin runtime.
"""
from __future__ import annotations
import os
from typing import List, Optional
import torch
from torch import nn
from transformers import BertConfig, BertForTokenClassification, BertModel, BertPreTrainedModel
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.modeling_utils import PreTrainedModel
from .config import Config
from .labels import infer_legacy_id2label, label_migration_sources
class LinearChainCRF(nn.Module):
"""A small batched linear-chain CRF for BIO token classification."""
def __init__(self, num_labels: int, id2label: Optional[dict] = None) -> None:
super().__init__()
self.num_labels = num_labels
self.start_transitions = nn.Parameter(torch.zeros(num_labels))
self.end_transitions = nn.Parameter(torch.zeros(num_labels))
self.transitions = nn.Parameter(torch.zeros(num_labels, num_labels))
self.register_buffer("start_allowed", torch.ones(num_labels, dtype=torch.bool))
self.register_buffer("transition_allowed", torch.ones(num_labels, num_labels, dtype=torch.bool))
if id2label:
self._configure_bio_masks(id2label)
@staticmethod
def _normalize_label_map(id2label: dict) -> dict[int, str]:
return {int(label_id): str(label) for label_id, label in id2label.items()}
def _configure_bio_masks(self, id2label: dict) -> None:
label_map = self._normalize_label_map(id2label)
for prev_id in range(self.num_labels):
prev_label = label_map.get(prev_id, "O")
self.start_allowed[prev_id] = not prev_label.startswith("I-")
for next_id in range(self.num_labels):
next_label = label_map.get(next_id, "O")
if next_label.startswith("I-"):
entity = next_label[2:]
allowed = prev_label in {f"B-{entity}", f"I-{entity}"}
else:
allowed = True
self.transition_allowed[prev_id, next_id] = allowed
def neg_log_likelihood(
self,
emissions: torch.Tensor,
tags: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor:
"""Return mean negative log likelihood for a padded batch."""
if emissions.ndim != 3:
raise ValueError("emissions must have shape [batch, seq, labels]")
if tags.shape != emissions.shape[:2]:
raise ValueError("tags must have shape [batch, seq]")
if mask.shape != tags.shape:
raise ValueError("mask must have shape [batch, seq]")
mask = mask.bool()
lengths = mask.long().sum(dim=1)
if torch.any(lengths == 0):
raise ValueError("CRF received an empty token sequence")
safe_tags = tags.masked_fill(~mask, 0)
log_partition = self._compute_log_partition(emissions, mask)
gold_score = self._compute_gold_score(emissions, safe_tags, mask, lengths)
return (log_partition - gold_score).mean()
def _compute_log_partition(self, emissions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, _num_labels = emissions.shape
emissions = emissions.float()
start_transitions = self.start_transitions.float()
transition_scores = self.transitions.float()
scores = start_transitions + emissions[:, 0]
for idx in range(1, sequence_length):
next_scores = (
scores.unsqueeze(2)
+ transition_scores.unsqueeze(0)
+ emissions[:, idx].unsqueeze(1)
)
next_scores = torch.logsumexp(next_scores, dim=1)
scores = torch.where(mask[:, idx].unsqueeze(1), next_scores, scores)
scores = scores + self.end_transitions
return torch.logsumexp(scores, dim=1)
def _compute_gold_score(
self,
emissions: torch.Tensor,
tags: torch.Tensor,
mask: torch.Tensor,
lengths: torch.Tensor,
) -> torch.Tensor:
emissions = emissions.float()
start_transitions = self.start_transitions.float()
transition_scores = self.transitions.float()
end_transitions = self.end_transitions.float()
batch_indices = torch.arange(emissions.shape[0], device=emissions.device)
score = start_transitions[tags[:, 0]]
score = score + emissions[batch_indices, 0, tags[:, 0]]
for idx in range(1, emissions.shape[1]):
transition_score = transition_scores[tags[:, idx - 1], tags[:, idx]]
emission_score = emissions[batch_indices, idx, tags[:, idx]]
score = score + (transition_score + emission_score) * mask[:, idx]
last_tag_indices = (lengths - 1).unsqueeze(1)
last_tags = tags.gather(1, last_tag_indices).squeeze(1)
return score + end_transitions[last_tags]
def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> List[List[int]]:
"""Viterbi decode a padded batch and return variable-length label IDs."""
if emissions.ndim != 3:
raise ValueError("emissions must have shape [batch, seq, labels]")
mask = mask.bool()
lengths = mask.long().sum(dim=1)
if torch.any(lengths == 0):
raise ValueError("CRF received an empty token sequence")
start_transitions = self.start_transitions.masked_fill(~self.start_allowed, float("-inf"))
transition_scores = self.transitions.masked_fill(~self.transition_allowed, float("-inf"))
scores = start_transitions + emissions[:, 0]
history: List[torch.Tensor] = []
for idx in range(1, emissions.shape[1]):
next_scores = scores.unsqueeze(2) + transition_scores.unsqueeze(0)
best_scores, best_tags = next_scores.max(dim=1)
best_scores = best_scores + emissions[:, idx]
scores = torch.where(mask[:, idx].unsqueeze(1), best_scores, scores)
history.append(best_tags)
scores = scores + self.end_transitions
best_last_tags = scores.argmax(dim=1)
paths: List[List[int]] = []
for batch_idx in range(emissions.shape[0]):
length = int(lengths[batch_idx].item())
best_tag = int(best_last_tags[batch_idx].item())
path = [best_tag]
for hist in reversed(history[: max(0, length - 1)]):
best_tag = int(hist[batch_idx, best_tag].item())
path.append(best_tag)
path.reverse()
paths.append(path)
return paths
class BertCrfForTokenClassification(BertPreTrainedModel):
"""BERT emission classifier trained with a learned CRF sequence loss."""
config_class = BertConfig
def __init__(self, config: BertConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
self.bert = BertModel(config, add_pooling_layer=False)
classifier_dropout = getattr(config, "classifier_dropout", None)
dropout_prob = classifier_dropout if classifier_dropout is not None else config.hidden_dropout_prob
self.dropout = nn.Dropout(dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.crf = LinearChainCRF(config.num_labels, getattr(config, "id2label", None))
self.post_init()
# Keep CRF transitions neutral when bootstrapping from a linear checkpoint.
nn.init.zeros_(self.crf.start_transitions)
nn.init.zeros_(self.crf.end_transitions)
nn.init.zeros_(self.crf.transitions)
def _crf_inputs(
self,
logits: torch.Tensor,
labels: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor],
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
if logits.shape[1] <= 2:
raise ValueError("CRF token classification expects CLS, tokens, and SEP positions")
emissions = logits[:, 1:-1, :]
if attention_mask is None:
if labels is None:
mask = torch.ones(emissions.shape[:2], dtype=torch.bool, device=logits.device)
else:
mask = labels[:, 1:-1].ne(-100)
else:
if labels is None:
real_lengths = attention_mask.long().sum(dim=1).clamp_min(2) - 2
positions = torch.arange(emissions.shape[1], device=logits.device).unsqueeze(0)
mask = positions < real_lengths.unsqueeze(1)
else:
mask = attention_mask[:, 1:-1].bool()
mask = mask & labels[:, 1:-1].ne(-100)
tags = labels[:, 1:-1] if labels is not None else None
return emissions, tags, mask
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> TokenClassifierOutput:
return_dict = return_dict if return_dict is not None else getattr(self.config, "return_dict", True)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = self.dropout(outputs[0])
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
emissions, tags, mask = self._crf_inputs(logits, labels, attention_mask)
if tags is None:
raise ValueError("labels are required for CRF loss")
loss = self.crf.neg_log_likelihood(emissions, tags, mask)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def decode(self, logits: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> List[List[int]]:
"""Decode full-sequence logits, excluding CLS/SEP and padding positions."""
emissions, _tags, mask = self._crf_inputs(logits, None, attention_mask)
return self.crf.decode(emissions, mask)
def build_bert_config(config: Config) -> BertConfig:
"""Build the Hugging Face BERT config shared by both model heads."""
return BertConfig(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
max_position_embeddings=config.max_position_embeddings,
num_labels=config.num_labels,
hidden_dropout_prob=config.hidden_dropout_prob,
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
id2label=config.id2label,
label2id=config.label2id,
label_schema_version=config.label_schema_version,
)
def normalize_model_head(model_head: Optional[str]) -> str:
head = (model_head or "linear").strip().lower()
if head not in {"linear", "crf"}:
raise ValueError(f"Unsupported model head: {model_head}")
return head
def create_model(config: Config, model_head: str = "linear") -> PreTrainedModel:
"""
Create a Tiny BERT model for token classification.
Args:
config: Config object with model hyperparameters.
model_head: ``linear`` for Hugging Face's standard token classifier or
``crf`` for a learned linear-chain CRF sequence head.
"""
head = normalize_model_head(model_head)
bert_config = build_bert_config(config)
bert_config.model_head = head
if head == "crf":
bert_config.architectures = ["BertCrfForTokenClassification"]
return BertCrfForTokenClassification(bert_config)
bert_config.architectures = ["BertForTokenClassification"]
return BertForTokenClassification(bert_config)
def infer_model_head(config: BertConfig) -> str:
head = getattr(config, "model_head", None)
if head:
return normalize_model_head(str(head))
architectures = getattr(config, "architectures", None) or []
if any("Crf" in str(architecture) or "CRF" in str(architecture) for architecture in architectures):
return "crf"
return "linear"
def load_model(model_dir: str, model_head: Optional[str] = None) -> PreTrainedModel:
"""Load a linear or CRF token classifier from a Hugging Face checkpoint."""
config = BertConfig.from_pretrained(model_dir)
head = normalize_model_head(model_head) if model_head is not None else infer_model_head(config)
if head == "crf":
return BertCrfForTokenClassification.from_pretrained(model_dir)
return BertForTokenClassification.from_pretrained(model_dir)
def _model_id2label_for_migration(model: PreTrainedModel) -> dict[int, str]:
raw_id2label = getattr(model.config, "id2label", None) or {}
normalized = {int(label_id): str(label) for label_id, label in raw_id2label.items()}
classifier = getattr(model, "classifier", None)
out_features = getattr(classifier, "out_features", None)
if out_features is not None and len(normalized) != int(out_features):
inferred = infer_legacy_id2label(int(out_features))
if inferred is not None:
return inferred
return normalized
def migrate_token_classifier_labels(
model: PreTrainedModel,
target_label2id: dict[str, int],
target_id2label: dict[int, str],
) -> dict[str, object]:
"""
Expand or reorder token-classification label rows for the shared schema.
Exact labels are copied by name. Legacy 15-label TITLE rows initialize all
title-like rows, and legacy SEASON rows initialize PATH_SEASON.
"""
classifier = getattr(model, "classifier", None)
if classifier is None or not isinstance(classifier, nn.Linear):
return {"changed": False, "reason": "no_linear_classifier"}
target_id2label = {int(label_id): str(label) for label_id, label in target_id2label.items()}
target_label2id = {str(label): int(label_id) for label, label_id in target_label2id.items()}
old_id2label = _model_id2label_for_migration(model)
old_label2id = {label: label_id for label_id, label in old_id2label.items()}
old_num_labels = int(classifier.out_features)
new_num_labels = len(target_label2id)
same_schema = (
old_num_labels == new_num_labels
and all(old_id2label.get(idx) == target_id2label.get(idx) for idx in range(new_num_labels))
)
if same_schema:
model.config.num_labels = new_num_labels
model.config.id2label = target_id2label
model.config.label2id = target_label2id
return {"changed": False, "copied": new_num_labels, "target_labels": new_num_labels}
old_weight = classifier.weight.detach()
old_bias = classifier.bias.detach() if classifier.bias is not None else None
new_classifier = nn.Linear(
classifier.in_features,
new_num_labels,
bias=classifier.bias is not None,
device=old_weight.device,
dtype=old_weight.dtype,
)
nn.init.normal_(
new_classifier.weight,
mean=0.0,
std=getattr(model.config, "initializer_range", 0.02),
)
if new_classifier.bias is not None:
nn.init.zeros_(new_classifier.bias)
row_sources: dict[int, int] = {}
copied = 0
for target_label, target_id in target_label2id.items():
for source_label in label_migration_sources(target_label):
source_id = old_label2id.get(source_label)
if source_id is None or source_id >= old_num_labels:
continue
new_classifier.weight.data[target_id].copy_(old_weight[source_id])
if new_classifier.bias is not None and old_bias is not None:
new_classifier.bias.data[target_id].copy_(old_bias[source_id])
row_sources[target_id] = source_id
copied += 1
break
model.classifier = new_classifier
model.num_labels = new_num_labels
model.config.num_labels = new_num_labels
model.config.id2label = target_id2label
model.config.label2id = target_label2id
if hasattr(model, "crf"):
old_crf = model.crf
new_crf = LinearChainCRF(new_num_labels, target_id2label).to(
device=old_weight.device,
dtype=old_weight.dtype,
)
nn.init.zeros_(new_crf.start_transitions)
nn.init.zeros_(new_crf.end_transitions)
nn.init.zeros_(new_crf.transitions)
with torch.no_grad():
for target_id, source_id in row_sources.items():
if source_id < old_crf.start_transitions.shape[0]:
new_crf.start_transitions[target_id].copy_(old_crf.start_transitions[source_id])
new_crf.end_transitions[target_id].copy_(old_crf.end_transitions[source_id])
for target_to_id, source_to_id in row_sources.items():
for target_from_id, source_from_id in row_sources.items():
if (
source_from_id < old_crf.transitions.shape[0]
and source_to_id < old_crf.transitions.shape[1]
):
new_crf.transitions[target_from_id, target_to_id].copy_(
old_crf.transitions[source_from_id, source_to_id]
)
model.crf = new_crf
return {
"changed": True,
"source_labels": old_num_labels,
"target_labels": new_num_labels,
"copied": copied,
}
def save_model_head_config(model: PreTrainedModel, model_head: str) -> None:
"""Persist the selected head in config.json for later auto-loading."""
head = normalize_model_head(model_head)
model.config.model_head = head
model.config.architectures = [
"BertCrfForTokenClassification" if head == "crf" else "BertForTokenClassification"
]
def count_parameters(model) -> int:
"""Count total trainable parameters in a model."""
return sum(p.numel() for p in model.parameters())
def print_model_summary(model):
"""Print model architecture summary with parameter count."""
total_params = count_parameters(model)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Parameter limit: 5,000,000")
if total_params < 5_000_000:
print(f"[OK] Within 5M limit ({(5_000_000 - total_params):,} remaining)")
else:
print(f"[FAIL] Exceeds 5M limit by {total_params - 5_000_000:,}")
return total_params
if __name__ == "__main__":
cfg = Config()
cfg.vocab_size = 3000
model = create_model(cfg, model_head=os.environ.get("ANIFILEBERT_MODEL_HEAD", "linear"))
print_model_summary(model)