pibot-jointbert / modeling_jointbert.py
smenaaliaga's picture
Upload PIBot Joint BERT model with full reproducibility
4a7e280 verified
"""
Use:
python load_local_model.py --model_dir model_out/pibot_model_v3 --text "cual fue el pib del ultimo trimestre"
"""
import torch
import torch.nn as nn
from transformers import BertPreTrainedModel, BertModel, BertConfig
from torchcrf import CRF
from module import IndicatorClassifier, MetricTypeClassifier, CalcModeClassifier, SeasonalClassifier, ReqFormClassifier, FrequencyClassifier, ActivityClassifier #, SlotClassifier
class JointBERT(BertPreTrainedModel):
def __init__(self, config, args, indicator_label_lst, metric_type_label_lst, calc_mode_label_lst,
seasonal_label_lst, req_form_label_lst, frequency_label_lst, activity_label_lst): #, slot_label_lst):
super(JointBERT, self).__init__(config)
self.args = args
self.num_indicator_labels = len(indicator_label_lst)
self.num_metric_type_labels = len(metric_type_label_lst)
self.num_calc_mode_labels = len(calc_mode_label_lst)
self.num_seasonal_labels = len(seasonal_label_lst)
self.num_req_form_labels = len(req_form_label_lst)
self.num_frequency_labels = len(frequency_label_lst)
self.num_activity_labels = len(activity_label_lst)
# self.num_slot_labels = len(slot_label_lst)
self.bert = BertModel(config=config) # Load pretrained bert
self.indicator_classifier = IndicatorClassifier(config.hidden_size, self.num_indicator_labels, args.dropout_rate)
self.metric_type_classifier = MetricTypeClassifier(config.hidden_size, self.num_metric_type_labels, args.dropout_rate)
self.calc_mode_classifier = CalcModeClassifier(config.hidden_size, self.num_calc_mode_labels, args.dropout_rate)
self.seasonal_classifier = SeasonalClassifier(config.hidden_size, self.num_seasonal_labels, args.dropout_rate)
self.req_form_classifier = ReqFormClassifier(config.hidden_size, self.num_req_form_labels, args.dropout_rate)
self.frequency_classifier = FrequencyClassifier(config.hidden_size, self.num_frequency_labels, args.dropout_rate)
self.activity_classifier = ActivityClassifier(config.hidden_size, self.num_activity_labels, args.dropout_rate)
# self.slot_classifier = SlotClassifier(config.hidden_size, self.num_slot_labels, args.dropout_rate)
# if args.use_crf:
# self.crf = CRF(num_tags=self.num_slot_labels, batch_first=True)
def forward(self, input_ids, attention_mask, token_type_ids=None, indicator_label_ids=None, metric_type_label_ids=None,
calc_mode_label_ids=None, seasonal_label_ids=None, req_form_label_ids=None, frequency_label_ids=None, activity_label_ids=None): #, slot_labels_ids=None):
outputs = self.bert(input_ids, attention_mask=attention_mask,
token_type_ids=token_type_ids) # sequence_output, pooled_output, (hidden_states), (attentions)
sequence_output = outputs[0]
pooled_output = outputs[1] # [CLS]
indicator_logits = self.indicator_classifier(pooled_output)
metric_type_logits = self.metric_type_classifier(pooled_output)
calc_mode_logits = self.calc_mode_classifier(pooled_output)
seasonal_logits = self.seasonal_classifier(pooled_output)
req_form_logits = self.req_form_classifier(pooled_output)
frequency_logits = self.frequency_classifier(pooled_output)
activity_logits = self.activity_classifier(pooled_output)
# slot_logits = self.slot_classifier(sequence_output)
total_loss = 0
# 1. Indicator CrossEntropy
if indicator_label_ids is not None:
indicator_loss_fct = nn.CrossEntropyLoss()
indicator_loss = indicator_loss_fct(indicator_logits.view(-1, self.num_indicator_labels), indicator_label_ids.view(-1))
total_loss += indicator_loss
# 2. Metric Type CrossEntropy
if metric_type_label_ids is not None:
metric_type_loss_fct = nn.CrossEntropyLoss()
metric_type_loss = metric_type_loss_fct(metric_type_logits.view(-1, self.num_metric_type_labels), metric_type_label_ids.view(-1))
total_loss += metric_type_loss
# 3. Calc Mode CrossEntropy
if calc_mode_label_ids is not None:
calc_mode_loss_fct = nn.CrossEntropyLoss()
calc_mode_loss = calc_mode_loss_fct(calc_mode_logits.view(-1, self.num_calc_mode_labels), calc_mode_label_ids.view(-1))
total_loss += calc_mode_loss
# 4. Seasonal CrossEntropy
if seasonal_label_ids is not None:
seasonal_loss_fct = nn.CrossEntropyLoss()
seasonal_loss = seasonal_loss_fct(seasonal_logits.view(-1, self.num_seasonal_labels), seasonal_label_ids.view(-1))
total_loss += seasonal_loss
# 5. Req Form CrossEntropy
if req_form_label_ids is not None:
req_form_loss_fct = nn.CrossEntropyLoss()
req_form_loss = req_form_loss_fct(req_form_logits.view(-1, self.num_req_form_labels), req_form_label_ids.view(-1))
total_loss += req_form_loss
# 6. Frequency CrossEntropy
if frequency_label_ids is not None:
frequency_loss_fct = nn.CrossEntropyLoss()
frequency_loss = frequency_loss_fct(frequency_logits.view(-1, self.num_frequency_labels), frequency_label_ids.view(-1))
total_loss += frequency_loss
# 7. Activity CrossEntropy
if activity_label_ids is not None:
activity_loss_fct = nn.CrossEntropyLoss()
activity_loss = activity_loss_fct(activity_logits.view(-1, self.num_activity_labels), activity_label_ids.view(-1))
total_loss += activity_loss
# # 8. Slot Softmax
# if slot_labels_ids is not None and self.args.slot_loss_coef != 0:
# if self.args.use_crf:
# # CRF doesn't handle ignore_index (-100), so we replace it with PAD (0)
# slot_labels_ids_crf = slot_labels_ids.clone()
# slot_labels_ids_crf[slot_labels_ids_crf == self.args.ignore_index] = 0
# slot_loss = self.crf(slot_logits, slot_labels_ids_crf, mask=attention_mask.bool(), reduction='mean')
# slot_loss = -1 * slot_loss # negative log-likelihood
# else:
# slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
# # Only keep active parts of the loss
# if attention_mask is not None:
# active_loss = attention_mask.view(-1) == 1
# active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
# active_labels = slot_labels_ids.view(-1)[active_loss]
# slot_loss = slot_loss_fct(active_logits, active_labels)
# else:
# slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))
# total_loss += self.args.slot_loss_coef * slot_loss
outputs = ((indicator_logits, metric_type_logits, calc_mode_logits, seasonal_logits, req_form_logits, frequency_logits, activity_logits),) + outputs[2:] # add hidden states and attention if they are here #, slot_logits
outputs = (total_loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of all classifier logits