IceBERT-PoS / modeling.py
haukurpj's picture
Upload folder using huggingface_hub
2d923bf verified
raw
history blame
20.6 kB
# Copyright (C) Miðeind ehf.
# This file is part of IceBERT POS model conversion.
import logging
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoConfig, AutoModel, PreTrainedModel, RobertaModel
from .configuration import IceBertPosConfig
logger = logging.getLogger(__name__)
class MultiLabelTokenClassificationHead(nn.Module):
"""Head for multilabel word-level classification tasks."""
def __init__(self, config: IceBertPosConfig):
super().__init__()
self.num_categories = config.num_categories
self.num_labels = config.num_labels
self.hidden_size = config.hidden_size
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
self.activation_fn = F.relu
self.dropout = nn.Dropout(p=config.classifier_dropout)
self.layer_norm = nn.LayerNorm(self.hidden_size)
# Category projection: hidden_size -> num_categories
self.cat_proj = nn.Linear(self.hidden_size, self.num_categories)
# Attribute projection: (hidden_size + num_categories) -> num_labels
self.out_proj = nn.Linear(self.hidden_size + self.num_categories, self.num_labels)
def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
features: Word-level features of shape (total_words, hidden_size)
Returns:
cat_logits: Category logits of shape (total_words, num_categories)
attr_logits: Attribute logits of shape (total_words, num_labels)
"""
x = self.dropout(features)
x = self.dense(x)
x = self.layer_norm(x)
x = self.activation_fn(x)
# Predict categories
cat_logits = self.cat_proj(x)
cat_probs = torch.softmax(cat_logits, dim=-1)
# Predict attributes using concatenated features
attr_input = torch.cat((cat_probs, x), dim=-1)
attr_logits = self.out_proj(attr_input)
return cat_logits, attr_logits
class IceBertPosForTokenClassification(PreTrainedModel):
"""
IceBERT model for multilabel token classification (POS tagging).
This model performs word-level POS tagging by:
1. Encoding input with RoBERTa
2. Aggregating subword tokens to word-level representations
3. Predicting both categories and attributes for each word
"""
config_class = IceBertPosConfig
def __init__(self, config: IceBertPosConfig):
super().__init__(config)
self.config = config
self.num_categories = config.num_categories
self.num_labels = config.num_labels
self.roberta = RobertaModel(config, add_pooling_layer=False)
self.classifier = MultiLabelTokenClassificationHead(config)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
word_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
input_ids: Token indices of shape (batch_size, sequence_length)
attention_mask: Attention mask of shape (batch_size, sequence_length)
word_mask: Binary mask indicating word boundaries (1 = word start)
Returns:
cat_logits: Category logits of shape (batch_size, max_words, num_categories)
attr_logits: Attribute logits of shape (batch_size, max_words, num_labels)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Get RoBERTa outputs
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0] # (batch_size, seq_len, hidden_size)
# Aggregate subword tokens to word-level representations using word_mask
word_features, nwords = self._aggregate_subword_tokens(sequence_output, word_mask)
# Apply classification head
cat_logits, attr_logits = self.classifier(word_features)
# Reshape back to batch format using word counts
cat_logits_batch, attr_logits_batch = self._reshape_to_batch_format(cat_logits, attr_logits, nwords)
return cat_logits_batch, attr_logits_batch
def _aggregate_subword_tokens(
self, sequence_output: torch.Tensor, word_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Aggregate subword token representations to word-level representations.
Following the original fairseq approach by averaging subword tokens within each word.
Args:
sequence_output: subword token representations (batch_size, seq_len, hidden_size)
word_mask: Binary mask where 1 indicates start of word (batch_size, seq_len)
Returns:
word_features: Word-level features (total_words, hidden_size)
nwords: Number of words per sequence (batch_size,)
"""
# TODO: Verify that BOS and EOS are handled correctly - I'm worried that this does not correctly handle padding
# Remove BOS and EOS tokens (first and last positions)
x = sequence_output[:, 1:-1, :] # (batch_size, seq_len-2, hidden_size)
starts = word_mask[:, 1:-1] # (batch_size, seq_len-2)
# Count words per sequence
nwords = starts.sum(dim=-1) # (batch_size,)
# Find word boundaries and average tokens within each word
mean_words = []
batch_size, seq_len, hidden_size = x.shape
for batch_idx in range(batch_size):
seq_starts = starts[batch_idx] # (seq_len-2,)
seq_x = x[batch_idx] # (seq_len-2, hidden_size)
# Find start positions of words
start_positions = seq_starts.nonzero(as_tuple=True)[0] # positions where words start
if len(start_positions) == 0:
continue
# Calculate end positions (start of next word or end of sequence)
end_positions = torch.cat([start_positions[1:], torch.tensor([seq_len], device=start_positions.device)])
# Average tokens within each word
for start_pos, end_pos in zip(start_positions, end_positions):
word_tokens = seq_x[start_pos:end_pos] # tokens in this word
word_repr = word_tokens.mean(dim=0) # average representation
mean_words.append(word_repr)
if len(mean_words) == 0:
return torch.empty(0, sequence_output.size(-1), device=sequence_output.device), nwords
return torch.stack(mean_words), nwords
def _reshape_to_batch_format(
self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, nwords: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Reshape word-level predictions back to batch format.
Following the original fairseq approach with pad_sequence.
Args:
cat_logits: Category logits (total_words, num_categories)
attr_logits: Attribute logits (total_words, num_labels)
nwords: Number of words per sequence (batch_size,)
Returns:
cat_logits_batch: (batch_size, max_words, num_categories)
attr_logits_batch: (batch_size, max_words, num_labels)
"""
# Split logits by sequence using word counts
words_per_seq = nwords.tolist()
cat_logits_split = cat_logits.split(words_per_seq)
attr_logits_split = attr_logits.split(words_per_seq)
# Pad to same length (matching original fairseq approach)
cat_logits_batch = pad_sequence(cat_logits_split, batch_first=True, padding_value=0)
attr_logits_batch = pad_sequence(attr_logits_split, batch_first=True, padding_value=0)
return cat_logits_batch, attr_logits_batch
@torch.no_grad()
def predict_labels(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, word_ids: List[List[int]]
) -> List[List[Tuple[str, List[str]]]]:
"""
Predict POS labels for input sequences.
Args:
input_ids: Token indices
attention_mask: Attention mask
word_ids: Word boundaries
Returns:
List of sequences, each containing (category, [attributes]) per word
"""
# Convert word_ids to word_mask
word_mask = self._word_ids_to_word_mask(word_ids, input_ids.shape)
cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask)
return self._logits_to_labels(cat_logits, attr_logits, word_ids)
def _word_ids_to_word_mask(self, word_ids: List[List[int]], input_shape: torch.Size) -> torch.Tensor:
"""
Convert word_ids to word_mask (binary mask indicating word boundaries).
Args:
word_ids: List of word id sequences
input_shape: Shape of input_ids tensor (batch_size, seq_len)
Returns:
word_mask: Binary tensor where 1 indicates start of word
"""
batch_size, seq_len = input_shape
word_mask = torch.zeros(batch_size, seq_len, dtype=torch.long)
for batch_idx, seq_word_ids in enumerate(word_ids):
prev_word_id = None
for token_idx, word_id in enumerate(seq_word_ids):
if word_id != prev_word_id:
word_mask[batch_idx, token_idx] = 1
prev_word_id = word_id
return word_mask
def predict_labels_from_text(self, sentences: List[str], tokenizer) -> List[List[Tuple[str, List[str]]]]:
"""
Predict POS labels from raw text using fairseq-style preprocessing.
Args:
sentences: List of input sentences
tokenizer: HuggingFace tokenizer
Returns:
List of sequences, each containing (category, [attributes]) per word
"""
# Tokenize with fairseq-style preprocessing
encodings = [tokenizer(sent, return_tensors="pt") for sent in sentences]
word_ids_list = [encoding.word_ids() for encoding in encodings]
# Batch the inputs
max_len = max(encoding["input_ids"].shape[1] for encoding in encodings)
batch_input_ids = []
batch_attention_mask = []
for encoding in encodings:
input_ids = encoding["input_ids"][0]
attention_mask = encoding["attention_mask"][0]
# Pad to max length
pad_len = max_len - len(input_ids)
if pad_len > 0:
input_ids = torch.cat([input_ids, torch.ones(pad_len, dtype=torch.long)]) # pad_token_id = 1
attention_mask = torch.cat([attention_mask, torch.zeros(pad_len, dtype=torch.long)])
batch_input_ids.append(input_ids)
batch_attention_mask.append(attention_mask)
batch_input_ids = torch.stack(batch_input_ids)
batch_attention_mask = torch.stack(batch_attention_mask)
return self.predict_labels(batch_input_ids, batch_attention_mask, word_ids_list)
def _make_group_name_to_group_attr_vec_idxs(self):
"""Create mapping from group names to their attribute vector indices"""
group_name_to_group_attr_vec_idxs = {}
labels = self.config.label_schema["labels"]
nspecial = 0 # Number of special tokens in label dictionary (like <SEP>)
for group_name, group_labels in self.config.label_schema["group_name_to_labels"].items():
vec_idxs = []
for label in group_labels:
if label in labels:
# Find index in labels list, but subtract nspecial to get vector index
label_dict_idx = labels.index(label)
if label_dict_idx >= nspecial: # Skip special tokens
vec_idxs.append(label_dict_idx - nspecial)
group_name_to_group_attr_vec_idxs[group_name] = torch.tensor(vec_idxs)
return group_name_to_group_attr_vec_idxs
def _make_group_masks(self):
"""Create group masks for each category"""
label_categories = self.config.label_schema["label_categories"]
group_names = self.config.label_schema["group_names"]
category_to_group_names = self.config.label_schema["category_to_group_names"]
num_cats = len(label_categories)
num_groups = len(group_names)
group_mask = torch.zeros(num_cats, num_groups, dtype=torch.bool)
for cat_idx, category in enumerate(label_categories):
if category in category_to_group_names:
for group_name in category_to_group_names[category]:
if group_name in group_names:
group_idx = group_names.index(group_name)
group_mask[cat_idx, group_idx] = True
return group_mask
def _make_category_mappings(self):
"""Create mappings between category vector indices and dictionary indices"""
labels = self.config.label_schema["labels"]
label_categories = self.config.label_schema["label_categories"]
# Create mapping from category names to vector indices (0-based)
cat_dict_idx_to_vec_idx = torch.zeros(len(labels), dtype=torch.long)
cat_vec_idx_to_dict_idx = torch.zeros(len(label_categories), dtype=torch.long)
for vec_idx, category in enumerate(label_categories):
if category in labels:
dict_idx = labels.index(category)
cat_dict_idx_to_vec_idx[dict_idx] = vec_idx
cat_vec_idx_to_dict_idx[vec_idx] = dict_idx
return cat_dict_idx_to_vec_idx, cat_vec_idx_to_dict_idx
def _count_words_per_sequence(self, word_ids: List[List[int]]) -> List[int]:
"""Count the number of unique words in each sequence."""
words_per_seq = []
for seq_word_ids in word_ids:
unique_word_ids = set(word_id for word_id in seq_word_ids if word_id is not None)
words_per_seq.append(len(unique_word_ids))
return words_per_seq
def _predict_categories_for_sequence(
self, cat_logits: torch.Tensor, seq_idx: int, seq_nwords: int, cat_vec_idx_to_dict_idx: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predict categories for a single sequence and return both vector and dictionary indices."""
pred_cat_vec_idxs = cat_logits[seq_idx, :seq_nwords].max(dim=-1).indices
pred_cats = cat_vec_idx_to_dict_idx[pred_cat_vec_idxs]
return pred_cat_vec_idxs, pred_cats
def _predict_attributes_for_group(
self,
attr_logits: torch.Tensor,
seq_idx: int,
seq_nwords: int,
group_vec_idxs: torch.Tensor,
seq_group_mask: torch.Tensor,
group_idx: int,
) -> torch.Tensor:
"""Predict attributes for a single group."""
if len(group_vec_idxs) == 0:
return torch.zeros(seq_nwords, dtype=torch.long)
# Get logits for this group
group_logits = attr_logits[seq_idx, :seq_nwords, group_vec_idxs]
if len(group_vec_idxs) == 1:
# Single element group: use sigmoid > 0.5
group_pred = group_logits.sigmoid().ge(0.5).long()
group_pred_dict_idxs = (group_pred.squeeze() * group_vec_idxs.item()) * seq_group_mask[:, group_idx]
else:
# Multi element group: use argmax
group_pred_vec_idxs = group_logits.max(dim=-1).indices
group_pred_dict_idxs = group_vec_idxs[group_pred_vec_idxs] * seq_group_mask[:, group_idx]
return group_pred_dict_idxs
def _predict_all_attributes_for_sequence(
self,
attr_logits: torch.Tensor,
seq_idx: int,
seq_nwords: int,
pred_cat_vec_idxs: torch.Tensor,
group_name_to_group_attr_vec_idxs: dict,
group_mask: torch.Tensor,
group_names: List[str],
) -> torch.Tensor:
"""Predict all attributes for a single sequence."""
seq_group_mask = group_mask[pred_cat_vec_idxs]
pred_attrs = []
for group_idx, group_name in enumerate(group_names):
if group_name not in group_name_to_group_attr_vec_idxs:
pred_attrs.append(torch.zeros(seq_nwords, dtype=torch.long))
continue
group_vec_idxs = group_name_to_group_attr_vec_idxs[group_name]
group_pred_dict_idxs = self._predict_attributes_for_group(
attr_logits, seq_idx, seq_nwords, group_vec_idxs, seq_group_mask, group_idx
)
pred_attrs.append(group_pred_dict_idxs)
# Stack predictions
if pred_attrs:
return torch.stack([p.squeeze() if p.dim() > 1 else p for p in pred_attrs]).t()
else:
return torch.zeros(seq_nwords, len(group_names), dtype=torch.long)
def _convert_predictions_to_labels(
self, pred_cats: torch.Tensor, pred_attrs_tensor: torch.Tensor, labels: List[str], group_names: List[str]
) -> List[Tuple[str, List[str]]]:
"""Convert prediction tensors to human-readable labels."""
seq_nwords = pred_cats.size(0)
seq_predictions = []
for word_idx in range(seq_nwords):
# Category (convert from dictionary index to string)
cat_dict_idx = pred_cats[word_idx].item()
if cat_dict_idx < len(labels):
category = labels[cat_dict_idx]
else:
category = "UNK"
# Attributes (convert from dictionary indices to strings)
attributes = []
for group_idx in range(len(group_names)):
attr_dict_idx = pred_attrs_tensor[word_idx, group_idx].item()
if attr_dict_idx > 0 and attr_dict_idx < len(labels): # Skip 0 (empty) and out of bounds
attributes.append(labels[attr_dict_idx])
seq_predictions.append((category, attributes))
return seq_predictions
def _logits_to_labels(
self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_ids: List[List[int]]
) -> List[List[Tuple[str, List[str]]]]:
"""
Convert logits to human-readable labels using fairseq's group-based logic.
"""
# Create necessary mappings
group_name_to_group_attr_vec_idxs = self._make_group_name_to_group_attr_vec_idxs()
group_mask = self._make_group_masks()
cat_dict_idx_to_vec_idx, cat_vec_idx_to_dict_idx = self._make_category_mappings()
label_schema = self.config.label_schema
labels = label_schema["labels"]
group_names = label_schema["group_names"]
batch_size = cat_logits.size(0)
words_per_seq = self._count_words_per_sequence(word_ids)
batch_predictions = []
for seq_idx in range(batch_size):
seq_nwords = words_per_seq[seq_idx]
# Predict categories
pred_cat_vec_idxs, pred_cats = self._predict_categories_for_sequence(
cat_logits, seq_idx, seq_nwords, cat_vec_idx_to_dict_idx
)
# Predict attributes
pred_attrs_tensor = self._predict_all_attributes_for_sequence(
attr_logits,
seq_idx,
seq_nwords,
pred_cat_vec_idxs,
group_name_to_group_attr_vec_idxs,
group_mask,
group_names,
)
# Convert to labels
seq_predictions = self._convert_predictions_to_labels(pred_cats, pred_attrs_tensor, labels, group_names)
batch_predictions.append(seq_predictions)
return batch_predictions
AutoConfig.register("icebert-pos", IceBertPosConfig)
AutoModel.register(IceBertPosConfig, IceBertPosForTokenClassification)
IceBertPosConfig.register_for_auto_class()
IceBertPosForTokenClassification.register_for_auto_class("AutoModel")