# Copyright (C) Miưeind ehf. # This file is part of IceBERT POS model conversion. import logging from typing import List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence from transformers import AutoConfig, AutoModel, PreTrainedModel, RobertaModel from .configuration import IceBertPosConfig logger = logging.getLogger(__name__) class MultiLabelTokenClassificationHead(nn.Module): """Head for multilabel word-level classification tasks.""" def __init__(self, config: IceBertPosConfig): super().__init__() self.num_categories = config.num_categories self.num_labels = config.num_labels self.hidden_size = config.hidden_size self.dense = nn.Linear(self.hidden_size, self.hidden_size) self.activation_fn = F.relu self.dropout = nn.Dropout(p=config.classifier_dropout) self.layer_norm = nn.LayerNorm(self.hidden_size) # Category projection: hidden_size -> num_categories self.cat_proj = nn.Linear(self.hidden_size, self.num_categories) # Attribute projection: (hidden_size + num_categories) -> num_labels self.out_proj = nn.Linear(self.hidden_size + self.num_categories, self.num_labels) def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: features: Word-level features of shape (total_words, hidden_size) Returns: cat_logits: Category logits of shape (total_words, num_categories) attr_logits: Attribute logits of shape (total_words, num_labels) """ x = self.dropout(features) x = self.dense(x) x = self.layer_norm(x) x = self.activation_fn(x) # Predict categories cat_logits = self.cat_proj(x) cat_probs = torch.softmax(cat_logits, dim=-1) # Predict attributes using concatenated features attr_input = torch.cat((cat_probs, x), dim=-1) attr_logits = self.out_proj(attr_input) return cat_logits, attr_logits class IceBertPosForTokenClassification(PreTrainedModel): """ IceBERT model for multilabel token classification (POS tagging). This model performs word-level POS tagging by: 1. Encoding input with RoBERTa 2. Aggregating subword tokens to word-level representations 3. Predicting both categories and attributes for each word """ config_class = IceBertPosConfig def __init__(self, config: IceBertPosConfig): super().__init__(config) self.config = config self.num_categories = config.num_categories self.num_labels = config.num_labels self.roberta = RobertaModel(config, add_pooling_layer=False) self.classifier = MultiLabelTokenClassificationHead(config) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, word_mask: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input_ids: Token indices of shape (batch_size, sequence_length) attention_mask: Attention mask of shape (batch_size, sequence_length) word_mask: Binary mask indicating word boundaries (1 = word start) Returns: cat_logits: Category logits of shape (batch_size, max_words, num_categories) attr_logits: Attribute logits of shape (batch_size, max_words, num_labels) """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Get RoBERTa outputs outputs = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] # (batch_size, seq_len, hidden_size) # Aggregate subword tokens to word-level representations using word_mask word_features, nwords = self._aggregate_subword_tokens(sequence_output, word_mask) # Apply classification head cat_logits, attr_logits = self.classifier(word_features) # Reshape back to batch format using word counts cat_logits_batch, attr_logits_batch = self._reshape_to_batch_format(cat_logits, attr_logits, nwords) return cat_logits_batch, attr_logits_batch def _aggregate_subword_tokens( self, sequence_output: torch.Tensor, word_mask: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Aggregate subword token representations to word-level representations. Following the original fairseq approach by averaging subword tokens within each word. Args: sequence_output: subword token representations (batch_size, seq_len, hidden_size) word_mask: Binary mask where 1 indicates start of word (batch_size, seq_len) Returns: word_features: Word-level features (total_words, hidden_size) nwords: Number of words per sequence (batch_size,) """ # TODO: Verify that BOS and EOS are handled correctly - I'm worried that this does not correctly handle padding # Remove BOS and EOS tokens (first and last positions) x = sequence_output[:, 1:-1, :] # (batch_size, seq_len-2, hidden_size) starts = word_mask[:, 1:-1] # (batch_size, seq_len-2) # Count words per sequence nwords = starts.sum(dim=-1) # (batch_size,) # Find word boundaries and average tokens within each word mean_words = [] batch_size, seq_len, hidden_size = x.shape for batch_idx in range(batch_size): seq_starts = starts[batch_idx] # (seq_len-2,) seq_x = x[batch_idx] # (seq_len-2, hidden_size) # Find start positions of words start_positions = seq_starts.nonzero(as_tuple=True)[0] # positions where words start if len(start_positions) == 0: continue # Calculate end positions (start of next word or end of sequence) end_positions = torch.cat([start_positions[1:], torch.tensor([seq_len], device=start_positions.device)]) # Average tokens within each word for start_pos, end_pos in zip(start_positions, end_positions): word_tokens = seq_x[start_pos:end_pos] # tokens in this word word_repr = word_tokens.mean(dim=0) # average representation mean_words.append(word_repr) if len(mean_words) == 0: return torch.empty(0, sequence_output.size(-1), device=sequence_output.device), nwords return torch.stack(mean_words), nwords def _reshape_to_batch_format( self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, nwords: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Reshape word-level predictions back to batch format. Following the original fairseq approach with pad_sequence. Args: cat_logits: Category logits (total_words, num_categories) attr_logits: Attribute logits (total_words, num_labels) nwords: Number of words per sequence (batch_size,) Returns: cat_logits_batch: (batch_size, max_words, num_categories) attr_logits_batch: (batch_size, max_words, num_labels) """ # Split logits by sequence using word counts words_per_seq = nwords.tolist() cat_logits_split = cat_logits.split(words_per_seq) attr_logits_split = attr_logits.split(words_per_seq) # Pad to same length (matching original fairseq approach) cat_logits_batch = pad_sequence(cat_logits_split, batch_first=True, padding_value=0) attr_logits_batch = pad_sequence(attr_logits_split, batch_first=True, padding_value=0) return cat_logits_batch, attr_logits_batch @torch.no_grad() def predict_labels( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, word_ids: List[List[int]] ) -> List[List[Tuple[str, List[str]]]]: """ Predict POS labels for input sequences. Args: input_ids: Token indices attention_mask: Attention mask word_ids: Word boundaries Returns: List of sequences, each containing (category, [attributes]) per word """ # Convert word_ids to word_mask word_mask = self._word_ids_to_word_mask(word_ids, input_ids.shape) cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask) return self._logits_to_labels(cat_logits, attr_logits, word_ids) def _word_ids_to_word_mask(self, word_ids: List[List[int]], input_shape: torch.Size) -> torch.Tensor: """ Convert word_ids to word_mask (binary mask indicating word boundaries). Args: word_ids: List of word id sequences input_shape: Shape of input_ids tensor (batch_size, seq_len) Returns: word_mask: Binary tensor where 1 indicates start of word """ batch_size, seq_len = input_shape word_mask = torch.zeros(batch_size, seq_len, dtype=torch.long) for batch_idx, seq_word_ids in enumerate(word_ids): prev_word_id = None for token_idx, word_id in enumerate(seq_word_ids): if word_id != prev_word_id: word_mask[batch_idx, token_idx] = 1 prev_word_id = word_id return word_mask def predict_labels_from_text(self, sentences: List[str], tokenizer) -> List[List[Tuple[str, List[str]]]]: """ Predict POS labels from raw text using fairseq-style preprocessing. Args: sentences: List of input sentences tokenizer: HuggingFace tokenizer Returns: List of sequences, each containing (category, [attributes]) per word """ # Tokenize with fairseq-style preprocessing encodings = [tokenizer(sent, return_tensors="pt") for sent in sentences] word_ids_list = [encoding.word_ids() for encoding in encodings] # Batch the inputs max_len = max(encoding["input_ids"].shape[1] for encoding in encodings) batch_input_ids = [] batch_attention_mask = [] for encoding in encodings: input_ids = encoding["input_ids"][0] attention_mask = encoding["attention_mask"][0] # Pad to max length pad_len = max_len - len(input_ids) if pad_len > 0: input_ids = torch.cat([input_ids, torch.ones(pad_len, dtype=torch.long)]) # pad_token_id = 1 attention_mask = torch.cat([attention_mask, torch.zeros(pad_len, dtype=torch.long)]) batch_input_ids.append(input_ids) batch_attention_mask.append(attention_mask) batch_input_ids = torch.stack(batch_input_ids) batch_attention_mask = torch.stack(batch_attention_mask) return self.predict_labels(batch_input_ids, batch_attention_mask, word_ids_list) def _make_group_name_to_group_attr_vec_idxs(self): """Create mapping from group names to their attribute vector indices""" group_name_to_group_attr_vec_idxs = {} labels = self.config.label_schema["labels"] nspecial = 0 # Number of special tokens in label dictionary (like ) for group_name, group_labels in self.config.label_schema["group_name_to_labels"].items(): vec_idxs = [] for label in group_labels: if label in labels: # Find index in labels list, but subtract nspecial to get vector index label_dict_idx = labels.index(label) if label_dict_idx >= nspecial: # Skip special tokens vec_idxs.append(label_dict_idx - nspecial) group_name_to_group_attr_vec_idxs[group_name] = torch.tensor(vec_idxs) return group_name_to_group_attr_vec_idxs def _make_group_masks(self): """Create group masks for each category""" label_categories = self.config.label_schema["label_categories"] group_names = self.config.label_schema["group_names"] category_to_group_names = self.config.label_schema["category_to_group_names"] num_cats = len(label_categories) num_groups = len(group_names) group_mask = torch.zeros(num_cats, num_groups, dtype=torch.bool) for cat_idx, category in enumerate(label_categories): if category in category_to_group_names: for group_name in category_to_group_names[category]: if group_name in group_names: group_idx = group_names.index(group_name) group_mask[cat_idx, group_idx] = True return group_mask def _make_category_mappings(self): """Create mappings between category vector indices and dictionary indices""" labels = self.config.label_schema["labels"] label_categories = self.config.label_schema["label_categories"] # Create mapping from category names to vector indices (0-based) cat_dict_idx_to_vec_idx = torch.zeros(len(labels), dtype=torch.long) cat_vec_idx_to_dict_idx = torch.zeros(len(label_categories), dtype=torch.long) for vec_idx, category in enumerate(label_categories): if category in labels: dict_idx = labels.index(category) cat_dict_idx_to_vec_idx[dict_idx] = vec_idx cat_vec_idx_to_dict_idx[vec_idx] = dict_idx return cat_dict_idx_to_vec_idx, cat_vec_idx_to_dict_idx def _count_words_per_sequence(self, word_ids: List[List[int]]) -> List[int]: """Count the number of unique words in each sequence.""" words_per_seq = [] for seq_word_ids in word_ids: unique_word_ids = set(word_id for word_id in seq_word_ids if word_id is not None) words_per_seq.append(len(unique_word_ids)) return words_per_seq def _predict_categories_for_sequence( self, cat_logits: torch.Tensor, seq_idx: int, seq_nwords: int, cat_vec_idx_to_dict_idx: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Predict categories for a single sequence and return both vector and dictionary indices.""" pred_cat_vec_idxs = cat_logits[seq_idx, :seq_nwords].max(dim=-1).indices pred_cats = cat_vec_idx_to_dict_idx[pred_cat_vec_idxs] return pred_cat_vec_idxs, pred_cats def _predict_attributes_for_group( self, attr_logits: torch.Tensor, seq_idx: int, seq_nwords: int, group_vec_idxs: torch.Tensor, seq_group_mask: torch.Tensor, group_idx: int, ) -> torch.Tensor: """Predict attributes for a single group.""" if len(group_vec_idxs) == 0: return torch.zeros(seq_nwords, dtype=torch.long) # Get logits for this group group_logits = attr_logits[seq_idx, :seq_nwords, group_vec_idxs] if len(group_vec_idxs) == 1: # Single element group: use sigmoid > 0.5 group_pred = group_logits.sigmoid().ge(0.5).long() group_pred_dict_idxs = (group_pred.squeeze() * group_vec_idxs.item()) * seq_group_mask[:, group_idx] else: # Multi element group: use argmax group_pred_vec_idxs = group_logits.max(dim=-1).indices group_pred_dict_idxs = group_vec_idxs[group_pred_vec_idxs] * seq_group_mask[:, group_idx] return group_pred_dict_idxs def _predict_all_attributes_for_sequence( self, attr_logits: torch.Tensor, seq_idx: int, seq_nwords: int, pred_cat_vec_idxs: torch.Tensor, group_name_to_group_attr_vec_idxs: dict, group_mask: torch.Tensor, group_names: List[str], ) -> torch.Tensor: """Predict all attributes for a single sequence.""" seq_group_mask = group_mask[pred_cat_vec_idxs] pred_attrs = [] for group_idx, group_name in enumerate(group_names): if group_name not in group_name_to_group_attr_vec_idxs: pred_attrs.append(torch.zeros(seq_nwords, dtype=torch.long)) continue group_vec_idxs = group_name_to_group_attr_vec_idxs[group_name] group_pred_dict_idxs = self._predict_attributes_for_group( attr_logits, seq_idx, seq_nwords, group_vec_idxs, seq_group_mask, group_idx ) pred_attrs.append(group_pred_dict_idxs) # Stack predictions if pred_attrs: return torch.stack([p.squeeze() if p.dim() > 1 else p for p in pred_attrs]).t() else: return torch.zeros(seq_nwords, len(group_names), dtype=torch.long) def _convert_predictions_to_labels( self, pred_cats: torch.Tensor, pred_attrs_tensor: torch.Tensor, labels: List[str], group_names: List[str] ) -> List[Tuple[str, List[str]]]: """Convert prediction tensors to human-readable labels.""" seq_nwords = pred_cats.size(0) seq_predictions = [] for word_idx in range(seq_nwords): # Category (convert from dictionary index to string) cat_dict_idx = pred_cats[word_idx].item() if cat_dict_idx < len(labels): category = labels[cat_dict_idx] else: category = "UNK" # Attributes (convert from dictionary indices to strings) attributes = [] for group_idx in range(len(group_names)): attr_dict_idx = pred_attrs_tensor[word_idx, group_idx].item() if attr_dict_idx > 0 and attr_dict_idx < len(labels): # Skip 0 (empty) and out of bounds attributes.append(labels[attr_dict_idx]) seq_predictions.append((category, attributes)) return seq_predictions def _logits_to_labels( self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_ids: List[List[int]] ) -> List[List[Tuple[str, List[str]]]]: """ Convert logits to human-readable labels using fairseq's group-based logic. """ # Create necessary mappings group_name_to_group_attr_vec_idxs = self._make_group_name_to_group_attr_vec_idxs() group_mask = self._make_group_masks() cat_dict_idx_to_vec_idx, cat_vec_idx_to_dict_idx = self._make_category_mappings() label_schema = self.config.label_schema labels = label_schema["labels"] group_names = label_schema["group_names"] batch_size = cat_logits.size(0) words_per_seq = self._count_words_per_sequence(word_ids) batch_predictions = [] for seq_idx in range(batch_size): seq_nwords = words_per_seq[seq_idx] # Predict categories pred_cat_vec_idxs, pred_cats = self._predict_categories_for_sequence( cat_logits, seq_idx, seq_nwords, cat_vec_idx_to_dict_idx ) # Predict attributes pred_attrs_tensor = self._predict_all_attributes_for_sequence( attr_logits, seq_idx, seq_nwords, pred_cat_vec_idxs, group_name_to_group_attr_vec_idxs, group_mask, group_names, ) # Convert to labels seq_predictions = self._convert_predictions_to_labels(pred_cats, pred_attrs_tensor, labels, group_names) batch_predictions.append(seq_predictions) return batch_predictions AutoConfig.register("icebert-pos", IceBertPosConfig) AutoModel.register(IceBertPosConfig, IceBertPosForTokenClassification) IceBertPosConfig.register_for_auto_class() IceBertPosForTokenClassification.register_for_auto_class("AutoModel")