Spaces:
Sleeping
Sleeping
| """ | |
| BERT-CRF Model for Indian Address NER. | |
| Combines a multilingual BERT encoder with a Conditional Random Field (CRF) | |
| layer for improved sequence labeling performance. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModel | |
| from transformers.modeling_outputs import TokenClassifierOutput | |
| from address_parser.models.config import ID2LABEL, LABEL2ID, ModelConfig | |
| class CRF(nn.Module): | |
| """ | |
| Conditional Random Field layer for sequence labeling. | |
| Implements the forward algorithm for computing log-likelihood | |
| and Viterbi decoding for inference. | |
| """ | |
| def __init__(self, num_tags: int, batch_first: bool = True): | |
| """ | |
| Initialize CRF layer. | |
| Args: | |
| num_tags: Number of output tags | |
| batch_first: If True, input is (batch, seq, features) | |
| """ | |
| super().__init__() | |
| self.num_tags = num_tags | |
| self.batch_first = batch_first | |
| # Transition matrix: transitions[i, j] = score of transitioning from tag i to tag j | |
| self.transitions = nn.Parameter(torch.randn(num_tags, num_tags)) | |
| # Start and end transition scores | |
| self.start_transitions = nn.Parameter(torch.randn(num_tags)) | |
| self.end_transitions = nn.Parameter(torch.randn(num_tags)) | |
| self._init_transitions() | |
| def _init_transitions(self): | |
| """Initialize transition parameters.""" | |
| nn.init.uniform_(self.transitions, -0.1, 0.1) | |
| nn.init.uniform_(self.start_transitions, -0.1, 0.1) | |
| nn.init.uniform_(self.end_transitions, -0.1, 0.1) | |
| def forward( | |
| self, | |
| emissions: torch.Tensor, | |
| tags: torch.LongTensor, | |
| mask: torch.ByteTensor | None = None, | |
| reduction: str = "mean", | |
| ) -> torch.Tensor: | |
| """ | |
| Compute negative log-likelihood loss. | |
| Args: | |
| emissions: Emission scores from BERT (batch, seq, num_tags) | |
| tags: Gold standard tags (batch, seq) | |
| mask: Mask for valid tokens (batch, seq) | |
| reduction: 'mean', 'sum', or 'none' | |
| Returns: | |
| Negative log-likelihood loss | |
| """ | |
| if mask is None: | |
| mask = torch.ones_like(tags, dtype=torch.bool) | |
| if self.batch_first: | |
| emissions = emissions.transpose(0, 1) | |
| tags = tags.transpose(0, 1) | |
| mask = mask.transpose(0, 1) | |
| # Compute log-likelihood | |
| numerator = self._compute_score(emissions, tags, mask) | |
| denominator = self._compute_normalizer(emissions, mask) | |
| llh = numerator - denominator | |
| if reduction == "mean": | |
| return -llh.mean() | |
| elif reduction == "sum": | |
| return -llh.sum() | |
| else: | |
| return -llh | |
| def decode( | |
| self, | |
| emissions: torch.Tensor, | |
| mask: torch.ByteTensor | None = None, | |
| ) -> list[list[int]]: | |
| """ | |
| Find the most likely tag sequence using Viterbi algorithm. | |
| Args: | |
| emissions: Emission scores (batch, seq, num_tags) | |
| mask: Mask for valid tokens (batch, seq) | |
| Returns: | |
| List of best tag sequences for each sample | |
| """ | |
| if mask is None: | |
| mask = torch.ones(emissions.shape[:2], dtype=torch.bool, device=emissions.device) | |
| if self.batch_first: | |
| emissions = emissions.transpose(0, 1) | |
| mask = mask.transpose(0, 1) | |
| return self._viterbi_decode(emissions, mask) | |
| def _compute_score( | |
| self, | |
| emissions: torch.Tensor, | |
| tags: torch.LongTensor, | |
| mask: torch.BoolTensor | |
| ) -> torch.Tensor: | |
| """Compute the score of a given tag sequence.""" | |
| seq_length, batch_size = tags.shape | |
| mask = mask.float() | |
| # Start transition score | |
| score = self.start_transitions[tags[0]] | |
| for i in range(seq_length - 1): | |
| current_tag = tags[i] | |
| next_tag = tags[i + 1] | |
| # Emission score | |
| score += emissions[i, torch.arange(batch_size), current_tag] * mask[i] | |
| # Transition score | |
| score += self.transitions[current_tag, next_tag] * mask[i + 1] | |
| # Last emission score | |
| last_tag_idx = mask.long().sum(dim=0) - 1 | |
| last_tags = tags.gather(0, last_tag_idx.unsqueeze(0)).squeeze(0) | |
| score += emissions[last_tag_idx, torch.arange(batch_size), last_tags] | |
| # End transition score | |
| score += self.end_transitions[last_tags] | |
| return score | |
| def _compute_normalizer( | |
| self, | |
| emissions: torch.Tensor, | |
| mask: torch.BoolTensor | |
| ) -> torch.Tensor: | |
| """Compute log-sum-exp of all possible tag sequences (partition function).""" | |
| seq_length = emissions.shape[0] | |
| # Initialize with start transitions | |
| score = self.start_transitions + emissions[0] | |
| for i in range(1, seq_length): | |
| # Broadcast score and transitions for all combinations | |
| broadcast_score = score.unsqueeze(2) | |
| broadcast_emissions = emissions[i].unsqueeze(1) | |
| # Compute next scores | |
| next_score = broadcast_score + self.transitions + broadcast_emissions | |
| # Log-sum-exp | |
| next_score = torch.logsumexp(next_score, dim=1) | |
| # Mask | |
| score = torch.where(mask[i].unsqueeze(1), next_score, score) | |
| # Add end transitions | |
| score += self.end_transitions | |
| return torch.logsumexp(score, dim=1) | |
| def _viterbi_decode( | |
| self, | |
| emissions: torch.Tensor, | |
| mask: torch.BoolTensor | |
| ) -> list[list[int]]: | |
| """Viterbi decoding to find best tag sequence.""" | |
| seq_length, batch_size, num_tags = emissions.shape | |
| # Initialize | |
| score = self.start_transitions + emissions[0] | |
| history = [] | |
| for i in range(1, seq_length): | |
| broadcast_score = score.unsqueeze(2) | |
| broadcast_emissions = emissions[i].unsqueeze(1) | |
| next_score = broadcast_score + self.transitions + broadcast_emissions | |
| # Find best previous tag for each current tag | |
| next_score, indices = next_score.max(dim=1) | |
| # Apply mask | |
| score = torch.where(mask[i].unsqueeze(1), next_score, score) | |
| history.append(indices) | |
| # Add end transitions | |
| score += self.end_transitions | |
| # Backtrack | |
| seq_ends = mask.long().sum(dim=0) - 1 | |
| best_tags_list = [] | |
| for batch_idx in range(batch_size): | |
| # Best last tag | |
| _, best_last_tag = score[batch_idx].max(dim=0) | |
| best_tags = [best_last_tag.item()] | |
| # Backtrack through history | |
| for hist in reversed(history[:seq_ends[batch_idx]]): | |
| best_last_tag = hist[batch_idx][best_tags[-1]] | |
| best_tags.append(best_last_tag.item()) | |
| best_tags.reverse() | |
| best_tags_list.append(best_tags) | |
| return best_tags_list | |
| class BertCRFForTokenClassification(nn.Module): | |
| """ | |
| BERT model with CRF layer for token classification. | |
| This combines a multilingual BERT encoder with a CRF layer | |
| for improved sequence labeling on NER tasks. | |
| """ | |
| def __init__(self, config: ModelConfig): | |
| """ | |
| Initialize BERT-CRF model. | |
| Args: | |
| config: Model configuration | |
| """ | |
| super().__init__() | |
| self.config = config | |
| self.num_labels = config.num_labels | |
| # Load pretrained BERT | |
| self.bert = AutoModel.from_pretrained( | |
| config.model_name, | |
| cache_dir=config.cache_dir, | |
| ) | |
| # Dropout | |
| self.dropout = nn.Dropout(config.classifier_dropout) | |
| # Classification head | |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |
| # CRF layer | |
| if config.use_crf: | |
| self.crf = CRF(num_tags=config.num_labels, batch_first=True) | |
| else: | |
| self.crf = None | |
| # Label mappings | |
| self.id2label = ID2LABEL | |
| self.label2id = LABEL2ID | |
| # PyTorch 2.9: Lazy compilation for optimized inference | |
| self._compiled_forward: nn.Module | None = None | |
| def _get_compiled_forward(self): | |
| """Lazy compile forward pass on first inference call.""" | |
| # Skip torch.compile on Windows without MSVC or when explicitly disabled | |
| # The inductor backend requires a C++ compiler (cl on Windows, gcc/clang on Linux) | |
| import os | |
| import sys | |
| skip_compile = ( | |
| os.environ.get("TORCH_COMPILE_DISABLE", "0") == "1" | |
| or sys.platform == "win32" # Skip on Windows to avoid cl requirement | |
| ) | |
| if self._compiled_forward is None: | |
| if not skip_compile and hasattr(torch, "compile"): | |
| try: | |
| self._compiled_forward = torch.compile( | |
| self.forward, | |
| backend="inductor", | |
| mode="reduce-overhead", | |
| dynamic=True, | |
| ) | |
| except Exception: | |
| self._compiled_forward = self.forward | |
| else: | |
| self._compiled_forward = self.forward | |
| return self._compiled_forward | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor | None = None, | |
| token_type_ids: torch.Tensor | None = None, | |
| labels: torch.Tensor | None = None, | |
| return_dict: bool = True, | |
| ): | |
| """ | |
| Forward pass. | |
| Args: | |
| input_ids: Input token IDs (batch, seq) | |
| attention_mask: Attention mask (batch, seq) | |
| token_type_ids: Token type IDs (batch, seq) | |
| labels: Gold standard labels for training (batch, seq) | |
| return_dict: Return as dict or tuple | |
| Returns: | |
| TokenClassifierOutput with loss, logits, hidden states | |
| """ | |
| # BERT encoding | |
| outputs = self.bert( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| ) | |
| sequence_output = outputs.last_hidden_state | |
| sequence_output = self.dropout(sequence_output) | |
| # Classification logits | |
| logits = self.classifier(sequence_output) | |
| loss = None | |
| if labels is not None: | |
| if self.crf is not None: | |
| # CRF loss - need to handle -100 (ignore_index) labels | |
| mask = attention_mask.bool() if attention_mask is not None else None | |
| # Replace -100 with 0 (will be masked out anyway) | |
| crf_labels = labels.clone() | |
| crf_labels[crf_labels == -100] = 0 | |
| loss = self.crf(logits, crf_labels, mask=mask, reduction=self.config.crf_reduction) | |
| else: | |
| # Standard cross-entropy | |
| loss_fct = nn.CrossEntropyLoss() | |
| active_loss = attention_mask.view(-1) == 1 | |
| active_logits = logits.view(-1, self.num_labels)[active_loss] | |
| active_labels = labels.view(-1)[active_loss] | |
| loss = loss_fct(active_logits, active_labels) | |
| if not return_dict: | |
| output = (logits,) + outputs[2:] | |
| return ((loss,) + output) if loss is not None else output | |
| return TokenClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| def decode( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor | None = None, | |
| token_type_ids: torch.Tensor | None = None, | |
| ) -> list[list[int]]: | |
| """ | |
| Decode input to tag sequences using compiled forward pass. | |
| Args: | |
| input_ids: Input token IDs (batch, seq) | |
| attention_mask: Attention mask (batch, seq) | |
| token_type_ids: Token type IDs (batch, seq) | |
| Returns: | |
| List of predicted tag sequences | |
| """ | |
| self.eval() | |
| with torch.no_grad(): | |
| # Use compiled forward for optimized inference (PyTorch 2.9+) | |
| forward_fn = self._get_compiled_forward() | |
| outputs = forward_fn( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| ) | |
| logits = outputs.logits | |
| if self.crf is not None: | |
| mask = attention_mask.bool() if attention_mask is not None else None | |
| predictions = self.crf.decode(logits, mask=mask) | |
| else: | |
| predictions = logits.argmax(dim=-1).tolist() | |
| return predictions | |
| def save_pretrained(self, save_directory: str): | |
| """Save model to directory.""" | |
| import json | |
| import os | |
| os.makedirs(save_directory, exist_ok=True) | |
| # Save model weights | |
| torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin")) | |
| # Save config | |
| config_dict = { | |
| "model_name": self.config.model_name, | |
| "num_labels": self.config.num_labels, | |
| "use_crf": self.config.use_crf, | |
| "hidden_size": self.config.hidden_size, | |
| "classifier_dropout": self.config.classifier_dropout, | |
| "id2label": self.id2label, | |
| "label2id": self.label2id, | |
| } | |
| with open(os.path.join(save_directory, "config.json"), "w") as f: | |
| json.dump(config_dict, f, indent=2) | |
| def from_pretrained(cls, model_path: str, device: str = "cpu"): | |
| """Load model from directory.""" | |
| import json | |
| with open(f"{model_path}/config.json") as f: | |
| config_dict = json.load(f) | |
| config = ModelConfig( | |
| model_name=config_dict["model_name"], | |
| num_labels=config_dict["num_labels"], | |
| use_crf=config_dict["use_crf"], | |
| hidden_size=config_dict["hidden_size"], | |
| classifier_dropout=config_dict["classifier_dropout"], | |
| ) | |
| model = cls(config) | |
| state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location=device) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| return model | |