| import torch.nn as nn
|
| from transformers import PreTrainedModel, AutoModel
|
| from .module import CalcModeClassifier, ActivityClassifier, RegionClassifier, InvestmentClassifier, ReqFormClassifier, SlotClassifier
|
|
|
| try:
|
| from torchcrf import CRF
|
| except ImportError:
|
| CRF = None
|
|
|
| class JointBERT(PreTrainedModel):
|
| def __init__(self, config, args, calc_mode_label_lst, activity_label_lst, region_label_lst, investment_label_lst, req_form_label_lst, slot_label_lst):
|
| super(JointBERT, self).__init__(config)
|
| self.args = args
|
|
|
| self.num_calc_mode_labels = len(calc_mode_label_lst)
|
| self.num_activity_labels = len(activity_label_lst)
|
| self.num_region_labels = len(region_label_lst)
|
| self.num_investment_labels = len(investment_label_lst)
|
| self.num_req_form_labels = len(req_form_label_lst)
|
| self.num_slot_labels = len(slot_label_lst)
|
|
|
|
|
| self.encoder = AutoModel.from_pretrained(args.model_name_or_path, config=config)
|
|
|
| self.calc_mode_classifier = CalcModeClassifier(config.hidden_size, self.num_calc_mode_labels, args.dropout_rate)
|
| self.activity_classifier = ActivityClassifier(config.hidden_size, self.num_activity_labels, args.dropout_rate)
|
| self.region_classifier = RegionClassifier(config.hidden_size, self.num_region_labels, args.dropout_rate)
|
| self.investment_classifier = InvestmentClassifier(config.hidden_size, self.num_investment_labels, args.dropout_rate)
|
| self.req_form_classifier = ReqFormClassifier(config.hidden_size, self.num_req_form_labels, args.dropout_rate)
|
| self.slot_classifier = SlotClassifier(config.hidden_size, self.num_slot_labels, args.dropout_rate)
|
|
|
| if args.use_crf:
|
| if CRF is None:
|
| raise ImportError("torchcrf no está instalado. Instala con: pip install pytorch-crf o ejecuta sin --use_crf")
|
| crf_init_errors = []
|
| for init_fn in (
|
| lambda: CRF(self.num_slot_labels, pad_idx=None, use_gpu=False),
|
| lambda: CRF(self.num_slot_labels, batch_first=True),
|
| lambda: CRF(num_tags=self.num_slot_labels, batch_first=True),
|
| lambda: CRF(self.num_slot_labels),
|
| lambda: CRF(num_tags=self.num_slot_labels),
|
| ):
|
| try:
|
| self.crf = init_fn()
|
| break
|
| except TypeError as e:
|
| crf_init_errors.append(str(e))
|
| else:
|
| raise TypeError("No se pudo inicializar CRF con las firmas conocidas: " + " | ".join(crf_init_errors))
|
|
|
| def forward(self, input_ids, attention_mask, token_type_ids=None,
|
| calc_mode_label_ids=None, activity_label_ids=None, region_label_ids=None, investment_label_ids=None, req_form_label_ids=None, slot_labels_ids=None):
|
| outputs = self.encoder(input_ids, attention_mask=attention_mask,
|
| token_type_ids=token_type_ids)
|
| sequence_output = outputs[0]
|
| pooled_output = outputs[1]
|
|
|
| calc_mode_logits = self.calc_mode_classifier(pooled_output)
|
| activity_logits = self.activity_classifier(pooled_output)
|
| region_logits = self.region_classifier(pooled_output)
|
| investment_logits = self.investment_classifier(pooled_output)
|
| req_form_logits = self.req_form_classifier(pooled_output)
|
| slot_logits = self.slot_classifier(sequence_output)
|
|
|
| total_loss = 0
|
|
|
| if calc_mode_label_ids is not None:
|
| calc_mode_loss_fct = nn.CrossEntropyLoss()
|
| calc_mode_loss = calc_mode_loss_fct(calc_mode_logits.view(-1, self.num_calc_mode_labels), calc_mode_label_ids.view(-1))
|
| total_loss += calc_mode_loss
|
|
|
|
|
| if activity_label_ids is not None:
|
| activity_loss_fct = nn.CrossEntropyLoss()
|
| activity_loss = activity_loss_fct(activity_logits.view(-1, self.num_activity_labels), activity_label_ids.view(-1))
|
| total_loss += activity_loss
|
|
|
|
|
| if region_label_ids is not None:
|
| region_loss_fct = nn.CrossEntropyLoss()
|
| region_loss = region_loss_fct(region_logits.view(-1, self.num_region_labels), region_label_ids.view(-1))
|
| total_loss += region_loss
|
|
|
|
|
| if investment_label_ids is not None:
|
| investment_loss_fct = nn.CrossEntropyLoss()
|
| investment_loss = investment_loss_fct(investment_logits.view(-1, self.num_investment_labels), investment_label_ids.view(-1))
|
| total_loss += investment_loss
|
|
|
|
|
| if req_form_label_ids is not None:
|
| req_form_loss_fct = nn.CrossEntropyLoss()
|
| req_form_loss = req_form_loss_fct(req_form_logits.view(-1, self.num_req_form_labels), req_form_label_ids.view(-1))
|
| total_loss += req_form_loss
|
|
|
|
|
| if slot_labels_ids is not None and self.args.slot_loss_coef != 0:
|
| if self.args.use_crf:
|
|
|
| slot_labels_ids_crf = slot_labels_ids.clone()
|
| slot_labels_ids_crf[slot_labels_ids_crf == self.args.ignore_index] = 0
|
| if hasattr(self.crf, 'viterbi_decode'):
|
|
|
| slot_loss = -self.crf(slot_logits, slot_labels_ids_crf, attention_mask.bool()).mean()
|
| else:
|
|
|
| slot_loss = self.crf(slot_logits, slot_labels_ids_crf, mask=attention_mask.bool(), reduction='mean')
|
| slot_loss = -1 * slot_loss
|
| else:
|
| slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
|
|
|
| if attention_mask is not None:
|
| active_loss = attention_mask.view(-1) == 1
|
| active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
|
| active_labels = slot_labels_ids.view(-1)[active_loss]
|
| slot_loss = slot_loss_fct(active_logits, active_labels)
|
| else:
|
| slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))
|
| total_loss += self.args.slot_loss_coef * slot_loss
|
|
|
| outputs = ((calc_mode_logits, activity_logits, region_logits, investment_logits, req_form_logits, slot_logits),) + outputs[2:]
|
|
|
| outputs = (total_loss,) + outputs
|
|
|
| return outputs |