pibert / modeling_jointbert.py
smenaaliaga's picture
Upload PIBot Joint BERT model package
d568351 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel
from .module import CalcModeClassifier, ActivityClassifier, RegionClassifier, InvestmentClassifier, ReqFormClassifier, SlotClassifier
try:
from torchcrf import CRF
except ImportError:
CRF = None
class JointBERT(PreTrainedModel):
def __init__(self, config, args, calc_mode_label_lst, activity_label_lst, region_label_lst, investment_label_lst, req_form_label_lst, slot_label_lst):
super(JointBERT, self).__init__(config)
self.args = args
self.num_calc_mode_labels = len(calc_mode_label_lst)
self.num_activity_labels = len(activity_label_lst)
self.num_region_labels = len(region_label_lst)
self.num_investment_labels = len(investment_label_lst)
self.num_req_form_labels = len(req_form_label_lst)
self.num_slot_labels = len(slot_label_lst)
# Usar AutoModel para soportar cualquier encoder transformer
self.encoder = AutoModel.from_pretrained(args.model_name_or_path, config=config)
self.calc_mode_classifier = CalcModeClassifier(config.hidden_size, self.num_calc_mode_labels, args.dropout_rate)
self.activity_classifier = ActivityClassifier(config.hidden_size, self.num_activity_labels, args.dropout_rate)
self.region_classifier = RegionClassifier(config.hidden_size, self.num_region_labels, args.dropout_rate)
self.investment_classifier = InvestmentClassifier(config.hidden_size, self.num_investment_labels, args.dropout_rate)
self.req_form_classifier = ReqFormClassifier(config.hidden_size, self.num_req_form_labels, args.dropout_rate)
self.slot_classifier = SlotClassifier(config.hidden_size, self.num_slot_labels, args.dropout_rate)
if args.use_crf:
if CRF is None:
raise ImportError("torchcrf no está instalado. Instala con: pip install pytorch-crf o ejecuta sin --use_crf")
crf_init_errors = []
for init_fn in (
lambda: CRF(self.num_slot_labels, pad_idx=None, use_gpu=False),
lambda: CRF(self.num_slot_labels, batch_first=True),
lambda: CRF(num_tags=self.num_slot_labels, batch_first=True),
lambda: CRF(self.num_slot_labels),
lambda: CRF(num_tags=self.num_slot_labels),
):
try:
self.crf = init_fn()
break
except TypeError as e:
crf_init_errors.append(str(e))
else:
raise TypeError("No se pudo inicializar CRF con las firmas conocidas: " + " | ".join(crf_init_errors))
def forward(self, input_ids, attention_mask, token_type_ids=None,
calc_mode_label_ids=None, activity_label_ids=None, region_label_ids=None, investment_label_ids=None, req_form_label_ids=None, slot_labels_ids=None):
outputs = self.encoder(input_ids, attention_mask=attention_mask,
token_type_ids=token_type_ids) # sequence_output, pooled_output, (hidden_states), (attentions)
sequence_output = outputs[0]
pooled_output = getattr(outputs, "pooler_output", None)
if pooled_output is None:
if len(outputs) > 1 and outputs[1] is not None and getattr(outputs[1], "dim", lambda: 0)() == 2:
pooled_output = outputs[1]
else:
pooled_output = sequence_output[:, 0]
calc_mode_logits = self.calc_mode_classifier(pooled_output)
activity_logits = self.activity_classifier(pooled_output)
region_logits = self.region_classifier(pooled_output)
investment_logits = self.investment_classifier(pooled_output)
req_form_logits = self.req_form_classifier(pooled_output)
slot_logits = self.slot_classifier(sequence_output)
total_loss = 0
def _get_weight(head_name):
"""Retorna class weights registrados como buffer, o None."""
buf_name = f"{head_name}_class_weights"
w = getattr(self, buf_name, None)
return w
# 1. Calc Mode CrossEntropy
if calc_mode_label_ids is not None:
calc_mode_loss_fct = nn.CrossEntropyLoss(weight=_get_weight('calc_mode'))
calc_mode_loss = calc_mode_loss_fct(calc_mode_logits.view(-1, self.num_calc_mode_labels), calc_mode_label_ids.view(-1))
total_loss += calc_mode_loss
# 2. Activity CrossEntropy
if activity_label_ids is not None:
activity_loss_fct = nn.CrossEntropyLoss(weight=_get_weight('activity'))
activity_loss = activity_loss_fct(activity_logits.view(-1, self.num_activity_labels), activity_label_ids.view(-1))
total_loss += activity_loss
# 3. Region CrossEntropy
if region_label_ids is not None:
region_loss_fct = nn.CrossEntropyLoss(weight=_get_weight('region'))
region_loss = region_loss_fct(region_logits.view(-1, self.num_region_labels), region_label_ids.view(-1))
total_loss += region_loss
# 4. Investment CrossEntropy
if investment_label_ids is not None:
investment_loss_fct = nn.CrossEntropyLoss(weight=_get_weight('investment'))
investment_loss = investment_loss_fct(investment_logits.view(-1, self.num_investment_labels), investment_label_ids.view(-1))
total_loss += investment_loss
# 5. Req Form CrossEntropy
if req_form_label_ids is not None:
req_form_loss_fct = nn.CrossEntropyLoss(weight=_get_weight('req_form'))
req_form_loss = req_form_loss_fct(req_form_logits.view(-1, self.num_req_form_labels), req_form_label_ids.view(-1))
total_loss += req_form_loss
# 6. Slot Softmax
if slot_labels_ids is not None and self.args.slot_loss_coef != 0:
if self.args.use_crf:
# CRF doesn't handle ignore_index (-100), so we replace it with PAD (0)
slot_labels_ids_crf = slot_labels_ids.clone()
slot_labels_ids_crf[slot_labels_ids_crf == self.args.ignore_index] = 0
if hasattr(self.crf, 'viterbi_decode'):
# TorchCRF API: forward returns log-likelihood per batch item
slot_loss = -self.crf(slot_logits, slot_labels_ids_crf, attention_mask.bool()).mean()
else:
# pytorch-crf API
slot_loss = self.crf(slot_logits, slot_labels_ids_crf, mask=attention_mask.bool(), reduction='mean')
slot_loss = -1 * slot_loss # negative log-likelihood
else:
slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
active_labels = slot_labels_ids.view(-1)[active_loss]
slot_loss = slot_loss_fct(active_logits, active_labels)
else:
slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))
total_loss += self.args.slot_loss_coef * slot_loss
outputs = ((calc_mode_logits, activity_logits, region_logits, investment_logits, req_form_logits, slot_logits),) + outputs[2:] # add hidden states and attention if they are here
outputs = (total_loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of all classifier logits