|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
from transformers import AutoConfig, AutoModel, PreTrainedModel, RobertaModel |
|
|
|
|
|
from .configuration import IceBertPosConfig |
|
|
from .old_label_utils import ( |
|
|
SimpleLabelDictionary, |
|
|
clean_cats_attrs, |
|
|
create_label_dictionary_from_schema, |
|
|
make_dict_idx_to_vec_idx, |
|
|
make_group_masks, |
|
|
make_group_name_to_group_attr_vec_idxs, |
|
|
make_vec_idx_to_dict_idx, |
|
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class MultiLabelTokenClassificationHead(nn.Module): |
|
|
"""Head for multilabel word-level classification tasks.""" |
|
|
|
|
|
def __init__(self, config: IceBertPosConfig): |
|
|
super().__init__() |
|
|
self.num_categories = config.num_categories |
|
|
self.num_labels = config.num_labels |
|
|
self.hidden_size = config.hidden_size |
|
|
|
|
|
self.dense = nn.Linear(self.hidden_size, self.hidden_size) |
|
|
self.activation_fn = F.relu |
|
|
self.dropout = nn.Dropout(p=config.classifier_dropout) |
|
|
self.layer_norm = nn.LayerNorm(self.hidden_size) |
|
|
|
|
|
|
|
|
self.cat_proj = nn.Linear(self.hidden_size, self.num_categories) |
|
|
|
|
|
|
|
|
self.out_proj = nn.Linear(self.hidden_size + self.num_categories, self.num_labels) |
|
|
|
|
|
def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Args: |
|
|
features: Word-level features of shape (batch_size, max_words, hidden_size) |
|
|
|
|
|
Returns: |
|
|
cat_logits: Category logits of shape (batch_size, max_words, num_categories) |
|
|
attr_logits: Attribute logits of shape (batch_size, max_words, num_labels) |
|
|
""" |
|
|
x = self.dropout(features) |
|
|
x = self.dense(x) |
|
|
x = self.layer_norm(x) |
|
|
x = self.activation_fn(x) |
|
|
|
|
|
|
|
|
cat_logits = self.cat_proj(x) |
|
|
cat_probs = torch.softmax(cat_logits, dim=-1) |
|
|
|
|
|
|
|
|
attr_input = torch.cat((cat_probs, x), dim=-1) |
|
|
attr_logits = self.out_proj(attr_input) |
|
|
|
|
|
return cat_logits, attr_logits |
|
|
|
|
|
|
|
|
class IceBertPosForTokenClassification(PreTrainedModel): |
|
|
""" |
|
|
IceBERT model for multilabel token classification (POS tagging). |
|
|
|
|
|
This model performs word-level POS tagging by: |
|
|
1. Encoding input with RoBERTa |
|
|
2. Aggregating subword tokens to word-level representations |
|
|
3. Predicting both categories and attributes for each word |
|
|
""" |
|
|
|
|
|
config_class = IceBertPosConfig |
|
|
|
|
|
def __init__(self, config: IceBertPosConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.num_categories = config.num_categories |
|
|
self.num_labels = config.num_labels |
|
|
|
|
|
self.roberta = RobertaModel(config, add_pooling_layer=False) |
|
|
self.classifier = MultiLabelTokenClassificationHead(config) |
|
|
|
|
|
|
|
|
self.label_dictionary = create_label_dictionary_from_schema(config.label_schema) |
|
|
self._setup_label_mappings() |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def _setup_label_mappings(self): |
|
|
"""Setup label mappings similar to the old fairseq model.""" |
|
|
schema = self.config.label_schema |
|
|
|
|
|
self.group_name_to_group_attr_vec_idxs = make_group_name_to_group_attr_vec_idxs(self.label_dictionary, schema) |
|
|
self.cat_dict_idx_to_vec_idx = make_dict_idx_to_vec_idx(self.label_dictionary, schema.label_categories) |
|
|
self.cat_vec_idx_to_dict_idx = make_vec_idx_to_dict_idx(self.label_dictionary, schema.label_categories) |
|
|
self.group_mask = make_group_masks(self.label_dictionary, schema) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
word_mask: torch.Tensor, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Args: |
|
|
input_ids: Token indices of shape (batch_size, sequence_length) |
|
|
attention_mask: Attention mask of shape (batch_size, sequence_length) |
|
|
word_mask: Binary mask indicating word boundaries (1 = word start) of shape (batch_size, sequence_length) |
|
|
|
|
|
Returns: |
|
|
cat_logits: Category logits of shape (batch_size, max_words, num_categories) |
|
|
attr_logits: Attribute logits of shape (batch_size, max_words, num_labels) |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
|
outputs = self.roberta( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=True, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
x = outputs[0] |
|
|
|
|
|
|
|
|
_, _, inner_dim = x.shape |
|
|
|
|
|
|
|
|
x = x[:, 1:-1, :] |
|
|
starts = word_mask[:, 1:-1] |
|
|
ends = starts.roll(-1, dims=[-1]).nonzero()[:, -1] + 1 |
|
|
starts = starts.nonzero().tolist() |
|
|
mean_words = [] |
|
|
for (seq_idx, token_idx), end in zip(starts, ends): |
|
|
mean_words.append(x[seq_idx, token_idx:end, :].mean(dim=0)) |
|
|
mean_words = torch.stack(mean_words) |
|
|
words = mean_words |
|
|
|
|
|
nwords = word_mask.sum(dim=-1) |
|
|
(cat_logits, attr_logits) = self.classifier(words) |
|
|
|
|
|
|
|
|
cat_logits = pad_sequence(cat_logits.split((nwords).tolist()), padding_value=0, batch_first=True) |
|
|
attr_logits = pad_sequence( |
|
|
attr_logits.split((nwords).tolist()), |
|
|
padding_value=0, |
|
|
batch_first=True, |
|
|
) |
|
|
return cat_logits, attr_logits |
|
|
|
|
|
def _aggregate_subword_tokens( |
|
|
self, sequence_output: torch.Tensor, word_mask: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Aggregate subword token representations to word-level representations. |
|
|
Following the original fairseq approach by averaging subword tokens within each word. |
|
|
|
|
|
Args: |
|
|
sequence_output: subword token representations (batch_size, seq_len, hidden_size) |
|
|
word_mask: Binary mask where 1 indicates start of word (batch_size, seq_len) |
|
|
|
|
|
Returns: |
|
|
word_features: Word-level features (batch_size, max_words, hidden_size) |
|
|
nwords: Number of words per sequence (batch_size,) |
|
|
""" |
|
|
|
|
|
|
|
|
x = sequence_output[:, 1:-1, :] |
|
|
starts = word_mask[:, 1:-1] |
|
|
|
|
|
|
|
|
nwords = starts.sum(dim=-1) |
|
|
|
|
|
|
|
|
mean_words = [] |
|
|
batch_size, seq_len, hidden_size = x.shape |
|
|
|
|
|
for batch_idx in range(batch_size): |
|
|
seq_starts = starts[batch_idx] |
|
|
seq_x = x[batch_idx] |
|
|
|
|
|
|
|
|
start_positions = seq_starts.nonzero(as_tuple=True)[0] |
|
|
|
|
|
if len(start_positions) == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
end_positions = torch.cat([start_positions[1:], torch.tensor([seq_len], device=start_positions.device)]) |
|
|
|
|
|
|
|
|
for start_pos, end_pos in zip(start_positions, end_positions): |
|
|
word_tokens = seq_x[start_pos:end_pos] |
|
|
word_repr = word_tokens.mean(dim=0) |
|
|
mean_words.append(word_repr) |
|
|
|
|
|
if len(mean_words) == 0: |
|
|
return torch.empty(0, sequence_output.size(-1), device=sequence_output.device), nwords |
|
|
|
|
|
return torch.stack(mean_words), nwords |
|
|
|
|
|
def _reshape_to_batch_format( |
|
|
self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, nwords: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Reshape word-level predictions back to batch format. |
|
|
Following the original fairseq approach with pad_sequence. |
|
|
|
|
|
Args: |
|
|
cat_logits: Category logits (total_words, num_categories) |
|
|
attr_logits: Attribute logits (total_words, num_labels) |
|
|
nwords: Number of words per sequence (batch_size,) |
|
|
|
|
|
Returns: |
|
|
cat_logits_batch: (batch_size, max_words, num_categories) |
|
|
attr_logits_batch: (batch_size, max_words, num_labels) |
|
|
""" |
|
|
|
|
|
|
|
|
words_per_seq = nwords.tolist() |
|
|
cat_logits_split = cat_logits.split(words_per_seq) |
|
|
attr_logits_split = attr_logits.split(words_per_seq) |
|
|
|
|
|
|
|
|
cat_logits_batch = pad_sequence(cat_logits_split, batch_first=True, padding_value=0) |
|
|
attr_logits_batch = pad_sequence(attr_logits_split, batch_first=True, padding_value=0) |
|
|
|
|
|
return cat_logits_batch, attr_logits_batch |
|
|
|
|
|
@torch.no_grad() |
|
|
def predict_labels( |
|
|
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, word_ids: List[List[int]] |
|
|
) -> List[List[Tuple[str, List[str]]]]: |
|
|
""" |
|
|
Predict POS labels for input sequences. |
|
|
|
|
|
Args: |
|
|
input_ids: Token indices |
|
|
attention_mask: Attention mask |
|
|
word_ids: Word boundaries |
|
|
|
|
|
Returns: |
|
|
List of sequences, each containing (category, [attributes]) per word |
|
|
""" |
|
|
|
|
|
word_mask = self._word_ids_to_word_mask(word_ids, input_ids.shape) |
|
|
|
|
|
cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask) |
|
|
|
|
|
return self._logits_to_labels(cat_logits, attr_logits, word_mask) |
|
|
|
|
|
def _word_ids_to_word_mask(self, word_ids: List[List[int]], input_shape: torch.Size) -> torch.Tensor: |
|
|
""" |
|
|
Convert word_ids to word_mask (binary mask indicating word boundaries). |
|
|
|
|
|
Args: |
|
|
word_ids: List of word id sequences |
|
|
input_shape: Shape of input_ids tensor (batch_size, seq_len) |
|
|
|
|
|
Returns: |
|
|
word_mask: Binary tensor where 1 indicates start of word (batch_size, seq_len) |
|
|
""" |
|
|
batch_size, seq_len = input_shape |
|
|
word_mask = torch.zeros(batch_size, seq_len, dtype=torch.long) |
|
|
|
|
|
for batch_idx, seq_word_ids in enumerate(word_ids): |
|
|
|
|
|
truncated_word_ids = seq_word_ids[1:-1] |
|
|
prev_word_id = None |
|
|
for token_idx, word_id in enumerate(truncated_word_ids): |
|
|
if word_id != prev_word_id: |
|
|
word_mask[batch_idx, token_idx + 1] = 1 |
|
|
prev_word_id = word_id |
|
|
|
|
|
|
|
|
logger.debug(f"Word mask: {word_mask[batch_idx].tolist()}") |
|
|
|
|
|
return word_mask |
|
|
|
|
|
def predict_labels_from_text(self, sentences: List[str], tokenizer) -> List[List[Tuple[str, List[str]]]]: |
|
|
""" |
|
|
Predict POS labels from raw text using fairseq-style preprocessing. |
|
|
|
|
|
Args: |
|
|
sentences: List of input sentences |
|
|
tokenizer: HuggingFace tokenizer |
|
|
|
|
|
Returns: |
|
|
List of sequences, each containing (category, [attributes]) per word |
|
|
""" |
|
|
|
|
|
|
|
|
sentences_split = [sentence.split() for sentence in sentences] |
|
|
|
|
|
|
|
|
encoding = tokenizer.batch_encode_plus( |
|
|
sentences_split, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
is_split_into_words=True, |
|
|
add_special_tokens=True |
|
|
) |
|
|
|
|
|
batch_input_ids = encoding["input_ids"] |
|
|
batch_attention_mask = encoding["attention_mask"] |
|
|
word_ids_list = [encoding.word_ids(i) for i in range(len(sentences))] |
|
|
|
|
|
|
|
|
for i in range(len(sentences)): |
|
|
logger.debug(f"Encoded tokens: {batch_input_ids[i]}") |
|
|
logger.debug(f"Decoded tokens: {tokenizer.convert_ids_to_tokens(batch_input_ids[i].tolist())}") |
|
|
logger.debug(f"Word IDs: {word_ids_list[i]}") |
|
|
|
|
|
return self.predict_labels(batch_input_ids, batch_attention_mask, word_ids_list) |
|
|
|
|
|
def _logits_to_labels( |
|
|
self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_mask: torch.Tensor |
|
|
) -> List[List[Tuple[str, List[str]]]]: |
|
|
""" |
|
|
Convert logits to human-readable labels using fairseq's group-based logic. |
|
|
Copied from the old model's logits_to_labels method. |
|
|
""" |
|
|
|
|
|
bsz, _, num_cats = cat_logits.shape |
|
|
_, _, num_attrs = attr_logits.shape |
|
|
nwords = word_mask.sum(-1) |
|
|
|
|
|
assert num_attrs == len(self.config.label_schema.labels) |
|
|
assert num_cats == len(self.config.label_schema.label_categories) |
|
|
|
|
|
batch_cats = [] |
|
|
batch_attrs = [] |
|
|
for seq_idx in range(bsz): |
|
|
seq_nwords = nwords[seq_idx] |
|
|
pred_cat_vec_idxs = cat_logits[seq_idx, :seq_nwords].max(dim=-1).indices |
|
|
pred_cats = self.cat_vec_idx_to_dict_idx[pred_cat_vec_idxs] |
|
|
|
|
|
group_mask = self.group_mask[pred_cat_vec_idxs] |
|
|
offset = self.label_dictionary.nspecial |
|
|
pred_attrs = [] |
|
|
for group_idx, group_name in enumerate(self.config.label_schema.group_names): |
|
|
group_vec_idxs = self.group_name_to_group_attr_vec_idxs[group_name] |
|
|
|
|
|
group_logits = attr_logits[seq_idx, :seq_nwords, group_vec_idxs] |
|
|
if len(group_vec_idxs) == 1: |
|
|
group_pred = group_logits.sigmoid().ge(0.5).long() |
|
|
group_pred_dict_idxs = (group_pred.squeeze() * (group_vec_idxs.item() + offset)).T.to( |
|
|
"cpu" |
|
|
) * group_mask[:, group_idx] |
|
|
else: |
|
|
group_pred_vec_idxs = group_logits.max(dim=-1).indices |
|
|
group_pred_dict_idxs = (group_vec_idxs[group_pred_vec_idxs] + offset) * group_mask[:, group_idx] |
|
|
pred_attrs.append(group_pred_dict_idxs) |
|
|
|
|
|
pred_attrs = torch.stack([p.squeeze() for p in pred_attrs]).t() |
|
|
|
|
|
batch_cats.append(pred_cats) |
|
|
batch_attrs.append(pred_attrs) |
|
|
|
|
|
predictions = list( |
|
|
[ |
|
|
clean_cats_attrs( |
|
|
self.label_dictionary, |
|
|
self.config.label_schema, |
|
|
seq_cats, |
|
|
seq_attrs, |
|
|
) |
|
|
for seq_cats, seq_attrs in zip(batch_cats, batch_attrs) |
|
|
] |
|
|
) |
|
|
|
|
|
return predictions |
|
|
|
|
|
|
|
|
def make_vec_idx_to_dict_idx(dictionary, labels, device="cpu", fill_value=-100): |
|
|
vec_idx_to_dict_idx = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long) |
|
|
for vec_idx, label in enumerate(labels): |
|
|
vec_idx_to_dict_idx[vec_idx] = dictionary.index(label) |
|
|
return vec_idx_to_dict_idx |
|
|
|
|
|
|
|
|
def make_group_masks(dictionary, schema, device="cpu"): |
|
|
num_groups = len(schema.group_names) |
|
|
offset = dictionary.nspecial |
|
|
num_labels = len(dictionary) - offset |
|
|
ret_mask = torch.zeros(num_labels, num_groups, dtype=torch.int64, device=device) |
|
|
for cat, cat_group_names in schema.category_to_group_names.items(): |
|
|
cat_label_idx = dictionary.index(cat) |
|
|
cat_vec_idx = schema.label_categories.index(cat) |
|
|
for group_name in cat_group_names: |
|
|
ret_mask[cat_vec_idx, schema.group_names.index(group_name)] = 1 |
|
|
assert cat_label_idx != dictionary.unk() |
|
|
for cat in schema.label_categories: |
|
|
cat_label_idx = dictionary.index(cat) |
|
|
assert cat_label_idx != dictionary.unk() |
|
|
return ret_mask |
|
|
|
|
|
|
|
|
def make_group_name_to_group_attr_vec_idxs(dict_, schema): |
|
|
offset = dict_.nspecial |
|
|
group_names = schema.group_name_to_labels.keys() |
|
|
name_to_labels = schema.group_name_to_labels |
|
|
group_name_to_group_attr_vec_idxs = { |
|
|
name: torch.tensor([dict_.index(item) - offset for item in name_to_labels[name]]) for name in group_names |
|
|
} |
|
|
return group_name_to_group_attr_vec_idxs |
|
|
|
|
|
|
|
|
def make_dict_idx_to_vec_idx(dictionary, cats, device="cpu", fill_value=-100): |
|
|
|
|
|
map_tgt = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long) |
|
|
for vec_idx, label in enumerate(cats): |
|
|
map_tgt[dictionary.index(label)] = vec_idx |
|
|
return map_tgt |
|
|
|
|
|
|
|
|
AutoConfig.register("icebert-pos", IceBertPosConfig) |
|
|
AutoModel.register(IceBertPosConfig, IceBertPosForTokenClassification) |
|
|
IceBertPosConfig.register_for_auto_class() |
|
|
IceBertPosForTokenClassification.register_for_auto_class("AutoModel") |
|
|
|