x2aqq's picture
Upload folder using huggingface_hub
47bc13b verified
"""
BERT-CRF Model for Indian Address NER.
Combines a multilingual BERT encoder with a Conditional Random Field (CRF)
layer for improved sequence labeling performance.
"""
import torch
import torch.nn as nn
from transformers import AutoModel
from transformers.modeling_outputs import TokenClassifierOutput
from address_parser.models.config import ID2LABEL, LABEL2ID, ModelConfig
class CRF(nn.Module):
"""
Conditional Random Field layer for sequence labeling.
Implements the forward algorithm for computing log-likelihood
and Viterbi decoding for inference.
"""
def __init__(self, num_tags: int, batch_first: bool = True):
"""
Initialize CRF layer.
Args:
num_tags: Number of output tags
batch_first: If True, input is (batch, seq, features)
"""
super().__init__()
self.num_tags = num_tags
self.batch_first = batch_first
# Transition matrix: transitions[i, j] = score of transitioning from tag i to tag j
self.transitions = nn.Parameter(torch.randn(num_tags, num_tags))
# Start and end transition scores
self.start_transitions = nn.Parameter(torch.randn(num_tags))
self.end_transitions = nn.Parameter(torch.randn(num_tags))
self._init_transitions()
def _init_transitions(self):
"""Initialize transition parameters."""
nn.init.uniform_(self.transitions, -0.1, 0.1)
nn.init.uniform_(self.start_transitions, -0.1, 0.1)
nn.init.uniform_(self.end_transitions, -0.1, 0.1)
def forward(
self,
emissions: torch.Tensor,
tags: torch.LongTensor,
mask: torch.ByteTensor | None = None,
reduction: str = "mean",
) -> torch.Tensor:
"""
Compute negative log-likelihood loss.
Args:
emissions: Emission scores from BERT (batch, seq, num_tags)
tags: Gold standard tags (batch, seq)
mask: Mask for valid tokens (batch, seq)
reduction: 'mean', 'sum', or 'none'
Returns:
Negative log-likelihood loss
"""
if mask is None:
mask = torch.ones_like(tags, dtype=torch.bool)
if self.batch_first:
emissions = emissions.transpose(0, 1)
tags = tags.transpose(0, 1)
mask = mask.transpose(0, 1)
# Compute log-likelihood
numerator = self._compute_score(emissions, tags, mask)
denominator = self._compute_normalizer(emissions, mask)
llh = numerator - denominator
if reduction == "mean":
return -llh.mean()
elif reduction == "sum":
return -llh.sum()
else:
return -llh
def decode(
self,
emissions: torch.Tensor,
mask: torch.ByteTensor | None = None,
) -> list[list[int]]:
"""
Find the most likely tag sequence using Viterbi algorithm.
Args:
emissions: Emission scores (batch, seq, num_tags)
mask: Mask for valid tokens (batch, seq)
Returns:
List of best tag sequences for each sample
"""
if mask is None:
mask = torch.ones(emissions.shape[:2], dtype=torch.bool, device=emissions.device)
if self.batch_first:
emissions = emissions.transpose(0, 1)
mask = mask.transpose(0, 1)
return self._viterbi_decode(emissions, mask)
def _compute_score(
self,
emissions: torch.Tensor,
tags: torch.LongTensor,
mask: torch.BoolTensor
) -> torch.Tensor:
"""Compute the score of a given tag sequence."""
seq_length, batch_size = tags.shape
mask = mask.float()
# Start transition score
score = self.start_transitions[tags[0]]
for i in range(seq_length - 1):
current_tag = tags[i]
next_tag = tags[i + 1]
# Emission score
score += emissions[i, torch.arange(batch_size), current_tag] * mask[i]
# Transition score
score += self.transitions[current_tag, next_tag] * mask[i + 1]
# Last emission score
last_tag_idx = mask.long().sum(dim=0) - 1
last_tags = tags.gather(0, last_tag_idx.unsqueeze(0)).squeeze(0)
score += emissions[last_tag_idx, torch.arange(batch_size), last_tags]
# End transition score
score += self.end_transitions[last_tags]
return score
def _compute_normalizer(
self,
emissions: torch.Tensor,
mask: torch.BoolTensor
) -> torch.Tensor:
"""Compute log-sum-exp of all possible tag sequences (partition function)."""
seq_length = emissions.shape[0]
# Initialize with start transitions
score = self.start_transitions + emissions[0]
for i in range(1, seq_length):
# Broadcast score and transitions for all combinations
broadcast_score = score.unsqueeze(2)
broadcast_emissions = emissions[i].unsqueeze(1)
# Compute next scores
next_score = broadcast_score + self.transitions + broadcast_emissions
# Log-sum-exp
next_score = torch.logsumexp(next_score, dim=1)
# Mask
score = torch.where(mask[i].unsqueeze(1), next_score, score)
# Add end transitions
score += self.end_transitions
return torch.logsumexp(score, dim=1)
def _viterbi_decode(
self,
emissions: torch.Tensor,
mask: torch.BoolTensor
) -> list[list[int]]:
"""Viterbi decoding to find best tag sequence."""
seq_length, batch_size, num_tags = emissions.shape
# Initialize
score = self.start_transitions + emissions[0]
history = []
for i in range(1, seq_length):
broadcast_score = score.unsqueeze(2)
broadcast_emissions = emissions[i].unsqueeze(1)
next_score = broadcast_score + self.transitions + broadcast_emissions
# Find best previous tag for each current tag
next_score, indices = next_score.max(dim=1)
# Apply mask
score = torch.where(mask[i].unsqueeze(1), next_score, score)
history.append(indices)
# Add end transitions
score += self.end_transitions
# Backtrack
seq_ends = mask.long().sum(dim=0) - 1
best_tags_list = []
for batch_idx in range(batch_size):
# Best last tag
_, best_last_tag = score[batch_idx].max(dim=0)
best_tags = [best_last_tag.item()]
# Backtrack through history
for hist in reversed(history[:seq_ends[batch_idx]]):
best_last_tag = hist[batch_idx][best_tags[-1]]
best_tags.append(best_last_tag.item())
best_tags.reverse()
best_tags_list.append(best_tags)
return best_tags_list
class BertCRFForTokenClassification(nn.Module):
"""
BERT model with CRF layer for token classification.
This combines a multilingual BERT encoder with a CRF layer
for improved sequence labeling on NER tasks.
"""
def __init__(self, config: ModelConfig):
"""
Initialize BERT-CRF model.
Args:
config: Model configuration
"""
super().__init__()
self.config = config
self.num_labels = config.num_labels
# Load pretrained BERT
self.bert = AutoModel.from_pretrained(
config.model_name,
cache_dir=config.cache_dir,
)
# Dropout
self.dropout = nn.Dropout(config.classifier_dropout)
# Classification head
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# CRF layer
if config.use_crf:
self.crf = CRF(num_tags=config.num_labels, batch_first=True)
else:
self.crf = None
# Label mappings
self.id2label = ID2LABEL
self.label2id = LABEL2ID
# PyTorch 2.9: Lazy compilation for optimized inference
self._compiled_forward: nn.Module | None = None
def _get_compiled_forward(self):
"""Lazy compile forward pass on first inference call."""
# Skip torch.compile on Windows without MSVC or when explicitly disabled
# The inductor backend requires a C++ compiler (cl on Windows, gcc/clang on Linux)
import os
import sys
skip_compile = (
os.environ.get("TORCH_COMPILE_DISABLE", "0") == "1"
or sys.platform == "win32" # Skip on Windows to avoid cl requirement
)
if self._compiled_forward is None:
if not skip_compile and hasattr(torch, "compile"):
try:
self._compiled_forward = torch.compile(
self.forward,
backend="inductor",
mode="reduce-overhead",
dynamic=True,
)
except Exception:
self._compiled_forward = self.forward
else:
self._compiled_forward = self.forward
return self._compiled_forward
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
token_type_ids: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
return_dict: bool = True,
):
"""
Forward pass.
Args:
input_ids: Input token IDs (batch, seq)
attention_mask: Attention mask (batch, seq)
token_type_ids: Token type IDs (batch, seq)
labels: Gold standard labels for training (batch, seq)
return_dict: Return as dict or tuple
Returns:
TokenClassifierOutput with loss, logits, hidden states
"""
# BERT encoding
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
sequence_output = outputs.last_hidden_state
sequence_output = self.dropout(sequence_output)
# Classification logits
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
if self.crf is not None:
# CRF loss - need to handle -100 (ignore_index) labels
mask = attention_mask.bool() if attention_mask is not None else None
# Replace -100 with 0 (will be masked out anyway)
crf_labels = labels.clone()
crf_labels[crf_labels == -100] = 0
loss = self.crf(logits, crf_labels, mask=mask, reduction=self.config.crf_reduction)
else:
# Standard cross-entropy
loss_fct = nn.CrossEntropyLoss()
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
loss = loss_fct(active_logits, active_labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def decode(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
token_type_ids: torch.Tensor | None = None,
) -> list[list[int]]:
"""
Decode input to tag sequences using compiled forward pass.
Args:
input_ids: Input token IDs (batch, seq)
attention_mask: Attention mask (batch, seq)
token_type_ids: Token type IDs (batch, seq)
Returns:
List of predicted tag sequences
"""
self.eval()
with torch.no_grad():
# Use compiled forward for optimized inference (PyTorch 2.9+)
forward_fn = self._get_compiled_forward()
outputs = forward_fn(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
logits = outputs.logits
if self.crf is not None:
mask = attention_mask.bool() if attention_mask is not None else None
predictions = self.crf.decode(logits, mask=mask)
else:
predictions = logits.argmax(dim=-1).tolist()
return predictions
def save_pretrained(self, save_directory: str):
"""Save model to directory."""
import json
import os
os.makedirs(save_directory, exist_ok=True)
# Save model weights
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
# Save config
config_dict = {
"model_name": self.config.model_name,
"num_labels": self.config.num_labels,
"use_crf": self.config.use_crf,
"hidden_size": self.config.hidden_size,
"classifier_dropout": self.config.classifier_dropout,
"id2label": self.id2label,
"label2id": self.label2id,
}
with open(os.path.join(save_directory, "config.json"), "w") as f:
json.dump(config_dict, f, indent=2)
@classmethod
def from_pretrained(cls, model_path: str, device: str = "cpu"):
"""Load model from directory."""
import json
with open(f"{model_path}/config.json") as f:
config_dict = json.load(f)
config = ModelConfig(
model_name=config_dict["model_name"],
num_labels=config_dict["num_labels"],
use_crf=config_dict["use_crf"],
hidden_size=config_dict["hidden_size"],
classifier_dropout=config_dict["classifier_dropout"],
)
model = cls(config)
state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location=device)
model.load_state_dict(state_dict)
model.to(device)
return model