|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoModelForSeq2SeqLM, BartConfig, BartForConditionalGeneration, BertModel, BertPreTrainedModel |
|
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
|
|
|
|
|
|
class FiDEncoder(nn.Module): |
|
|
def __init__(self, encoder): |
|
|
super().__init__() |
|
|
self.encoder = encoder |
|
|
|
|
|
def forward(self, input_ids=None, attention_mask=None, **kwargs): |
|
|
|
|
|
bsz, total_length = input_ids.shape |
|
|
passage_length = total_length // self.n_passages |
|
|
input_ids = input_ids.view(bsz * self.n_passages, passage_length) |
|
|
attention_mask = attention_mask.view(bsz * self.n_passages, passage_length) |
|
|
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask, **kwargs) |
|
|
|
|
|
encoder_outputs = BaseModelOutput( |
|
|
last_hidden_state=encoder_outputs[0].view(bsz, self.n_passages * passage_length, -1), |
|
|
hidden_states=encoder_outputs.hidden_states, |
|
|
attentions=encoder_outputs.attentions, |
|
|
) |
|
|
|
|
|
return encoder_outputs |
|
|
|
|
|
|
|
|
class FiD(BartForConditionalGeneration): |
|
|
def __init__(self, model_config: BartConfig): |
|
|
if model_config: |
|
|
super().__init__(model_config) |
|
|
else: |
|
|
super().__init__() |
|
|
self.wrap_encoder(self.model.encoder) |
|
|
|
|
|
def load_pretrained_params(self, basemodel_name): |
|
|
basemodel = AutoModelForSeq2SeqLM.from_pretrained(basemodel_name) |
|
|
self.model.encoder.encoder.load_state_dict(basemodel.get_encoder().state_dict()) |
|
|
self.model.decoder.load_state_dict(basemodel.get_decoder().state_dict()) |
|
|
self.lm_head.load_state_dict(basemodel.lm_head.state_dict()) |
|
|
print(f"loaded {basemodel_name} parameters.") |
|
|
|
|
|
def wrap_encoder(self, encoder): |
|
|
self.model.encoder = FiDEncoder(encoder) |
|
|
|
|
|
@classmethod |
|
|
def from_path(cls, model_config): |
|
|
|
|
|
return cls(model_config=model_config) |
|
|
|
|
|
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor = None, **kwargs): |
|
|
""" |
|
|
|
|
|
Args: |
|
|
input_ids (torch.Tensor): (batch_size, topk, seq_max_len) shape tensor |
|
|
attention_mask (torch.Tensor): (batch_size, topk, seq_max_len) shape tensor |
|
|
labels (torch.Tensor): (batch_size, seq_max_len) shape summarization Tensor |
|
|
|
|
|
Return: |
|
|
{ |
|
|
logit (torch.Tensor): |
|
|
(batch_size, max_seq_len, vocab_size) shape logit |
|
|
loss_fn: |
|
|
logit과 label간의 loss_fn |
|
|
last_hidden_state (torch.Tensor): |
|
|
(batch_size, max_seq_len, hidden_dim) shape tensor |
|
|
} |
|
|
""" |
|
|
if input_ids is not None: |
|
|
if input_ids.ndim == 3: |
|
|
self.model.encoder.n_passages = input_ids.size(1) |
|
|
input_ids = input_ids.view(input_ids.size(0), -1) |
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.view(attention_mask.size(0), -1) |
|
|
|
|
|
return super().forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs) |
|
|
|
|
|
def generate(self, input_ids, attention_mask, max_length=256): |
|
|
return super().generate( |
|
|
inputs=input_ids.view(input_ids.size(0), -1), |
|
|
attention_mask=attention_mask.view(attention_mask.size(0), -1), |
|
|
max_length=max_length, |
|
|
) |