|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
from transformers import AutoConfig, AutoModel, PreTrainedModel, RobertaModel |
|
|
|
|
|
from .configuration import IceBertPosConfig |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class MultiLabelTokenClassificationHead(nn.Module): |
|
|
"""Head for multilabel word-level classification tasks.""" |
|
|
|
|
|
def __init__(self, config: IceBertPosConfig): |
|
|
super().__init__() |
|
|
self.num_categories = config.num_categories |
|
|
self.num_labels = config.num_labels |
|
|
self.hidden_size = config.hidden_size |
|
|
|
|
|
self.dense = nn.Linear(self.hidden_size, self.hidden_size) |
|
|
self.activation_fn = F.relu |
|
|
self.dropout = nn.Dropout(p=config.classifier_dropout) |
|
|
self.layer_norm = nn.LayerNorm(self.hidden_size) |
|
|
|
|
|
|
|
|
self.cat_proj = nn.Linear(self.hidden_size, self.num_categories) |
|
|
|
|
|
|
|
|
self.out_proj = nn.Linear(self.hidden_size + self.num_categories, self.num_labels) |
|
|
|
|
|
def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Args: |
|
|
features: Word-level features of shape (total_words, hidden_size) |
|
|
|
|
|
Returns: |
|
|
cat_logits: Category logits of shape (total_words, num_categories) |
|
|
attr_logits: Attribute logits of shape (total_words, num_labels) |
|
|
""" |
|
|
x = self.dropout(features) |
|
|
x = self.dense(x) |
|
|
x = self.layer_norm(x) |
|
|
x = self.activation_fn(x) |
|
|
|
|
|
|
|
|
cat_logits = self.cat_proj(x) |
|
|
cat_probs = torch.softmax(cat_logits, dim=-1) |
|
|
|
|
|
|
|
|
attr_input = torch.cat((cat_probs, x), dim=-1) |
|
|
attr_logits = self.out_proj(attr_input) |
|
|
|
|
|
return cat_logits, attr_logits |
|
|
|
|
|
|
|
|
class IceBertPosForTokenClassification(PreTrainedModel): |
|
|
""" |
|
|
IceBERT model for multilabel token classification (POS tagging). |
|
|
|
|
|
This model performs word-level POS tagging by: |
|
|
1. Encoding input with RoBERTa |
|
|
2. Aggregating subword tokens to word-level representations |
|
|
3. Predicting both categories and attributes for each word |
|
|
""" |
|
|
|
|
|
config_class = IceBertPosConfig |
|
|
|
|
|
def __init__(self, config: IceBertPosConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.num_categories = config.num_categories |
|
|
self.num_labels = config.num_labels |
|
|
|
|
|
self.roberta = RobertaModel(config, add_pooling_layer=False) |
|
|
self.classifier = MultiLabelTokenClassificationHead(config) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
word_mask: torch.Tensor, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Args: |
|
|
input_ids: Token indices of shape (batch_size, sequence_length) |
|
|
attention_mask: Attention mask of shape (batch_size, sequence_length) |
|
|
word_mask: Binary mask indicating word boundaries (1 = word start) |
|
|
|
|
|
Returns: |
|
|
cat_logits: Category logits of shape (batch_size, max_words, num_categories) |
|
|
attr_logits: Attribute logits of shape (batch_size, max_words, num_labels) |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
|
outputs = self.roberta( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
sequence_output = outputs[0] |
|
|
|
|
|
|
|
|
word_features, nwords = self._aggregate_subword_tokens(sequence_output, word_mask) |
|
|
|
|
|
|
|
|
cat_logits, attr_logits = self.classifier(word_features) |
|
|
|
|
|
|
|
|
cat_logits_batch, attr_logits_batch = self._reshape_to_batch_format(cat_logits, attr_logits, nwords) |
|
|
|
|
|
return cat_logits_batch, attr_logits_batch |
|
|
|
|
|
def _aggregate_subword_tokens( |
|
|
self, sequence_output: torch.Tensor, word_mask: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Aggregate subword token representations to word-level representations. |
|
|
Following the original fairseq approach by averaging subword tokens within each word. |
|
|
|
|
|
Args: |
|
|
sequence_output: subword token representations (batch_size, seq_len, hidden_size) |
|
|
word_mask: Binary mask where 1 indicates start of word (batch_size, seq_len) |
|
|
|
|
|
Returns: |
|
|
word_features: Word-level features (total_words, hidden_size) |
|
|
nwords: Number of words per sequence (batch_size,) |
|
|
""" |
|
|
|
|
|
|
|
|
x = sequence_output[:, 1:-1, :] |
|
|
starts = word_mask[:, 1:-1] |
|
|
|
|
|
|
|
|
nwords = starts.sum(dim=-1) |
|
|
|
|
|
|
|
|
mean_words = [] |
|
|
batch_size, seq_len, hidden_size = x.shape |
|
|
|
|
|
for batch_idx in range(batch_size): |
|
|
seq_starts = starts[batch_idx] |
|
|
seq_x = x[batch_idx] |
|
|
|
|
|
|
|
|
start_positions = seq_starts.nonzero(as_tuple=True)[0] |
|
|
|
|
|
if len(start_positions) == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
end_positions = torch.cat([start_positions[1:], torch.tensor([seq_len], device=start_positions.device)]) |
|
|
|
|
|
|
|
|
for start_pos, end_pos in zip(start_positions, end_positions): |
|
|
word_tokens = seq_x[start_pos:end_pos] |
|
|
word_repr = word_tokens.mean(dim=0) |
|
|
mean_words.append(word_repr) |
|
|
|
|
|
if len(mean_words) == 0: |
|
|
return torch.empty(0, sequence_output.size(-1), device=sequence_output.device), nwords |
|
|
|
|
|
return torch.stack(mean_words), nwords |
|
|
|
|
|
def _reshape_to_batch_format( |
|
|
self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, nwords: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Reshape word-level predictions back to batch format. |
|
|
Following the original fairseq approach with pad_sequence. |
|
|
|
|
|
Args: |
|
|
cat_logits: Category logits (total_words, num_categories) |
|
|
attr_logits: Attribute logits (total_words, num_labels) |
|
|
nwords: Number of words per sequence (batch_size,) |
|
|
|
|
|
Returns: |
|
|
cat_logits_batch: (batch_size, max_words, num_categories) |
|
|
attr_logits_batch: (batch_size, max_words, num_labels) |
|
|
""" |
|
|
|
|
|
|
|
|
words_per_seq = nwords.tolist() |
|
|
cat_logits_split = cat_logits.split(words_per_seq) |
|
|
attr_logits_split = attr_logits.split(words_per_seq) |
|
|
|
|
|
|
|
|
cat_logits_batch = pad_sequence(cat_logits_split, batch_first=True, padding_value=0) |
|
|
attr_logits_batch = pad_sequence(attr_logits_split, batch_first=True, padding_value=0) |
|
|
|
|
|
return cat_logits_batch, attr_logits_batch |
|
|
|
|
|
@torch.no_grad() |
|
|
def predict_labels( |
|
|
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, word_ids: List[List[int]] |
|
|
) -> List[List[Tuple[str, List[str]]]]: |
|
|
""" |
|
|
Predict POS labels for input sequences. |
|
|
|
|
|
Args: |
|
|
input_ids: Token indices |
|
|
attention_mask: Attention mask |
|
|
word_ids: Word boundaries |
|
|
|
|
|
Returns: |
|
|
List of sequences, each containing (category, [attributes]) per word |
|
|
""" |
|
|
|
|
|
word_mask = self._word_ids_to_word_mask(word_ids, input_ids.shape) |
|
|
|
|
|
cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask) |
|
|
|
|
|
return self._logits_to_labels(cat_logits, attr_logits, word_ids) |
|
|
|
|
|
def _word_ids_to_word_mask(self, word_ids: List[List[int]], input_shape: torch.Size) -> torch.Tensor: |
|
|
""" |
|
|
Convert word_ids to word_mask (binary mask indicating word boundaries). |
|
|
|
|
|
Args: |
|
|
word_ids: List of word id sequences |
|
|
input_shape: Shape of input_ids tensor (batch_size, seq_len) |
|
|
|
|
|
Returns: |
|
|
word_mask: Binary tensor where 1 indicates start of word |
|
|
""" |
|
|
batch_size, seq_len = input_shape |
|
|
word_mask = torch.zeros(batch_size, seq_len, dtype=torch.long) |
|
|
|
|
|
for batch_idx, seq_word_ids in enumerate(word_ids): |
|
|
prev_word_id = None |
|
|
for token_idx, word_id in enumerate(seq_word_ids): |
|
|
if word_id != prev_word_id: |
|
|
word_mask[batch_idx, token_idx] = 1 |
|
|
prev_word_id = word_id |
|
|
|
|
|
return word_mask |
|
|
|
|
|
def predict_labels_from_text(self, sentences: List[str], tokenizer) -> List[List[Tuple[str, List[str]]]]: |
|
|
""" |
|
|
Predict POS labels from raw text using fairseq-style preprocessing. |
|
|
|
|
|
Args: |
|
|
sentences: List of input sentences |
|
|
tokenizer: HuggingFace tokenizer |
|
|
|
|
|
Returns: |
|
|
List of sequences, each containing (category, [attributes]) per word |
|
|
""" |
|
|
|
|
|
encodings = [tokenizer(sent, return_tensors="pt") for sent in sentences] |
|
|
word_ids_list = [encoding.word_ids() for encoding in encodings] |
|
|
|
|
|
|
|
|
max_len = max(encoding["input_ids"].shape[1] for encoding in encodings) |
|
|
batch_input_ids = [] |
|
|
batch_attention_mask = [] |
|
|
|
|
|
for encoding in encodings: |
|
|
input_ids = encoding["input_ids"][0] |
|
|
attention_mask = encoding["attention_mask"][0] |
|
|
|
|
|
|
|
|
pad_len = max_len - len(input_ids) |
|
|
if pad_len > 0: |
|
|
input_ids = torch.cat([input_ids, torch.ones(pad_len, dtype=torch.long)]) |
|
|
attention_mask = torch.cat([attention_mask, torch.zeros(pad_len, dtype=torch.long)]) |
|
|
|
|
|
batch_input_ids.append(input_ids) |
|
|
batch_attention_mask.append(attention_mask) |
|
|
|
|
|
batch_input_ids = torch.stack(batch_input_ids) |
|
|
batch_attention_mask = torch.stack(batch_attention_mask) |
|
|
|
|
|
return self.predict_labels(batch_input_ids, batch_attention_mask, word_ids_list) |
|
|
|
|
|
def _make_group_name_to_group_attr_vec_idxs(self): |
|
|
"""Create mapping from group names to their attribute vector indices""" |
|
|
group_name_to_group_attr_vec_idxs = {} |
|
|
labels = self.config.label_schema["labels"] |
|
|
nspecial = 0 |
|
|
|
|
|
for group_name, group_labels in self.config.label_schema["group_name_to_labels"].items(): |
|
|
vec_idxs = [] |
|
|
for label in group_labels: |
|
|
if label in labels: |
|
|
|
|
|
label_dict_idx = labels.index(label) |
|
|
if label_dict_idx >= nspecial: |
|
|
vec_idxs.append(label_dict_idx - nspecial) |
|
|
group_name_to_group_attr_vec_idxs[group_name] = torch.tensor(vec_idxs) |
|
|
|
|
|
return group_name_to_group_attr_vec_idxs |
|
|
|
|
|
def _make_group_masks(self): |
|
|
"""Create group masks for each category""" |
|
|
label_categories = self.config.label_schema["label_categories"] |
|
|
group_names = self.config.label_schema["group_names"] |
|
|
category_to_group_names = self.config.label_schema["category_to_group_names"] |
|
|
|
|
|
num_cats = len(label_categories) |
|
|
num_groups = len(group_names) |
|
|
|
|
|
group_mask = torch.zeros(num_cats, num_groups, dtype=torch.bool) |
|
|
|
|
|
for cat_idx, category in enumerate(label_categories): |
|
|
if category in category_to_group_names: |
|
|
for group_name in category_to_group_names[category]: |
|
|
if group_name in group_names: |
|
|
group_idx = group_names.index(group_name) |
|
|
group_mask[cat_idx, group_idx] = True |
|
|
|
|
|
return group_mask |
|
|
|
|
|
def _make_category_mappings(self): |
|
|
"""Create mappings between category vector indices and dictionary indices""" |
|
|
labels = self.config.label_schema["labels"] |
|
|
label_categories = self.config.label_schema["label_categories"] |
|
|
|
|
|
|
|
|
cat_dict_idx_to_vec_idx = torch.zeros(len(labels), dtype=torch.long) |
|
|
cat_vec_idx_to_dict_idx = torch.zeros(len(label_categories), dtype=torch.long) |
|
|
|
|
|
for vec_idx, category in enumerate(label_categories): |
|
|
if category in labels: |
|
|
dict_idx = labels.index(category) |
|
|
cat_dict_idx_to_vec_idx[dict_idx] = vec_idx |
|
|
cat_vec_idx_to_dict_idx[vec_idx] = dict_idx |
|
|
|
|
|
return cat_dict_idx_to_vec_idx, cat_vec_idx_to_dict_idx |
|
|
|
|
|
def _count_words_per_sequence(self, word_ids: List[List[int]]) -> List[int]: |
|
|
"""Count the number of unique words in each sequence.""" |
|
|
words_per_seq = [] |
|
|
for seq_word_ids in word_ids: |
|
|
unique_word_ids = set(word_id for word_id in seq_word_ids if word_id is not None) |
|
|
words_per_seq.append(len(unique_word_ids)) |
|
|
return words_per_seq |
|
|
|
|
|
def _predict_categories_for_sequence( |
|
|
self, cat_logits: torch.Tensor, seq_idx: int, seq_nwords: int, cat_vec_idx_to_dict_idx: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Predict categories for a single sequence and return both vector and dictionary indices.""" |
|
|
pred_cat_vec_idxs = cat_logits[seq_idx, :seq_nwords].max(dim=-1).indices |
|
|
pred_cats = cat_vec_idx_to_dict_idx[pred_cat_vec_idxs] |
|
|
return pred_cat_vec_idxs, pred_cats |
|
|
|
|
|
def _predict_attributes_for_group( |
|
|
self, |
|
|
attr_logits: torch.Tensor, |
|
|
seq_idx: int, |
|
|
seq_nwords: int, |
|
|
group_vec_idxs: torch.Tensor, |
|
|
seq_group_mask: torch.Tensor, |
|
|
group_idx: int, |
|
|
) -> torch.Tensor: |
|
|
"""Predict attributes for a single group.""" |
|
|
if len(group_vec_idxs) == 0: |
|
|
return torch.zeros(seq_nwords, dtype=torch.long) |
|
|
|
|
|
|
|
|
group_logits = attr_logits[seq_idx, :seq_nwords, group_vec_idxs] |
|
|
|
|
|
if len(group_vec_idxs) == 1: |
|
|
|
|
|
group_pred = group_logits.sigmoid().ge(0.5).long() |
|
|
group_pred_dict_idxs = (group_pred.squeeze() * group_vec_idxs.item()) * seq_group_mask[:, group_idx] |
|
|
else: |
|
|
|
|
|
group_pred_vec_idxs = group_logits.max(dim=-1).indices |
|
|
group_pred_dict_idxs = group_vec_idxs[group_pred_vec_idxs] * seq_group_mask[:, group_idx] |
|
|
|
|
|
return group_pred_dict_idxs |
|
|
|
|
|
def _predict_all_attributes_for_sequence( |
|
|
self, |
|
|
attr_logits: torch.Tensor, |
|
|
seq_idx: int, |
|
|
seq_nwords: int, |
|
|
pred_cat_vec_idxs: torch.Tensor, |
|
|
group_name_to_group_attr_vec_idxs: dict, |
|
|
group_mask: torch.Tensor, |
|
|
group_names: List[str], |
|
|
) -> torch.Tensor: |
|
|
"""Predict all attributes for a single sequence.""" |
|
|
seq_group_mask = group_mask[pred_cat_vec_idxs] |
|
|
pred_attrs = [] |
|
|
|
|
|
for group_idx, group_name in enumerate(group_names): |
|
|
if group_name not in group_name_to_group_attr_vec_idxs: |
|
|
pred_attrs.append(torch.zeros(seq_nwords, dtype=torch.long)) |
|
|
continue |
|
|
|
|
|
group_vec_idxs = group_name_to_group_attr_vec_idxs[group_name] |
|
|
group_pred_dict_idxs = self._predict_attributes_for_group( |
|
|
attr_logits, seq_idx, seq_nwords, group_vec_idxs, seq_group_mask, group_idx |
|
|
) |
|
|
pred_attrs.append(group_pred_dict_idxs) |
|
|
|
|
|
|
|
|
if pred_attrs: |
|
|
return torch.stack([p.squeeze() if p.dim() > 1 else p for p in pred_attrs]).t() |
|
|
else: |
|
|
return torch.zeros(seq_nwords, len(group_names), dtype=torch.long) |
|
|
|
|
|
def _convert_predictions_to_labels( |
|
|
self, pred_cats: torch.Tensor, pred_attrs_tensor: torch.Tensor, labels: List[str], group_names: List[str] |
|
|
) -> List[Tuple[str, List[str]]]: |
|
|
"""Convert prediction tensors to human-readable labels.""" |
|
|
seq_nwords = pred_cats.size(0) |
|
|
seq_predictions = [] |
|
|
|
|
|
for word_idx in range(seq_nwords): |
|
|
|
|
|
cat_dict_idx = pred_cats[word_idx].item() |
|
|
if cat_dict_idx < len(labels): |
|
|
category = labels[cat_dict_idx] |
|
|
else: |
|
|
category = "UNK" |
|
|
|
|
|
|
|
|
attributes = [] |
|
|
for group_idx in range(len(group_names)): |
|
|
attr_dict_idx = pred_attrs_tensor[word_idx, group_idx].item() |
|
|
if attr_dict_idx > 0 and attr_dict_idx < len(labels): |
|
|
attributes.append(labels[attr_dict_idx]) |
|
|
|
|
|
seq_predictions.append((category, attributes)) |
|
|
|
|
|
return seq_predictions |
|
|
|
|
|
def _logits_to_labels( |
|
|
self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_ids: List[List[int]] |
|
|
) -> List[List[Tuple[str, List[str]]]]: |
|
|
""" |
|
|
Convert logits to human-readable labels using fairseq's group-based logic. |
|
|
""" |
|
|
|
|
|
group_name_to_group_attr_vec_idxs = self._make_group_name_to_group_attr_vec_idxs() |
|
|
group_mask = self._make_group_masks() |
|
|
cat_dict_idx_to_vec_idx, cat_vec_idx_to_dict_idx = self._make_category_mappings() |
|
|
|
|
|
label_schema = self.config.label_schema |
|
|
labels = label_schema["labels"] |
|
|
group_names = label_schema["group_names"] |
|
|
|
|
|
batch_size = cat_logits.size(0) |
|
|
words_per_seq = self._count_words_per_sequence(word_ids) |
|
|
batch_predictions = [] |
|
|
|
|
|
for seq_idx in range(batch_size): |
|
|
seq_nwords = words_per_seq[seq_idx] |
|
|
|
|
|
|
|
|
pred_cat_vec_idxs, pred_cats = self._predict_categories_for_sequence( |
|
|
cat_logits, seq_idx, seq_nwords, cat_vec_idx_to_dict_idx |
|
|
) |
|
|
|
|
|
|
|
|
pred_attrs_tensor = self._predict_all_attributes_for_sequence( |
|
|
attr_logits, |
|
|
seq_idx, |
|
|
seq_nwords, |
|
|
pred_cat_vec_idxs, |
|
|
group_name_to_group_attr_vec_idxs, |
|
|
group_mask, |
|
|
group_names, |
|
|
) |
|
|
|
|
|
|
|
|
seq_predictions = self._convert_predictions_to_labels(pred_cats, pred_attrs_tensor, labels, group_names) |
|
|
batch_predictions.append(seq_predictions) |
|
|
|
|
|
return batch_predictions |
|
|
|
|
|
|
|
|
AutoConfig.register("icebert-pos", IceBertPosConfig) |
|
|
AutoModel.register(IceBertPosConfig, IceBertPosForTokenClassification) |
|
|
IceBertPosConfig.register_for_auto_class() |
|
|
IceBertPosForTokenClassification.register_for_auto_class("AutoModel") |
|
|
|