File size: 3,435 Bytes
360d81f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import torch.nn as nn
from transformers import AutoModelForSeq2SeqLM, BartConfig, BartForConditionalGeneration, BertModel, BertPreTrainedModel
from transformers.modeling_outputs import BaseModelOutput


class FiDEncoder(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder

    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        # passage_length = input_ids.size(-1)
        bsz, total_length = input_ids.shape
        passage_length = total_length // self.n_passages
        input_ids = input_ids.view(bsz * self.n_passages, passage_length)
        attention_mask = attention_mask.view(bsz * self.n_passages, passage_length)
        encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

        encoder_outputs = BaseModelOutput(
            last_hidden_state=encoder_outputs[0].view(bsz, self.n_passages * passage_length, -1),
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )

        return encoder_outputs


class FiD(BartForConditionalGeneration):
    def __init__(self, model_config: BartConfig):
        if model_config:
            super().__init__(model_config)
        else:
            super().__init__()
        self.wrap_encoder(self.model.encoder)

    def load_pretrained_params(self, basemodel_name):
        basemodel = AutoModelForSeq2SeqLM.from_pretrained(basemodel_name)
        self.model.encoder.encoder.load_state_dict(basemodel.get_encoder().state_dict())
        self.model.decoder.load_state_dict(basemodel.get_decoder().state_dict())
        self.lm_head.load_state_dict(basemodel.lm_head.state_dict())
        print(f"loaded {basemodel_name} parameters.")

    def wrap_encoder(self, encoder):
        self.model.encoder = FiDEncoder(encoder)

    @classmethod
    def from_path(cls, model_config):

        return cls(model_config=model_config)

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor = None, **kwargs):
        """

        Args:
            input_ids (torch.Tensor): (batch_size, topk, seq_max_len) shape tensor
            attention_mask (torch.Tensor): (batch_size, topk, seq_max_len) shape tensor
            labels (torch.Tensor): (batch_size, seq_max_len) shape summarization Tensor

        Return:
            {
                logit (torch.Tensor):
                    (batch_size, max_seq_len, vocab_size) shape logit
                loss_fn:
                    logit과 label간의 loss_fn
                last_hidden_state (torch.Tensor):
                    (batch_size, max_seq_len, hidden_dim) shape tensor
            }
        """
        if input_ids is not None:
            if input_ids.ndim == 3:
                self.model.encoder.n_passages = input_ids.size(1)
            input_ids = input_ids.view(input_ids.size(0), -1)
        if attention_mask is not None:
            attention_mask = attention_mask.view(attention_mask.size(0), -1)

        return super().forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs)

    def generate(self, input_ids, attention_mask, max_length=256):
        return super().generate(
            inputs=input_ids.view(input_ids.size(0), -1),
            attention_mask=attention_mask.view(attention_mask.size(0), -1),
            max_length=max_length,
        )