langquantof commited on
Commit
32b405b
ยท
verified ยท
1 Parent(s): e454024

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +119 -0
model.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Korean Financial Report Extractive Summarization Model
3
+
4
+ ๋ฌธ๋‹จ์—์„œ ๋Œ€ํ‘œ๋ฌธ์žฅ์„ ์ถ”์ถœํ•˜๊ณ  ์—ญํ• (outlook, event, financial, risk)์„ ๋ถ„๋ฅ˜ํ•˜๋Š” ๋ชจ๋ธ
5
+ - klue/roberta-base ๊ธฐ๋ฐ˜
6
+ - ๋ฌธ์žฅ๋ณ„ [CLS] ์ธ์ฝ”๋”ฉ + Inter-sentence Transformer
7
+ - ๋Œ€ํ‘œ๋ฌธ์žฅ ์ด์ง„ ๋ถ„๋ฅ˜ + ์—ญํ•  ๋‹ค์ค‘ ๋ถ„๋ฅ˜ (Multi-task)
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from transformers import AutoConfig, AutoModel, AutoTokenizer, PretrainedConfig, PreTrainedModel
13
+
14
+ ROLE_LABELS = ["outlook", "event", "financial", "risk"]
15
+ NUM_ROLES = len(ROLE_LABELS)
16
+ ROLE_TO_IDX = {role: idx for idx, role in enumerate(ROLE_LABELS)}
17
+ IDX_TO_ROLE = {idx: role for idx, role in enumerate(ROLE_LABELS)}
18
+
19
+
20
+ class DocumentEncoderConfig(PretrainedConfig):
21
+ model_type = "document_encoder"
22
+
23
+ def __init__(
24
+ self,
25
+ base_model_name: str = "klue/roberta-base",
26
+ hidden_size: int = 768,
27
+ num_transformer_layers: int = 2,
28
+ num_roles: int = NUM_ROLES,
29
+ max_length: int = 128,
30
+ max_sentences: int = 30,
31
+ role_labels: list = None,
32
+ **kwargs,
33
+ ):
34
+ super().__init__(**kwargs)
35
+ self.base_model_name = base_model_name
36
+ self.hidden_size = hidden_size
37
+ self.num_transformer_layers = num_transformer_layers
38
+ self.num_roles = num_roles
39
+ self.max_length = max_length
40
+ self.max_sentences = max_sentences
41
+ self.role_labels = role_labels or ROLE_LABELS
42
+
43
+
44
+ class DocumentEncoderForExtractiveSummarization(PreTrainedModel):
45
+ config_class = DocumentEncoderConfig
46
+
47
+ def __init__(self, config: DocumentEncoderConfig):
48
+ super().__init__(config)
49
+
50
+ self.sentence_encoder = AutoModel.from_pretrained(config.base_model_name)
51
+
52
+ encoder_layer = nn.TransformerEncoderLayer(
53
+ d_model=config.hidden_size,
54
+ nhead=8,
55
+ dim_feedforward=2048,
56
+ dropout=0.1,
57
+ batch_first=True,
58
+ )
59
+ self.inter_sentence_transformer = nn.TransformerEncoder(
60
+ encoder_layer,
61
+ num_layers=config.num_transformer_layers,
62
+ )
63
+
64
+ self.classifier = nn.Sequential(
65
+ nn.Linear(config.hidden_size, 256),
66
+ nn.ReLU(),
67
+ nn.Dropout(0.1),
68
+ nn.Linear(256, 1),
69
+ nn.Sigmoid(),
70
+ )
71
+
72
+ self.role_classifier = nn.Sequential(
73
+ nn.Linear(config.hidden_size, 256),
74
+ nn.ReLU(),
75
+ nn.Dropout(0.1),
76
+ nn.Linear(256, config.num_roles),
77
+ )
78
+
79
+ def encode_sentences(self, input_ids, attention_mask):
80
+ outputs = self.sentence_encoder(input_ids=input_ids, attention_mask=attention_mask)
81
+ return outputs.last_hidden_state[:, 0, :]
82
+
83
+ def forward(self, sentences_input_ids, sentences_attention_mask, document_mask=None):
84
+ """
85
+ Args:
86
+ sentences_input_ids: (batch_size, num_sentences, seq_len)
87
+ sentences_attention_mask: (batch_size, num_sentences, seq_len)
88
+ document_mask: (batch_size, num_sentences)
89
+
90
+ Returns:
91
+ scores: (batch_size, num_sentences) ๋Œ€ํ‘œ๋ฌธ์žฅ ์ ์ˆ˜
92
+ role_logits: (batch_size, num_sentences, num_roles) ์—ญํ•  ๋กœ์ง“
93
+ """
94
+ batch_size, num_sentences, seq_len = sentences_input_ids.shape
95
+
96
+ flat_ids = sentences_input_ids.view(-1, seq_len)
97
+ flat_mask = sentences_attention_mask.view(-1, seq_len)
98
+
99
+ embeddings = self.encode_sentences(flat_ids, flat_mask)
100
+ hidden_size = embeddings.shape[-1]
101
+ embeddings = embeddings.view(batch_size, num_sentences, hidden_size)
102
+
103
+ src_key_padding_mask = None
104
+ if document_mask is not None:
105
+ src_key_padding_mask = ~document_mask.bool()
106
+
107
+ contextualized = self.inter_sentence_transformer(
108
+ embeddings, src_key_padding_mask=src_key_padding_mask
109
+ )
110
+
111
+ scores = self.classifier(contextualized).squeeze(-1)
112
+ role_logits = self.role_classifier(contextualized)
113
+
114
+ return scores, role_logits
115
+
116
+
117
+ # Auto ํด๋ž˜์Šค ๋“ฑ๋ก
118
+ AutoConfig.register("document_encoder", DocumentEncoderConfig)
119
+ AutoModel.register(DocumentEncoderConfig, DocumentEncoderForExtractiveSummarization)