Add model code for FiD
Browse files- fid_model.py +84 -0
fid_model.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformers import AutoModelForSeq2SeqLM, BartConfig, BartForConditionalGeneration, BertModel, BertPreTrainedModel
|
| 4 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class FiDEncoder(nn.Module):
|
| 8 |
+
def __init__(self, encoder):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.encoder = encoder
|
| 11 |
+
|
| 12 |
+
def forward(self, input_ids=None, attention_mask=None, **kwargs):
|
| 13 |
+
# passage_length = input_ids.size(-1)
|
| 14 |
+
bsz, total_length = input_ids.shape
|
| 15 |
+
passage_length = total_length // self.n_passages
|
| 16 |
+
input_ids = input_ids.view(bsz * self.n_passages, passage_length)
|
| 17 |
+
attention_mask = attention_mask.view(bsz * self.n_passages, passage_length)
|
| 18 |
+
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
| 19 |
+
|
| 20 |
+
encoder_outputs = BaseModelOutput(
|
| 21 |
+
last_hidden_state=encoder_outputs[0].view(bsz, self.n_passages * passage_length, -1),
|
| 22 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 23 |
+
attentions=encoder_outputs.attentions,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
return encoder_outputs
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class FiD(BartForConditionalGeneration):
|
| 30 |
+
def __init__(self, model_config: BartConfig):
|
| 31 |
+
if model_config:
|
| 32 |
+
super().__init__(model_config)
|
| 33 |
+
else:
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.wrap_encoder(self.model.encoder)
|
| 36 |
+
|
| 37 |
+
def load_pretrained_params(self, basemodel_name):
|
| 38 |
+
basemodel = AutoModelForSeq2SeqLM.from_pretrained(basemodel_name)
|
| 39 |
+
self.model.encoder.encoder.load_state_dict(basemodel.get_encoder().state_dict())
|
| 40 |
+
self.model.decoder.load_state_dict(basemodel.get_decoder().state_dict())
|
| 41 |
+
self.lm_head.load_state_dict(basemodel.lm_head.state_dict())
|
| 42 |
+
print(f"loaded {basemodel_name} parameters.")
|
| 43 |
+
|
| 44 |
+
def wrap_encoder(self, encoder):
|
| 45 |
+
self.model.encoder = FiDEncoder(encoder)
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def from_path(cls, model_config):
|
| 49 |
+
|
| 50 |
+
return cls(model_config=model_config)
|
| 51 |
+
|
| 52 |
+
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor = None, **kwargs):
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
input_ids (torch.Tensor): (batch_size, topk, seq_max_len) shape tensor
|
| 57 |
+
attention_mask (torch.Tensor): (batch_size, topk, seq_max_len) shape tensor
|
| 58 |
+
labels (torch.Tensor): (batch_size, seq_max_len) shape summarization Tensor
|
| 59 |
+
|
| 60 |
+
Return:
|
| 61 |
+
{
|
| 62 |
+
logit (torch.Tensor):
|
| 63 |
+
(batch_size, max_seq_len, vocab_size) shape logit
|
| 64 |
+
loss_fn:
|
| 65 |
+
logit과 label간의 loss_fn
|
| 66 |
+
last_hidden_state (torch.Tensor):
|
| 67 |
+
(batch_size, max_seq_len, hidden_dim) shape tensor
|
| 68 |
+
}
|
| 69 |
+
"""
|
| 70 |
+
if input_ids is not None:
|
| 71 |
+
if input_ids.ndim == 3:
|
| 72 |
+
self.model.encoder.n_passages = input_ids.size(1)
|
| 73 |
+
input_ids = input_ids.view(input_ids.size(0), -1)
|
| 74 |
+
if attention_mask is not None:
|
| 75 |
+
attention_mask = attention_mask.view(attention_mask.size(0), -1)
|
| 76 |
+
|
| 77 |
+
return super().forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
|
| 78 |
+
|
| 79 |
+
def generate(self, input_ids, attention_mask, max_length=256):
|
| 80 |
+
return super().generate(
|
| 81 |
+
inputs=input_ids.view(input_ids.size(0), -1),
|
| 82 |
+
attention_mask=attention_mask.view(attention_mask.size(0), -1),
|
| 83 |
+
max_length=max_length,
|
| 84 |
+
)
|