File size: 4,273 Bytes
32b405b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""

Korean Financial Report Extractive Summarization Model



๋ฌธ๋‹จ์—์„œ ๋Œ€ํ‘œ๋ฌธ์žฅ์„ ์ถ”์ถœํ•˜๊ณ  ์—ญํ• (outlook, event, financial, risk)์„ ๋ถ„๋ฅ˜ํ•˜๋Š” ๋ชจ๋ธ

- klue/roberta-base ๊ธฐ๋ฐ˜

- ๋ฌธ์žฅ๋ณ„ [CLS] ์ธ์ฝ”๋”ฉ + Inter-sentence Transformer

- ๋Œ€ํ‘œ๋ฌธ์žฅ ์ด์ง„ ๋ถ„๋ฅ˜ + ์—ญํ•  ๋‹ค์ค‘ ๋ถ„๋ฅ˜ (Multi-task)

"""

import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, AutoTokenizer, PretrainedConfig, PreTrainedModel

ROLE_LABELS = ["outlook", "event", "financial", "risk"]
NUM_ROLES = len(ROLE_LABELS)
ROLE_TO_IDX = {role: idx for idx, role in enumerate(ROLE_LABELS)}
IDX_TO_ROLE = {idx: role for idx, role in enumerate(ROLE_LABELS)}


class DocumentEncoderConfig(PretrainedConfig):
    model_type = "document_encoder"

    def __init__(

        self,

        base_model_name: str = "klue/roberta-base",

        hidden_size: int = 768,

        num_transformer_layers: int = 2,

        num_roles: int = NUM_ROLES,

        max_length: int = 128,

        max_sentences: int = 30,

        role_labels: list = None,

        **kwargs,

    ):
        super().__init__(**kwargs)
        self.base_model_name = base_model_name
        self.hidden_size = hidden_size
        self.num_transformer_layers = num_transformer_layers
        self.num_roles = num_roles
        self.max_length = max_length
        self.max_sentences = max_sentences
        self.role_labels = role_labels or ROLE_LABELS


class DocumentEncoderForExtractiveSummarization(PreTrainedModel):
    config_class = DocumentEncoderConfig

    def __init__(self, config: DocumentEncoderConfig):
        super().__init__(config)

        self.sentence_encoder = AutoModel.from_pretrained(config.base_model_name)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.hidden_size,
            nhead=8,
            dim_feedforward=2048,
            dropout=0.1,
            batch_first=True,
        )
        self.inter_sentence_transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=config.num_transformer_layers,
        )

        self.classifier = nn.Sequential(
            nn.Linear(config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

        self.role_classifier = nn.Sequential(
            nn.Linear(config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, config.num_roles),
        )

    def encode_sentences(self, input_ids, attention_mask):
        outputs = self.sentence_encoder(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state[:, 0, :]

    def forward(self, sentences_input_ids, sentences_attention_mask, document_mask=None):
        """

        Args:

            sentences_input_ids: (batch_size, num_sentences, seq_len)

            sentences_attention_mask: (batch_size, num_sentences, seq_len)

            document_mask: (batch_size, num_sentences)



        Returns:

            scores: (batch_size, num_sentences) ๋Œ€ํ‘œ๋ฌธ์žฅ ์ ์ˆ˜

            role_logits: (batch_size, num_sentences, num_roles) ์—ญํ•  ๋กœ์ง“

        """
        batch_size, num_sentences, seq_len = sentences_input_ids.shape

        flat_ids = sentences_input_ids.view(-1, seq_len)
        flat_mask = sentences_attention_mask.view(-1, seq_len)

        embeddings = self.encode_sentences(flat_ids, flat_mask)
        hidden_size = embeddings.shape[-1]
        embeddings = embeddings.view(batch_size, num_sentences, hidden_size)

        src_key_padding_mask = None
        if document_mask is not None:
            src_key_padding_mask = ~document_mask.bool()

        contextualized = self.inter_sentence_transformer(
            embeddings, src_key_padding_mask=src_key_padding_mask
        )

        scores = self.classifier(contextualized).squeeze(-1)
        role_logits = self.role_classifier(contextualized)

        return scores, role_logits


# Auto ํด๋ž˜์Šค ๋“ฑ๋ก
AutoConfig.register("document_encoder", DocumentEncoderConfig)
AutoModel.register(DocumentEncoderConfig, DocumentEncoderForExtractiveSummarization)