FiD-kor-bart / fid_model.py
mountinyy's picture
Add model code for FiD
360d81f
import torch
import torch.nn as nn
from transformers import AutoModelForSeq2SeqLM, BartConfig, BartForConditionalGeneration, BertModel, BertPreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
class FiDEncoder(nn.Module):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def forward(self, input_ids=None, attention_mask=None, **kwargs):
# passage_length = input_ids.size(-1)
bsz, total_length = input_ids.shape
passage_length = total_length // self.n_passages
input_ids = input_ids.view(bsz * self.n_passages, passage_length)
attention_mask = attention_mask.view(bsz * self.n_passages, passage_length)
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0].view(bsz, self.n_passages * passage_length, -1),
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
return encoder_outputs
class FiD(BartForConditionalGeneration):
def __init__(self, model_config: BartConfig):
if model_config:
super().__init__(model_config)
else:
super().__init__()
self.wrap_encoder(self.model.encoder)
def load_pretrained_params(self, basemodel_name):
basemodel = AutoModelForSeq2SeqLM.from_pretrained(basemodel_name)
self.model.encoder.encoder.load_state_dict(basemodel.get_encoder().state_dict())
self.model.decoder.load_state_dict(basemodel.get_decoder().state_dict())
self.lm_head.load_state_dict(basemodel.lm_head.state_dict())
print(f"loaded {basemodel_name} parameters.")
def wrap_encoder(self, encoder):
self.model.encoder = FiDEncoder(encoder)
@classmethod
def from_path(cls, model_config):
return cls(model_config=model_config)
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor = None, **kwargs):
"""
Args:
input_ids (torch.Tensor): (batch_size, topk, seq_max_len) shape tensor
attention_mask (torch.Tensor): (batch_size, topk, seq_max_len) shape tensor
labels (torch.Tensor): (batch_size, seq_max_len) shape summarization Tensor
Return:
{
logit (torch.Tensor):
(batch_size, max_seq_len, vocab_size) shape logit
loss_fn:
logit과 label간의 loss_fn
last_hidden_state (torch.Tensor):
(batch_size, max_seq_len, hidden_dim) shape tensor
}
"""
if input_ids is not None:
if input_ids.ndim == 3:
self.model.encoder.n_passages = input_ids.size(1)
input_ids = input_ids.view(input_ids.size(0), -1)
if attention_mask is not None:
attention_mask = attention_mask.view(attention_mask.size(0), -1)
return super().forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
def generate(self, input_ids, attention_mask, max_length=256):
return super().generate(
inputs=input_ids.view(input_ids.size(0), -1),
attention_mask=attention_mask.view(attention_mask.size(0), -1),
max_length=max_length,
)