File size: 7,867 Bytes
d568351 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel
from .module import CalcModeClassifier, ActivityClassifier, RegionClassifier, InvestmentClassifier, ReqFormClassifier, SlotClassifier
try:
from torchcrf import CRF
except ImportError:
CRF = None
class JointBERT(PreTrainedModel):
def __init__(self, config, args, calc_mode_label_lst, activity_label_lst, region_label_lst, investment_label_lst, req_form_label_lst, slot_label_lst):
super(JointBERT, self).__init__(config)
self.args = args
self.num_calc_mode_labels = len(calc_mode_label_lst)
self.num_activity_labels = len(activity_label_lst)
self.num_region_labels = len(region_label_lst)
self.num_investment_labels = len(investment_label_lst)
self.num_req_form_labels = len(req_form_label_lst)
self.num_slot_labels = len(slot_label_lst)
# Usar AutoModel para soportar cualquier encoder transformer
self.encoder = AutoModel.from_pretrained(args.model_name_or_path, config=config)
self.calc_mode_classifier = CalcModeClassifier(config.hidden_size, self.num_calc_mode_labels, args.dropout_rate)
self.activity_classifier = ActivityClassifier(config.hidden_size, self.num_activity_labels, args.dropout_rate)
self.region_classifier = RegionClassifier(config.hidden_size, self.num_region_labels, args.dropout_rate)
self.investment_classifier = InvestmentClassifier(config.hidden_size, self.num_investment_labels, args.dropout_rate)
self.req_form_classifier = ReqFormClassifier(config.hidden_size, self.num_req_form_labels, args.dropout_rate)
self.slot_classifier = SlotClassifier(config.hidden_size, self.num_slot_labels, args.dropout_rate)
if args.use_crf:
if CRF is None:
raise ImportError("torchcrf no está instalado. Instala con: pip install pytorch-crf o ejecuta sin --use_crf")
crf_init_errors = []
for init_fn in (
lambda: CRF(self.num_slot_labels, pad_idx=None, use_gpu=False),
lambda: CRF(self.num_slot_labels, batch_first=True),
lambda: CRF(num_tags=self.num_slot_labels, batch_first=True),
lambda: CRF(self.num_slot_labels),
lambda: CRF(num_tags=self.num_slot_labels),
):
try:
self.crf = init_fn()
break
except TypeError as e:
crf_init_errors.append(str(e))
else:
raise TypeError("No se pudo inicializar CRF con las firmas conocidas: " + " | ".join(crf_init_errors))
def forward(self, input_ids, attention_mask, token_type_ids=None,
calc_mode_label_ids=None, activity_label_ids=None, region_label_ids=None, investment_label_ids=None, req_form_label_ids=None, slot_labels_ids=None):
outputs = self.encoder(input_ids, attention_mask=attention_mask,
token_type_ids=token_type_ids) # sequence_output, pooled_output, (hidden_states), (attentions)
sequence_output = outputs[0]
pooled_output = getattr(outputs, "pooler_output", None)
if pooled_output is None:
if len(outputs) > 1 and outputs[1] is not None and getattr(outputs[1], "dim", lambda: 0)() == 2:
pooled_output = outputs[1]
else:
pooled_output = sequence_output[:, 0]
calc_mode_logits = self.calc_mode_classifier(pooled_output)
activity_logits = self.activity_classifier(pooled_output)
region_logits = self.region_classifier(pooled_output)
investment_logits = self.investment_classifier(pooled_output)
req_form_logits = self.req_form_classifier(pooled_output)
slot_logits = self.slot_classifier(sequence_output)
total_loss = 0
def _get_weight(head_name):
"""Retorna class weights registrados como buffer, o None."""
buf_name = f"{head_name}_class_weights"
w = getattr(self, buf_name, None)
return w
# 1. Calc Mode CrossEntropy
if calc_mode_label_ids is not None:
calc_mode_loss_fct = nn.CrossEntropyLoss(weight=_get_weight('calc_mode'))
calc_mode_loss = calc_mode_loss_fct(calc_mode_logits.view(-1, self.num_calc_mode_labels), calc_mode_label_ids.view(-1))
total_loss += calc_mode_loss
# 2. Activity CrossEntropy
if activity_label_ids is not None:
activity_loss_fct = nn.CrossEntropyLoss(weight=_get_weight('activity'))
activity_loss = activity_loss_fct(activity_logits.view(-1, self.num_activity_labels), activity_label_ids.view(-1))
total_loss += activity_loss
# 3. Region CrossEntropy
if region_label_ids is not None:
region_loss_fct = nn.CrossEntropyLoss(weight=_get_weight('region'))
region_loss = region_loss_fct(region_logits.view(-1, self.num_region_labels), region_label_ids.view(-1))
total_loss += region_loss
# 4. Investment CrossEntropy
if investment_label_ids is not None:
investment_loss_fct = nn.CrossEntropyLoss(weight=_get_weight('investment'))
investment_loss = investment_loss_fct(investment_logits.view(-1, self.num_investment_labels), investment_label_ids.view(-1))
total_loss += investment_loss
# 5. Req Form CrossEntropy
if req_form_label_ids is not None:
req_form_loss_fct = nn.CrossEntropyLoss(weight=_get_weight('req_form'))
req_form_loss = req_form_loss_fct(req_form_logits.view(-1, self.num_req_form_labels), req_form_label_ids.view(-1))
total_loss += req_form_loss
# 6. Slot Softmax
if slot_labels_ids is not None and self.args.slot_loss_coef != 0:
if self.args.use_crf:
# CRF doesn't handle ignore_index (-100), so we replace it with PAD (0)
slot_labels_ids_crf = slot_labels_ids.clone()
slot_labels_ids_crf[slot_labels_ids_crf == self.args.ignore_index] = 0
if hasattr(self.crf, 'viterbi_decode'):
# TorchCRF API: forward returns log-likelihood per batch item
slot_loss = -self.crf(slot_logits, slot_labels_ids_crf, attention_mask.bool()).mean()
else:
# pytorch-crf API
slot_loss = self.crf(slot_logits, slot_labels_ids_crf, mask=attention_mask.bool(), reduction='mean')
slot_loss = -1 * slot_loss # negative log-likelihood
else:
slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
active_labels = slot_labels_ids.view(-1)[active_loss]
slot_loss = slot_loss_fct(active_logits, active_labels)
else:
slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))
total_loss += self.args.slot_loss_coef * slot_loss
outputs = ((calc_mode_logits, activity_logits, region_logits, investment_logits, req_form_logits, slot_logits),) + outputs[2:] # add hidden states and attention if they are here
outputs = (total_loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of all classifier logits |