mountinyy commited on
Commit
360d81f
·
1 Parent(s): d2f5cd5

Add model code for FiD

Browse files
Files changed (1) hide show
  1. fid_model.py +84 -0
fid_model.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoModelForSeq2SeqLM, BartConfig, BartForConditionalGeneration, BertModel, BertPreTrainedModel
4
+ from transformers.modeling_outputs import BaseModelOutput
5
+
6
+
7
+ class FiDEncoder(nn.Module):
8
+ def __init__(self, encoder):
9
+ super().__init__()
10
+ self.encoder = encoder
11
+
12
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
13
+ # passage_length = input_ids.size(-1)
14
+ bsz, total_length = input_ids.shape
15
+ passage_length = total_length // self.n_passages
16
+ input_ids = input_ids.view(bsz * self.n_passages, passage_length)
17
+ attention_mask = attention_mask.view(bsz * self.n_passages, passage_length)
18
+ encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
19
+
20
+ encoder_outputs = BaseModelOutput(
21
+ last_hidden_state=encoder_outputs[0].view(bsz, self.n_passages * passage_length, -1),
22
+ hidden_states=encoder_outputs.hidden_states,
23
+ attentions=encoder_outputs.attentions,
24
+ )
25
+
26
+ return encoder_outputs
27
+
28
+
29
+ class FiD(BartForConditionalGeneration):
30
+ def __init__(self, model_config: BartConfig):
31
+ if model_config:
32
+ super().__init__(model_config)
33
+ else:
34
+ super().__init__()
35
+ self.wrap_encoder(self.model.encoder)
36
+
37
+ def load_pretrained_params(self, basemodel_name):
38
+ basemodel = AutoModelForSeq2SeqLM.from_pretrained(basemodel_name)
39
+ self.model.encoder.encoder.load_state_dict(basemodel.get_encoder().state_dict())
40
+ self.model.decoder.load_state_dict(basemodel.get_decoder().state_dict())
41
+ self.lm_head.load_state_dict(basemodel.lm_head.state_dict())
42
+ print(f"loaded {basemodel_name} parameters.")
43
+
44
+ def wrap_encoder(self, encoder):
45
+ self.model.encoder = FiDEncoder(encoder)
46
+
47
+ @classmethod
48
+ def from_path(cls, model_config):
49
+
50
+ return cls(model_config=model_config)
51
+
52
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor = None, **kwargs):
53
+ """
54
+
55
+ Args:
56
+ input_ids (torch.Tensor): (batch_size, topk, seq_max_len) shape tensor
57
+ attention_mask (torch.Tensor): (batch_size, topk, seq_max_len) shape tensor
58
+ labels (torch.Tensor): (batch_size, seq_max_len) shape summarization Tensor
59
+
60
+ Return:
61
+ {
62
+ logit (torch.Tensor):
63
+ (batch_size, max_seq_len, vocab_size) shape logit
64
+ loss_fn:
65
+ logit과 label간의 loss_fn
66
+ last_hidden_state (torch.Tensor):
67
+ (batch_size, max_seq_len, hidden_dim) shape tensor
68
+ }
69
+ """
70
+ if input_ids is not None:
71
+ if input_ids.ndim == 3:
72
+ self.model.encoder.n_passages = input_ids.size(1)
73
+ input_ids = input_ids.view(input_ids.size(0), -1)
74
+ if attention_mask is not None:
75
+ attention_mask = attention_mask.view(attention_mask.size(0), -1)
76
+
77
+ return super().forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
78
+
79
+ def generate(self, input_ids, attention_mask, max_length=256):
80
+ return super().generate(
81
+ inputs=input_ids.view(input_ids.size(0), -1),
82
+ attention_mask=attention_mask.view(attention_mask.size(0), -1),
83
+ max_length=max_length,
84
+ )