| """
|
| BERT-CRF Model for Indian Address NER.
|
|
|
| Combines a multilingual BERT encoder with a Conditional Random Field (CRF)
|
| layer for improved sequence labeling performance.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| from transformers import AutoModel
|
| from transformers.modeling_outputs import TokenClassifierOutput
|
|
|
| from address_parser.models.config import ID2LABEL, LABEL2ID, ModelConfig
|
|
|
|
|
| class CRF(nn.Module):
|
| """
|
| Conditional Random Field layer for sequence labeling.
|
|
|
| Implements the forward algorithm for computing log-likelihood
|
| and Viterbi decoding for inference.
|
| """
|
|
|
| def __init__(self, num_tags: int, batch_first: bool = True):
|
| """
|
| Initialize CRF layer.
|
|
|
| Args:
|
| num_tags: Number of output tags
|
| batch_first: If True, input is (batch, seq, features)
|
| """
|
| super().__init__()
|
| self.num_tags = num_tags
|
| self.batch_first = batch_first
|
|
|
|
|
| self.transitions = nn.Parameter(torch.randn(num_tags, num_tags))
|
|
|
|
|
| self.start_transitions = nn.Parameter(torch.randn(num_tags))
|
| self.end_transitions = nn.Parameter(torch.randn(num_tags))
|
|
|
| self._init_transitions()
|
|
|
| def _init_transitions(self):
|
| """Initialize transition parameters."""
|
| nn.init.uniform_(self.transitions, -0.1, 0.1)
|
| nn.init.uniform_(self.start_transitions, -0.1, 0.1)
|
| nn.init.uniform_(self.end_transitions, -0.1, 0.1)
|
|
|
| def forward(
|
| self,
|
| emissions: torch.Tensor,
|
| tags: torch.LongTensor,
|
| mask: torch.ByteTensor | None = None,
|
| reduction: str = "mean",
|
| ) -> torch.Tensor:
|
| """
|
| Compute negative log-likelihood loss.
|
|
|
| Args:
|
| emissions: Emission scores from BERT (batch, seq, num_tags)
|
| tags: Gold standard tags (batch, seq)
|
| mask: Mask for valid tokens (batch, seq)
|
| reduction: 'mean', 'sum', or 'none'
|
|
|
| Returns:
|
| Negative log-likelihood loss
|
| """
|
| if mask is None:
|
| mask = torch.ones_like(tags, dtype=torch.bool)
|
|
|
| if self.batch_first:
|
| emissions = emissions.transpose(0, 1)
|
| tags = tags.transpose(0, 1)
|
| mask = mask.transpose(0, 1)
|
|
|
|
|
| numerator = self._compute_score(emissions, tags, mask)
|
| denominator = self._compute_normalizer(emissions, mask)
|
| llh = numerator - denominator
|
|
|
| if reduction == "mean":
|
| return -llh.mean()
|
| elif reduction == "sum":
|
| return -llh.sum()
|
| else:
|
| return -llh
|
|
|
| def decode(
|
| self,
|
| emissions: torch.Tensor,
|
| mask: torch.ByteTensor | None = None,
|
| ) -> list[list[int]]:
|
| """
|
| Find the most likely tag sequence using Viterbi algorithm.
|
|
|
| Args:
|
| emissions: Emission scores (batch, seq, num_tags)
|
| mask: Mask for valid tokens (batch, seq)
|
|
|
| Returns:
|
| List of best tag sequences for each sample
|
| """
|
| if mask is None:
|
| mask = torch.ones(emissions.shape[:2], dtype=torch.bool, device=emissions.device)
|
|
|
| if self.batch_first:
|
| emissions = emissions.transpose(0, 1)
|
| mask = mask.transpose(0, 1)
|
|
|
| return self._viterbi_decode(emissions, mask)
|
|
|
| def _compute_score(
|
| self,
|
| emissions: torch.Tensor,
|
| tags: torch.LongTensor,
|
| mask: torch.BoolTensor
|
| ) -> torch.Tensor:
|
| """Compute the score of a given tag sequence."""
|
| seq_length, batch_size = tags.shape
|
| mask = mask.float()
|
|
|
|
|
| score = self.start_transitions[tags[0]]
|
|
|
| for i in range(seq_length - 1):
|
| current_tag = tags[i]
|
| next_tag = tags[i + 1]
|
|
|
|
|
| score += emissions[i, torch.arange(batch_size), current_tag] * mask[i]
|
|
|
|
|
| score += self.transitions[current_tag, next_tag] * mask[i + 1]
|
|
|
|
|
| last_tag_idx = mask.long().sum(dim=0) - 1
|
| last_tags = tags.gather(0, last_tag_idx.unsqueeze(0)).squeeze(0)
|
| score += emissions[last_tag_idx, torch.arange(batch_size), last_tags]
|
|
|
|
|
| score += self.end_transitions[last_tags]
|
|
|
| return score
|
|
|
| def _compute_normalizer(
|
| self,
|
| emissions: torch.Tensor,
|
| mask: torch.BoolTensor
|
| ) -> torch.Tensor:
|
| """Compute log-sum-exp of all possible tag sequences (partition function)."""
|
| seq_length = emissions.shape[0]
|
|
|
|
|
| score = self.start_transitions + emissions[0]
|
|
|
| for i in range(1, seq_length):
|
|
|
| broadcast_score = score.unsqueeze(2)
|
| broadcast_emissions = emissions[i].unsqueeze(1)
|
|
|
|
|
| next_score = broadcast_score + self.transitions + broadcast_emissions
|
|
|
|
|
| next_score = torch.logsumexp(next_score, dim=1)
|
|
|
|
|
| score = torch.where(mask[i].unsqueeze(1), next_score, score)
|
|
|
|
|
| score += self.end_transitions
|
|
|
| return torch.logsumexp(score, dim=1)
|
|
|
| def _viterbi_decode(
|
| self,
|
| emissions: torch.Tensor,
|
| mask: torch.BoolTensor
|
| ) -> list[list[int]]:
|
| """Viterbi decoding to find best tag sequence."""
|
| seq_length, batch_size, num_tags = emissions.shape
|
|
|
|
|
| score = self.start_transitions + emissions[0]
|
| history = []
|
|
|
| for i in range(1, seq_length):
|
| broadcast_score = score.unsqueeze(2)
|
| broadcast_emissions = emissions[i].unsqueeze(1)
|
|
|
| next_score = broadcast_score + self.transitions + broadcast_emissions
|
|
|
|
|
| next_score, indices = next_score.max(dim=1)
|
|
|
|
|
| score = torch.where(mask[i].unsqueeze(1), next_score, score)
|
| history.append(indices)
|
|
|
|
|
| score += self.end_transitions
|
|
|
|
|
| seq_ends = mask.long().sum(dim=0) - 1
|
| best_tags_list = []
|
|
|
| for batch_idx in range(batch_size):
|
|
|
| _, best_last_tag = score[batch_idx].max(dim=0)
|
| best_tags = [best_last_tag.item()]
|
|
|
|
|
| for hist in reversed(history[:seq_ends[batch_idx]]):
|
| best_last_tag = hist[batch_idx][best_tags[-1]]
|
| best_tags.append(best_last_tag.item())
|
|
|
| best_tags.reverse()
|
| best_tags_list.append(best_tags)
|
|
|
| return best_tags_list
|
|
|
|
|
| class BertCRFForTokenClassification(nn.Module):
|
| """
|
| BERT model with CRF layer for token classification.
|
|
|
| This combines a multilingual BERT encoder with a CRF layer
|
| for improved sequence labeling on NER tasks.
|
| """
|
|
|
| def __init__(self, config: ModelConfig):
|
| """
|
| Initialize BERT-CRF model.
|
|
|
| Args:
|
| config: Model configuration
|
| """
|
| super().__init__()
|
| self.config = config
|
| self.num_labels = config.num_labels
|
|
|
|
|
| self.bert = AutoModel.from_pretrained(
|
| config.model_name,
|
| cache_dir=config.cache_dir,
|
| )
|
|
|
|
|
| self.dropout = nn.Dropout(config.classifier_dropout)
|
|
|
|
|
| self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
|
|
|
|
| if config.use_crf:
|
| self.crf = CRF(num_tags=config.num_labels, batch_first=True)
|
| else:
|
| self.crf = None
|
|
|
|
|
| self.id2label = ID2LABEL
|
| self.label2id = LABEL2ID
|
|
|
|
|
| self._compiled_forward: nn.Module | None = None
|
|
|
| def _get_compiled_forward(self):
|
| """Lazy compile forward pass on first inference call."""
|
|
|
|
|
| import os
|
| import sys
|
|
|
| skip_compile = (
|
| os.environ.get("TORCH_COMPILE_DISABLE", "0") == "1"
|
| or sys.platform == "win32"
|
| )
|
|
|
| if self._compiled_forward is None:
|
| if not skip_compile and hasattr(torch, "compile"):
|
| try:
|
| self._compiled_forward = torch.compile(
|
| self.forward,
|
| backend="inductor",
|
| mode="reduce-overhead",
|
| dynamic=True,
|
| )
|
| except Exception:
|
| self._compiled_forward = self.forward
|
| else:
|
| self._compiled_forward = self.forward
|
| return self._compiled_forward
|
|
|
| def forward(
|
| self,
|
| input_ids: torch.Tensor,
|
| attention_mask: torch.Tensor | None = None,
|
| token_type_ids: torch.Tensor | None = None,
|
| labels: torch.Tensor | None = None,
|
| return_dict: bool = True,
|
| ):
|
| """
|
| Forward pass.
|
|
|
| Args:
|
| input_ids: Input token IDs (batch, seq)
|
| attention_mask: Attention mask (batch, seq)
|
| token_type_ids: Token type IDs (batch, seq)
|
| labels: Gold standard labels for training (batch, seq)
|
| return_dict: Return as dict or tuple
|
|
|
| Returns:
|
| TokenClassifierOutput with loss, logits, hidden states
|
| """
|
|
|
| outputs = self.bert(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask,
|
| token_type_ids=token_type_ids,
|
| )
|
|
|
| sequence_output = outputs.last_hidden_state
|
| sequence_output = self.dropout(sequence_output)
|
|
|
|
|
| logits = self.classifier(sequence_output)
|
|
|
| loss = None
|
| if labels is not None:
|
| if self.crf is not None:
|
|
|
| mask = attention_mask.bool() if attention_mask is not None else None
|
|
|
| crf_labels = labels.clone()
|
| crf_labels[crf_labels == -100] = 0
|
| loss = self.crf(logits, crf_labels, mask=mask, reduction=self.config.crf_reduction)
|
| else:
|
|
|
| loss_fct = nn.CrossEntropyLoss()
|
| active_loss = attention_mask.view(-1) == 1
|
| active_logits = logits.view(-1, self.num_labels)[active_loss]
|
| active_labels = labels.view(-1)[active_loss]
|
| loss = loss_fct(active_logits, active_labels)
|
|
|
| if not return_dict:
|
| output = (logits,) + outputs[2:]
|
| return ((loss,) + output) if loss is not None else output
|
|
|
| return TokenClassifierOutput(
|
| loss=loss,
|
| logits=logits,
|
| hidden_states=outputs.hidden_states,
|
| attentions=outputs.attentions,
|
| )
|
|
|
| def decode(
|
| self,
|
| input_ids: torch.Tensor,
|
| attention_mask: torch.Tensor | None = None,
|
| token_type_ids: torch.Tensor | None = None,
|
| ) -> list[list[int]]:
|
| """
|
| Decode input to tag sequences using compiled forward pass.
|
|
|
| Args:
|
| input_ids: Input token IDs (batch, seq)
|
| attention_mask: Attention mask (batch, seq)
|
| token_type_ids: Token type IDs (batch, seq)
|
|
|
| Returns:
|
| List of predicted tag sequences
|
| """
|
| self.eval()
|
| with torch.no_grad():
|
|
|
| forward_fn = self._get_compiled_forward()
|
| outputs = forward_fn(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask,
|
| token_type_ids=token_type_ids,
|
| )
|
|
|
| logits = outputs.logits
|
|
|
| if self.crf is not None:
|
| mask = attention_mask.bool() if attention_mask is not None else None
|
| predictions = self.crf.decode(logits, mask=mask)
|
| else:
|
| predictions = logits.argmax(dim=-1).tolist()
|
|
|
| return predictions
|
|
|
| def save_pretrained(self, save_directory: str):
|
| """Save model to directory."""
|
| import json
|
| import os
|
|
|
| os.makedirs(save_directory, exist_ok=True)
|
|
|
|
|
| torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
|
|
|
|
|
| config_dict = {
|
| "model_name": self.config.model_name,
|
| "num_labels": self.config.num_labels,
|
| "use_crf": self.config.use_crf,
|
| "hidden_size": self.config.hidden_size,
|
| "classifier_dropout": self.config.classifier_dropout,
|
| "id2label": self.id2label,
|
| "label2id": self.label2id,
|
| }
|
| with open(os.path.join(save_directory, "config.json"), "w") as f:
|
| json.dump(config_dict, f, indent=2)
|
|
|
| @classmethod
|
| def from_pretrained(cls, model_path: str, device: str = "cpu"):
|
| """Load model from directory."""
|
| import json
|
|
|
| with open(f"{model_path}/config.json") as f:
|
| config_dict = json.load(f)
|
|
|
| config = ModelConfig(
|
| model_name=config_dict["model_name"],
|
| num_labels=config_dict["num_labels"],
|
| use_crf=config_dict["use_crf"],
|
| hidden_size=config_dict["hidden_size"],
|
| classifier_dropout=config_dict["classifier_dropout"],
|
| )
|
|
|
| model = cls(config)
|
| state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location=device)
|
| model.load_state_dict(state_dict)
|
| model.to(device)
|
|
|
| return model
|
|
|