| | """
|
| | Use:
|
| | python load_local_model.py --model_dir model_out/pibot_model_v3 --text "cual fue el pib del ultimo trimestre"
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | from transformers import BertPreTrainedModel, BertModel, BertConfig
|
| | from torchcrf import CRF
|
| | from module import IndicatorClassifier, MetricTypeClassifier, CalcModeClassifier, SeasonalClassifier, ReqFormClassifier, FrequencyClassifier, ActivityClassifier
|
| |
|
| |
|
| | class JointBERT(BertPreTrainedModel):
|
| | def __init__(self, config, args, indicator_label_lst, metric_type_label_lst, calc_mode_label_lst,
|
| | seasonal_label_lst, req_form_label_lst, frequency_label_lst, activity_label_lst):
|
| | super(JointBERT, self).__init__(config)
|
| | self.args = args
|
| |
|
| | self.num_indicator_labels = len(indicator_label_lst)
|
| | self.num_metric_type_labels = len(metric_type_label_lst)
|
| | self.num_calc_mode_labels = len(calc_mode_label_lst)
|
| | self.num_seasonal_labels = len(seasonal_label_lst)
|
| | self.num_req_form_labels = len(req_form_label_lst)
|
| | self.num_frequency_labels = len(frequency_label_lst)
|
| | self.num_activity_labels = len(activity_label_lst)
|
| |
|
| |
|
| | self.bert = BertModel(config=config)
|
| |
|
| | self.indicator_classifier = IndicatorClassifier(config.hidden_size, self.num_indicator_labels, args.dropout_rate)
|
| | self.metric_type_classifier = MetricTypeClassifier(config.hidden_size, self.num_metric_type_labels, args.dropout_rate)
|
| | self.calc_mode_classifier = CalcModeClassifier(config.hidden_size, self.num_calc_mode_labels, args.dropout_rate)
|
| | self.seasonal_classifier = SeasonalClassifier(config.hidden_size, self.num_seasonal_labels, args.dropout_rate)
|
| | self.req_form_classifier = ReqFormClassifier(config.hidden_size, self.num_req_form_labels, args.dropout_rate)
|
| | self.frequency_classifier = FrequencyClassifier(config.hidden_size, self.num_frequency_labels, args.dropout_rate)
|
| | self.activity_classifier = ActivityClassifier(config.hidden_size, self.num_activity_labels, args.dropout_rate)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def forward(self, input_ids, attention_mask, token_type_ids=None, indicator_label_ids=None, metric_type_label_ids=None,
|
| | calc_mode_label_ids=None, seasonal_label_ids=None, req_form_label_ids=None, frequency_label_ids=None, activity_label_ids=None):
|
| | outputs = self.bert(input_ids, attention_mask=attention_mask,
|
| | token_type_ids=token_type_ids)
|
| | sequence_output = outputs[0]
|
| | pooled_output = outputs[1]
|
| |
|
| | indicator_logits = self.indicator_classifier(pooled_output)
|
| | metric_type_logits = self.metric_type_classifier(pooled_output)
|
| | calc_mode_logits = self.calc_mode_classifier(pooled_output)
|
| | seasonal_logits = self.seasonal_classifier(pooled_output)
|
| | req_form_logits = self.req_form_classifier(pooled_output)
|
| | frequency_logits = self.frequency_classifier(pooled_output)
|
| | activity_logits = self.activity_classifier(pooled_output)
|
| |
|
| |
|
| | total_loss = 0
|
| |
|
| | if indicator_label_ids is not None:
|
| | indicator_loss_fct = nn.CrossEntropyLoss()
|
| | indicator_loss = indicator_loss_fct(indicator_logits.view(-1, self.num_indicator_labels), indicator_label_ids.view(-1))
|
| | total_loss += indicator_loss
|
| |
|
| |
|
| | if metric_type_label_ids is not None:
|
| | metric_type_loss_fct = nn.CrossEntropyLoss()
|
| | metric_type_loss = metric_type_loss_fct(metric_type_logits.view(-1, self.num_metric_type_labels), metric_type_label_ids.view(-1))
|
| | total_loss += metric_type_loss
|
| |
|
| |
|
| | if calc_mode_label_ids is not None:
|
| | calc_mode_loss_fct = nn.CrossEntropyLoss()
|
| | calc_mode_loss = calc_mode_loss_fct(calc_mode_logits.view(-1, self.num_calc_mode_labels), calc_mode_label_ids.view(-1))
|
| | total_loss += calc_mode_loss
|
| |
|
| |
|
| | if seasonal_label_ids is not None:
|
| | seasonal_loss_fct = nn.CrossEntropyLoss()
|
| | seasonal_loss = seasonal_loss_fct(seasonal_logits.view(-1, self.num_seasonal_labels), seasonal_label_ids.view(-1))
|
| | total_loss += seasonal_loss
|
| |
|
| |
|
| | if req_form_label_ids is not None:
|
| | req_form_loss_fct = nn.CrossEntropyLoss()
|
| | req_form_loss = req_form_loss_fct(req_form_logits.view(-1, self.num_req_form_labels), req_form_label_ids.view(-1))
|
| | total_loss += req_form_loss
|
| |
|
| |
|
| | if frequency_label_ids is not None:
|
| | frequency_loss_fct = nn.CrossEntropyLoss()
|
| | frequency_loss = frequency_loss_fct(frequency_logits.view(-1, self.num_frequency_labels), frequency_label_ids.view(-1))
|
| | total_loss += frequency_loss
|
| |
|
| |
|
| | if activity_label_ids is not None:
|
| | activity_loss_fct = nn.CrossEntropyLoss()
|
| | activity_loss = activity_loss_fct(activity_logits.view(-1, self.num_activity_labels), activity_label_ids.view(-1))
|
| | total_loss += activity_loss
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | outputs = ((indicator_logits, metric_type_logits, calc_mode_logits, seasonal_logits, req_form_logits, frequency_logits, activity_logits),) + outputs[2:]
|
| |
|
| | outputs = (total_loss,) + outputs
|
| |
|
| | return outputs |