IceBERT-PoS / modeling.py
haukurpj's picture
Fix inconsistencies with the old model - now works equally
aaca62a
raw
history blame
18.3 kB
# Copyright (C) Miðeind ehf.
# This file is part of IceBERT POS model conversion.
import logging
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoConfig, AutoModel, PreTrainedModel, RobertaModel
from .configuration import IceBertPosConfig
from .old_label_utils import (
SimpleLabelDictionary,
clean_cats_attrs,
create_label_dictionary_from_schema,
make_dict_idx_to_vec_idx,
make_group_masks,
make_group_name_to_group_attr_vec_idxs,
make_vec_idx_to_dict_idx,
)
logger = logging.getLogger(__name__)
class MultiLabelTokenClassificationHead(nn.Module):
"""Head for multilabel word-level classification tasks."""
def __init__(self, config: IceBertPosConfig):
super().__init__()
self.num_categories = config.num_categories
self.num_labels = config.num_labels
self.hidden_size = config.hidden_size
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
self.activation_fn = F.relu
self.dropout = nn.Dropout(p=config.classifier_dropout)
self.layer_norm = nn.LayerNorm(self.hidden_size)
# Category projection: hidden_size -> num_categories
self.cat_proj = nn.Linear(self.hidden_size, self.num_categories)
# Attribute projection: (hidden_size + num_categories) -> num_labels
self.out_proj = nn.Linear(self.hidden_size + self.num_categories, self.num_labels)
def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
features: Word-level features of shape (batch_size, max_words, hidden_size)
Returns:
cat_logits: Category logits of shape (batch_size, max_words, num_categories)
attr_logits: Attribute logits of shape (batch_size, max_words, num_labels)
"""
x = self.dropout(features)
x = self.dense(x)
x = self.layer_norm(x)
x = self.activation_fn(x)
# Predict categories
cat_logits = self.cat_proj(x)
cat_probs = torch.softmax(cat_logits, dim=-1)
# Predict attributes using concatenated features
attr_input = torch.cat((cat_probs, x), dim=-1)
attr_logits = self.out_proj(attr_input)
return cat_logits, attr_logits
class IceBertPosForTokenClassification(PreTrainedModel):
"""
IceBERT model for multilabel token classification (POS tagging).
This model performs word-level POS tagging by:
1. Encoding input with RoBERTa
2. Aggregating subword tokens to word-level representations
3. Predicting both categories and attributes for each word
"""
config_class = IceBertPosConfig
def __init__(self, config: IceBertPosConfig):
super().__init__(config)
self.config = config
self.num_categories = config.num_categories
self.num_labels = config.num_labels
self.roberta = RobertaModel(config, add_pooling_layer=False)
self.classifier = MultiLabelTokenClassificationHead(config)
# Create label dictionary and mappings (mimicking old fairseq model)
self.label_dictionary = create_label_dictionary_from_schema(config.label_schema)
self._setup_label_mappings()
# Initialize weights and apply final processing
self.post_init()
def _setup_label_mappings(self):
"""Setup label mappings similar to the old fairseq model."""
schema = self.config.label_schema
self.group_name_to_group_attr_vec_idxs = make_group_name_to_group_attr_vec_idxs(self.label_dictionary, schema)
self.cat_dict_idx_to_vec_idx = make_dict_idx_to_vec_idx(self.label_dictionary, schema.label_categories)
self.cat_vec_idx_to_dict_idx = make_vec_idx_to_dict_idx(self.label_dictionary, schema.label_categories)
self.group_mask = make_group_masks(self.label_dictionary, schema)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
word_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
input_ids: Token indices of shape (batch_size, sequence_length)
attention_mask: Attention mask of shape (batch_size, sequence_length)
word_mask: Binary mask indicating word boundaries (1 = word start) of shape (batch_size, sequence_length)
Returns:
cat_logits: Category logits of shape (batch_size, max_words, num_categories)
attr_logits: Attribute logits of shape (batch_size, max_words, num_labels)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Get RoBERTa outputs
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=return_dict,
)
x = outputs[0] # (batch_size, seq_len, hidden)
# Copy exact logic from old model
_, _, inner_dim = x.shape
# use first bpe token of word as representation
x = x[:, 1:-1, :]
starts = word_mask[:, 1:-1] # remove bos, eos
ends = starts.roll(-1, dims=[-1]).nonzero()[:, -1] + 1
starts = starts.nonzero().tolist()
mean_words = []
for (seq_idx, token_idx), end in zip(starts, ends):
mean_words.append(x[seq_idx, token_idx:end, :].mean(dim=0))
mean_words = torch.stack(mean_words)
words = mean_words
# Innermost dimension is mask for tokens at head of word.
nwords = word_mask.sum(dim=-1)
(cat_logits, attr_logits) = self.classifier(words)
# (Batch * Time) x Depth -> Batch x Time x Depth
cat_logits = pad_sequence(cat_logits.split((nwords).tolist()), padding_value=0, batch_first=True)
attr_logits = pad_sequence(
attr_logits.split((nwords).tolist()),
padding_value=0,
batch_first=True,
)
return cat_logits, attr_logits
def _aggregate_subword_tokens(
self, sequence_output: torch.Tensor, word_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Aggregate subword token representations to word-level representations.
Following the original fairseq approach by averaging subword tokens within each word.
Args:
sequence_output: subword token representations (batch_size, seq_len, hidden_size)
word_mask: Binary mask where 1 indicates start of word (batch_size, seq_len)
Returns:
word_features: Word-level features (batch_size, max_words, hidden_size)
nwords: Number of words per sequence (batch_size,)
"""
# TODO: Verify that BOS and EOS are handled correctly - I'm worried that this does not correctly handle padding
# Remove BOS and EOS tokens (first and last positions)
x = sequence_output[:, 1:-1, :] # (batch_size, seq_len-2, hidden_size)
starts = word_mask[:, 1:-1] # (batch_size, seq_len-2)
# Count words per sequence
nwords = starts.sum(dim=-1) # (batch_size,)
# Find word boundaries and average tokens within each word
mean_words = []
batch_size, seq_len, hidden_size = x.shape
for batch_idx in range(batch_size):
seq_starts = starts[batch_idx] # (seq_len-2,)
seq_x = x[batch_idx] # (seq_len-2, hidden_size)
# Find start positions of words
start_positions = seq_starts.nonzero(as_tuple=True)[0] # positions where words start
if len(start_positions) == 0:
continue
# Calculate end positions (start of next word or end of sequence)
end_positions = torch.cat([start_positions[1:], torch.tensor([seq_len], device=start_positions.device)])
# Average tokens within each word
for start_pos, end_pos in zip(start_positions, end_positions):
word_tokens = seq_x[start_pos:end_pos] # tokens in this word
word_repr = word_tokens.mean(dim=0) # average representation
mean_words.append(word_repr)
if len(mean_words) == 0:
return torch.empty(0, sequence_output.size(-1), device=sequence_output.device), nwords
return torch.stack(mean_words), nwords
def _reshape_to_batch_format(
self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, nwords: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Reshape word-level predictions back to batch format.
Following the original fairseq approach with pad_sequence.
Args:
cat_logits: Category logits (total_words, num_categories)
attr_logits: Attribute logits (total_words, num_labels)
nwords: Number of words per sequence (batch_size,)
Returns:
cat_logits_batch: (batch_size, max_words, num_categories)
attr_logits_batch: (batch_size, max_words, num_labels)
"""
# Split logits by sequence using word counts
words_per_seq = nwords.tolist()
cat_logits_split = cat_logits.split(words_per_seq)
attr_logits_split = attr_logits.split(words_per_seq)
# Pad to same length (matching original fairseq approach)
cat_logits_batch = pad_sequence(cat_logits_split, batch_first=True, padding_value=0)
attr_logits_batch = pad_sequence(attr_logits_split, batch_first=True, padding_value=0)
return cat_logits_batch, attr_logits_batch
@torch.no_grad()
def predict_labels(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, word_ids: List[List[int]]
) -> List[List[Tuple[str, List[str]]]]:
"""
Predict POS labels for input sequences.
Args:
input_ids: Token indices
attention_mask: Attention mask
word_ids: Word boundaries
Returns:
List of sequences, each containing (category, [attributes]) per word
"""
# Convert word_ids to word_mask
word_mask = self._word_ids_to_word_mask(word_ids, input_ids.shape)
cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask)
return self._logits_to_labels(cat_logits, attr_logits, word_mask)
def _word_ids_to_word_mask(self, word_ids: List[List[int]], input_shape: torch.Size) -> torch.Tensor:
"""
Convert word_ids to word_mask (binary mask indicating word boundaries).
Args:
word_ids: List of word id sequences
input_shape: Shape of input_ids tensor (batch_size, seq_len)
Returns:
word_mask: Binary tensor where 1 indicates start of word (batch_size, seq_len)
"""
batch_size, seq_len = input_shape
word_mask = torch.zeros(batch_size, seq_len, dtype=torch.long)
for batch_idx, seq_word_ids in enumerate(word_ids):
# Truncate to exclude BOS and EOS tokens (first and last)
truncated_word_ids = seq_word_ids[1:-1]
prev_word_id = None
for token_idx, word_id in enumerate(truncated_word_ids):
if word_id != prev_word_id:
word_mask[batch_idx, token_idx + 1] = 1 # +1 to account for BOS
prev_word_id = word_id
# Debug logging to match fairseq model
logger.debug(f"Word mask: {word_mask[batch_idx].tolist()}")
return word_mask
def predict_labels_from_text(self, sentences: List[str], tokenizer) -> List[List[Tuple[str, List[str]]]]:
"""
Predict POS labels from raw text using fairseq-style preprocessing.
Args:
sentences: List of input sentences
tokenizer: HuggingFace tokenizer
Returns:
List of sequences, each containing (category, [attributes]) per word
"""
# Split sentences by spaces to get proper word boundaries
# This fixes the issue where tokens like "Kl." get split incorrectly
sentences_split = [sentence.split() for sentence in sentences]
# Use batch_encode_plus with is_split_into_words=True to preserve word boundaries
encoding = tokenizer.batch_encode_plus(
sentences_split,
return_tensors="pt",
padding=True,
is_split_into_words=True,
add_special_tokens=True
)
batch_input_ids = encoding["input_ids"]
batch_attention_mask = encoding["attention_mask"]
word_ids_list = [encoding.word_ids(i) for i in range(len(sentences))]
# Debug logging to match fairseq model
for i in range(len(sentences)):
logger.debug(f"Encoded tokens: {batch_input_ids[i]}")
logger.debug(f"Decoded tokens: {tokenizer.convert_ids_to_tokens(batch_input_ids[i].tolist())}")
logger.debug(f"Word IDs: {word_ids_list[i]}")
return self.predict_labels(batch_input_ids, batch_attention_mask, word_ids_list)
def _logits_to_labels(
self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_mask: torch.Tensor
) -> List[List[Tuple[str, List[str]]]]:
"""
Convert logits to human-readable labels using fairseq's group-based logic.
Copied from the old model's logits_to_labels method.
"""
# logits: Batch x Time x Labels
bsz, _, num_cats = cat_logits.shape
_, _, num_attrs = attr_logits.shape
nwords = word_mask.sum(-1)
assert num_attrs == len(self.config.label_schema.labels)
assert num_cats == len(self.config.label_schema.label_categories)
batch_cats = []
batch_attrs = []
for seq_idx in range(bsz):
seq_nwords = nwords[seq_idx]
pred_cat_vec_idxs = cat_logits[seq_idx, :seq_nwords].max(dim=-1).indices
pred_cats = self.cat_vec_idx_to_dict_idx[pred_cat_vec_idxs]
group_mask = self.group_mask[pred_cat_vec_idxs]
offset = self.label_dictionary.nspecial
pred_attrs = []
for group_idx, group_name in enumerate(self.config.label_schema.group_names):
group_vec_idxs = self.group_name_to_group_attr_vec_idxs[group_name]
# logits: (bsz * nwords) x labels
group_logits = attr_logits[seq_idx, :seq_nwords, group_vec_idxs]
if len(group_vec_idxs) == 1:
group_pred = group_logits.sigmoid().ge(0.5).long()
group_pred_dict_idxs = (group_pred.squeeze() * (group_vec_idxs.item() + offset)).T.to(
"cpu"
) * group_mask[:, group_idx]
else:
group_pred_vec_idxs = group_logits.max(dim=-1).indices
group_pred_dict_idxs = (group_vec_idxs[group_pred_vec_idxs] + offset) * group_mask[:, group_idx]
pred_attrs.append(group_pred_dict_idxs)
pred_attrs = torch.stack([p.squeeze() for p in pred_attrs]).t()
batch_cats.append(pred_cats)
batch_attrs.append(pred_attrs)
predictions = list(
[
clean_cats_attrs(
self.label_dictionary,
self.config.label_schema,
seq_cats,
seq_attrs,
)
for seq_cats, seq_attrs in zip(batch_cats, batch_attrs)
]
)
return predictions
def make_vec_idx_to_dict_idx(dictionary, labels, device="cpu", fill_value=-100):
vec_idx_to_dict_idx = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long)
for vec_idx, label in enumerate(labels):
vec_idx_to_dict_idx[vec_idx] = dictionary.index(label)
return vec_idx_to_dict_idx
def make_group_masks(dictionary, schema, device="cpu"):
num_groups = len(schema.group_names)
offset = dictionary.nspecial
num_labels = len(dictionary) - offset
ret_mask = torch.zeros(num_labels, num_groups, dtype=torch.int64, device=device)
for cat, cat_group_names in schema.category_to_group_names.items():
cat_label_idx = dictionary.index(cat)
cat_vec_idx = schema.label_categories.index(cat)
for group_name in cat_group_names:
ret_mask[cat_vec_idx, schema.group_names.index(group_name)] = 1
assert cat_label_idx != dictionary.unk()
for cat in schema.label_categories:
cat_label_idx = dictionary.index(cat)
assert cat_label_idx != dictionary.unk()
return ret_mask
def make_group_name_to_group_attr_vec_idxs(dict_, schema):
offset = dict_.nspecial
group_names = schema.group_name_to_labels.keys()
name_to_labels = schema.group_name_to_labels
group_name_to_group_attr_vec_idxs = {
name: torch.tensor([dict_.index(item) - offset for item in name_to_labels[name]]) for name in group_names
}
return group_name_to_group_attr_vec_idxs
def make_dict_idx_to_vec_idx(dictionary, cats, device="cpu", fill_value=-100):
# NOTE: when target is not in label_categories, the error is silent
map_tgt = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long)
for vec_idx, label in enumerate(cats):
map_tgt[dictionary.index(label)] = vec_idx
return map_tgt
AutoConfig.register("icebert-pos", IceBertPosConfig)
AutoModel.register(IceBertPosConfig, IceBertPosForTokenClassification)
IceBertPosConfig.register_for_auto_class()
IceBertPosForTokenClassification.register_for_auto_class("AutoModel")