htg2501 commited on
Commit
4ab5986
·
verified ·
1 Parent(s): 897e8e5

Update summarization.py

Browse files
Files changed (1) hide show
  1. summarization.py +615 -615
summarization.py CHANGED
@@ -1,616 +1,616 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import pickle
5
- import numpy as np
6
- from rouge import Rouge
7
- import string
8
- import re
9
- from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer
10
- from underthesea import sent_tokenize, word_tokenize
11
-
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- abstract_tokenizer_path = "vinai/bartpho-syllable-base"
14
- abstract_model_path = "htg2501/checkpoint"
15
- extractive_model_path = "./e_25_0.3071.mdl"
16
- contrastive_model_path = "./c_25_0.3071.mdl"
17
-
18
- stopword_path = "./vietnamese-stopwords-dash.txt"
19
- LDA_model_path = "./LDA_models.pkl"
20
-
21
- phobert = AutoModel.from_pretrained("vinai/phobert-base-v2").to(device)
22
- phobert_tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2")
23
- model_summarization = AutoModelForSeq2SeqLM.from_pretrained(abstract_model_path).to(device)
24
- tokenizer_summarization = AutoTokenizer.from_pretrained(abstract_tokenizer_path)
25
-
26
- """# Extractive model"""
27
-
28
-
29
- def getRouge2(ref, pred, kind):
30
- try:
31
- return round(Rouge().get_scores(pred.lower(), ref.lower())[0]['rouge-2'][kind], 4)
32
- except ValueError:
33
- return 0.0
34
-
35
-
36
- class MLP(nn.Module):
37
- def __init__(self, dims: list, layers=2, act=nn.LeakyReLU(), dropout_p=0.1, keep_last_layer=False):
38
- super(MLP, self).__init__()
39
- assert len(dims) == layers + 1
40
- self.layers = layers
41
- self.act = act
42
- self.dropout = nn.Dropout(dropout_p)
43
- self.keep_last = keep_last_layer
44
-
45
- self.mlp_layers = nn.ModuleList([])
46
- for i in range(self.layers):
47
- self.mlp_layers.append(nn.Linear(dims[i], dims[i + 1]))
48
-
49
- def forward(self, x):
50
- for i in range(len(self.mlp_layers) - 1):
51
- x = self.dropout(self.act(self.mlp_layers[i](x)))
52
- if self.keep_last:
53
- x = self.mlp_layers[-1](x)
54
- else:
55
- x = self.act(self.mlp_layers[-1](x))
56
- return x
57
-
58
-
59
- class GraphAttentionLayer(nn.Module):
60
- def __init__(self, in_features: int, out_features: int, n_heads: int,
61
- is_concat: bool = True,
62
- dropout: float = 0.6,
63
- leaky_relu_negative_slope: float = 0.2):
64
- super().__init__()
65
-
66
- self.is_concat = is_concat
67
- self.n_heads = n_heads
68
-
69
- # Calculate the number of dimensions per head
70
- if is_concat:
71
- assert out_features % n_heads == 0
72
- self.n_hidden = out_features // n_heads
73
- else:
74
- self.n_hidden = out_features
75
-
76
- self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)
77
- self.attn = nn.Linear(self.n_hidden * 2, 1, bias=False)
78
- self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
79
- self.softmax = nn.Softmax(dim=1)
80
- self.dropout = nn.Dropout(dropout)
81
-
82
- def forward(self, h: torch.Tensor, adj_mat: torch.Tensor, docnum, secnum):
83
- n_nodes = h.shape[0]
84
- g = self.linear(h).view(n_nodes, self.n_heads, self.n_hidden)
85
- g_repeat = g.repeat(n_nodes, 1, 1)
86
- g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)
87
- g_concat = torch.cat([g_repeat_interleave, g_repeat], dim=-1)
88
- g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)
89
- e = self.activation(self.attn(g_concat))
90
-
91
- e = e.squeeze(-1)
92
-
93
- # The adjacency matrix should have shape
94
- # `[n_nodes, n_nodes, n_heads]` or`[n_nodes, n_nodes, 1]`
95
- assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
96
- assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
97
- assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads
98
- # Mask $e_{ij}$ based on adjacency matrix.
99
- # $e_{ij}$ is set to $- \infty$ if there is no edge from $i$ to $j$.
100
- e = e.masked_fill(adj_mat == 0, float(-1e9))
101
- a = self.softmax(e)
102
- a = self.dropout(a)
103
- attn_res = torch.einsum('ijh,jhf->ihf', a, g)
104
-
105
- # Concatenate the heads
106
- if self.is_concat:
107
- return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)
108
- # Take the mean of the heads
109
- else:
110
- return attn_res.mean(dim=1)
111
-
112
-
113
- class GAT(nn.Module):
114
- def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):
115
- super().__init__()
116
- self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)
117
- self.activation = nn.ELU()
118
- self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)
119
- self.dropout = nn.Dropout(dropout)
120
-
121
- def forward(self, x: torch.Tensor, adj_mat: torch.Tensor, docnum, secnum):
122
- x = x.squeeze(0)
123
- adj_mat = adj_mat.squeeze(0)
124
- adj_x = adj_mat.clone().sum(dim=1, keepdim=True).repeat(1, x.shape[1]).bool()
125
- adj_mat = adj_mat.unsqueeze(-1).bool()
126
- x = self.dropout(x)
127
- x = self.layer1(x, adj_mat, docnum, secnum)
128
- x = self.activation(x)
129
- x = self.dropout(x)
130
- x = self.output(x, adj_mat, docnum, secnum).masked_fill(adj_x == 0, float(0))
131
- return x.unsqueeze(0)
132
-
133
-
134
- class StepWiseGraphConvLayer(nn.Module):
135
- def __init__(self, in_dim, hid_dim, dropout_p=0.1, act=nn.LeakyReLU(), nheads=6, iter=1, final="att"):
136
- super().__init__()
137
- self.act = act
138
- self.dropout = nn.Dropout(dropout_p)
139
- self.iter = iter
140
- self.in_dim = in_dim
141
- self.gat = nn.ModuleList([GAT(in_features=in_dim, n_hidden=hid_dim, n_classes=in_dim,
142
- dropout=dropout_p, n_heads=nheads) for _ in range(iter)])
143
- self.gat2 = nn.ModuleList([GAT(in_features=in_dim, n_hidden=hid_dim, n_classes=in_dim,
144
- dropout=dropout_p, n_heads=nheads) for _ in range(iter)])
145
- self.gat3 = nn.ModuleList([GAT(in_features=in_dim, n_hidden=hid_dim, n_classes=in_dim,
146
- dropout=dropout_p, n_heads=nheads) for _ in range(iter)])
147
-
148
- self.out_ffn = MLP([in_dim * 3, hid_dim, hid_dim, in_dim], layers=3, dropout_p=dropout_p)
149
-
150
- def forward(self, feature, adj, docnum, secnum):
151
- sen_adj = adj.clone()
152
- sen_adj[:, -docnum - secnum - 1:, :] = sen_adj[:, :, -docnum - secnum - 1:] = 0
153
- sec_adj = adj.clone()
154
- sec_adj[:, :-docnum - secnum - 1, :] = sec_adj[:, -docnum - 1:, :] = sec_adj[:, :, -docnum - 1:] = 0
155
- doc_adj = adj.clone()
156
- doc_adj[:, :-docnum - 1, :] = 0
157
-
158
- feature_sen = feature.clone()
159
- feature_resi = feature
160
-
161
- feature_sen_re = feature_sen.clone()
162
- for i in range(0, self.iter):
163
- feature_sen = self.gat[i](feature_sen, sen_adj, docnum, secnum)
164
- feature_sen = F.layer_norm(feature_sen + feature_sen_re, [self.in_dim])
165
-
166
- feature_sec = feature_sen.clone()
167
- feature_sec_re = feature_sec.clone()
168
- for i in range(0, self.iter):
169
- feature_sec = self.gat2[i](feature_sec, sec_adj, docnum, secnum)
170
- feature_sec = F.layer_norm(feature_sec + feature_sec_re, [self.in_dim])
171
-
172
- feature_doc = feature_sec.clone()
173
- feature_doc_re = feature_doc.clone()
174
- for i in range(0, self.iter):
175
- feature_doc = self.gat3[i](feature_doc, doc_adj, docnum, secnum)
176
- feature_doc = F.layer_norm(feature_doc + feature_doc_re, [self.in_dim])
177
-
178
- feature_sec[:, :-docnum - secnum - 1, :] = adj[:, :-docnum - secnum - 1,
179
- -docnum - secnum - 1:-docnum - 1] @ feature_sec[:,
180
- -docnum - secnum - 1:-docnum - 1,
181
- :]
182
- feature_doc[:, -docnum - secnum - 1:-docnum - 1, :] = adj[:, -docnum - secnum - 1:-docnum - 1,
183
- -docnum - 1:] @ feature_doc[:, -docnum - 1:, :]
184
- feature_doc[:, :-docnum - secnum - 1, :] = adj[:, :-docnum - secnum - 1,
185
- -docnum - secnum - 1:-docnum - 1] @ feature_doc[:,
186
- -docnum - secnum - 1:-docnum - 1,
187
- :]
188
- feature = torch.concat([feature_doc, feature_sec, feature_sen], dim=-1)
189
- feature = F.layer_norm(self.out_ffn(feature) + feature_resi, [self.in_dim])
190
- return feature
191
-
192
-
193
- class Contrast_Encoder(nn.Module):
194
- def __init__(self, input_dim, hidden_dim, heads, act=nn.LeakyReLU(0.1), dropout_p=0.1):
195
- super(Contrast_Encoder, self).__init__()
196
- self.graph_encoder = StepWiseGraphConvLayer(in_dim=input_dim, hid_dim=hidden_dim,
197
- dropout_p=dropout_p, act=act, nheads=heads, iter=1)
198
- self.common_proj_mlp = MLP([input_dim, hidden_dim, input_dim], layers=2, dropout_p=dropout_p, act=act,
199
- keep_last_layer=False)
200
-
201
- def forward(self, p_gfeature, doc_lens, p_adj, docnum, secnum):
202
- posVec = torch.cat(
203
- [PositionVec[:l] for l in doc_lens] + [torch.zeros(secnum + docnum + 1, 768).float().to(device)], dim=0)
204
- p_gfeature = p_gfeature + posVec.unsqueeze(0)
205
- pg = self.graph_encoder(p_gfeature, p_adj, docnum, secnum)
206
- pg = self.common_proj_mlp(pg)
207
- return pg
208
-
209
-
210
- class End2End_Encoder(nn.Module):
211
- def __init__(self, input_dim, hidden_dim, heads, act=nn.LeakyReLU(0.1), dropout_p=0.3):
212
- super(End2End_Encoder, self).__init__()
213
- self.graph_encoder = StepWiseGraphConvLayer(in_dim=input_dim, hid_dim=hidden_dim,
214
- dropout_p=dropout_p, act=act, nheads=heads, iter=1)
215
- self.dropout = nn.Dropout(dropout_p)
216
- self.out_proj_layer_mlp = MLP([input_dim, hidden_dim, input_dim], layers=2, dropout_p=dropout_p, act=act,
217
- keep_last_layer=False)
218
- self.linear = MLP([input_dim, 1], layers=1, dropout_p=dropout_p, act=act, keep_last_layer=True)
219
-
220
- def forward(self, x, doc_lens, adj, docnum, secnum):
221
- x = self.graph_encoder(x, adj, docnum, secnum)
222
- x = self.out_proj_layer_mlp(x)
223
- return self.linear(x)[:, :-docnum - secnum - 1, :]
224
-
225
-
226
- def _similarity(h1: torch.Tensor, h2: torch.Tensor):
227
- h1 = F.normalize(h1)
228
- h2 = F.normalize(h2)
229
- return h1 @ h2.t()
230
-
231
-
232
- class InfoNCE(nn.Module):
233
- def __init__(self, tau):
234
- super(InfoNCE, self).__init__()
235
- self.tau = tau
236
-
237
- def forward(self, anchor, sample, pos_mask, *args, **kwargs):
238
- sim = _similarity(anchor, sample) / self.tau
239
- if len(anchor) > 1:
240
- sim, _ = torch.max(sim, dim=0, keepdim=True)
241
- exp_sim = torch.exp(sim)
242
- loss = torch.log((exp_sim * pos_mask).sum(dim=1)) - torch.log(exp_sim.sum(dim=1))
243
- return -loss.mean()
244
-
245
-
246
- class Cluster:
247
- def __init__(self, sent_texts, sent_vecs, doc_lens, doc_sec_mask, sec_sen_mask):
248
- assert len(sent_vecs) == len(sent_texts)
249
- self.docnum = len(doc_sec_mask)
250
- self.secnum = len(sec_sen_mask)
251
- self.feature = torch.cat(
252
- (torch.stack(sent_vecs, dim=0), torch.zeros((self.secnum + self.docnum + 1, sent_vecs[0].shape[0]))),
253
- dim=0).to(device)
254
- self.adj = torch.from_numpy(self.mask_to_adj(doc_sec_mask, sec_sen_mask)).float().to(device)
255
- self.sent_text = np.array(sent_texts)
256
- self.doc_lens = doc_lens
257
- self.init_node_vec()
258
- self.feature = self.feature.float()
259
-
260
- def init_node_vec(self):
261
- docnum, secnum = self.docnum, self.secnum
262
- for i in range(-secnum - docnum - 1, -docnum - 1):
263
- mask = self.adj[i].clone()
264
- mask[-secnum - docnum - 1:] = 0
265
- self.feature[i] = torch.mean(self.feature[mask.bool()], dim=0)
266
- for i in range(-docnum - 1, -1):
267
- mask = self.adj[i].clone()
268
- mask[-docnum - 1:] = 0
269
- self.feature[i] = torch.mean(self.feature[mask.bool()], dim=0)
270
- self.feature[-1] = torch.mean(self.feature[-docnum - 1:-1], dim=0)
271
-
272
- def mask_to_adj(self, doc_sec_mask, sec_sen_mask):
273
- sen_num = sec_sen_mask.shape[1]
274
- sec_num = sec_sen_mask.shape[0]
275
- doc_num = doc_sec_mask.shape[0]
276
- adj = np.zeros((sen_num + sec_num + doc_num + 1, sen_num + sec_num + doc_num + 1))
277
- # section connection
278
- adj[-sec_num - doc_num - 1:-doc_num - 1, 0:-sec_num - doc_num - 1] = sec_sen_mask
279
- adj[0:-sec_num - doc_num - 1, -sec_num - doc_num - 1:-doc_num - 1] = sec_sen_mask.T
280
- for i in range(0, doc_num):
281
- doc_mask = doc_sec_mask[i]
282
- doc_mask = doc_mask.reshape((1, len(doc_mask)))
283
- adj[sen_num:-doc_num - 1, sen_num:-doc_num - 1] += doc_mask * doc_mask.T
284
- # doc connection
285
- adj[-doc_num - 1:-1, -sec_num - doc_num - 1:-doc_num - 1] = doc_sec_mask
286
- adj[-sec_num - doc_num - 1:-doc_num - 1, -doc_num - 1:-1] = doc_sec_mask.T
287
- adj[-doc_num - 1:, -doc_num - 1:] = 1
288
-
289
- #build sentence connection
290
- for i in range(0, sec_num):
291
- sec_mask = sec_sen_mask[i]
292
- sec_mask = sec_mask.reshape((1, len(sec_mask)))
293
- adj[:sen_num, :sen_num] += sec_mask * sec_mask.T
294
- return adj
295
-
296
-
297
- def meanTokenVecs(text):
298
- sent = text.lower()
299
- input_ids = torch.tensor([phobert_tokenizer.encode(sent)])
300
- tokenized_text = phobert_tokenizer.tokenize(sent)
301
- with torch.no_grad():
302
- features = phobert(input_ids.to(device))
303
- wordVecs, buffer, buffer_str = {}, [], ''
304
- for token in zip(tokenized_text, features.last_hidden_state[0, 1:-1, :]):
305
- if token[0][-2:] == '@@':
306
- buffer.append(token[1])
307
- buffer_str += token[0][:-2]
308
- continue
309
- if buffer:
310
- buffer.append(token[1])
311
- buffer_str += token[0]
312
- wordVecs[buffer_str] = torch.mean(torch.stack(buffer), dim=0)
313
- buffer, buffer_str = [], ''
314
- else:
315
- wordVecs[token[0]] = token[1]
316
-
317
- return torch.mean(torch.stack([vec for w, vec in wordVecs.items() if w not in string.punctuation]), dim=0).to(
318
- torch.device('cpu'))
319
-
320
-
321
- def getPositionEncoding(pos, d=768, n=10000):
322
- P = np.zeros(d)
323
- for i in np.arange(int(d / 2)):
324
- denominator = np.power(n, 2 * i / d)
325
- P[2 * i] = np.sin(pos / denominator)
326
- P[2 * i + 1] = np.cos(pos / denominator)
327
- return P
328
-
329
-
330
- def removeRedundant(text):
331
- text = text.lower()
332
- words = [w for w in text.split(' ') if w not in stop_w]
333
- return ' '.join(words)
334
-
335
-
336
- def divideSection(doc_text, category='Giáo dục'):
337
- sent_para, para_sec, sent_sec = {}, {}, {}
338
-
339
- paras = [para for para in doc_text.split('\n') if para != '']
340
- all_sents = []
341
- # prepare sent_Para
342
- sentcnt = 0
343
- for i, para in enumerate(paras):
344
- sents = [word_tokenize(sent, format="text") for sent in sent_tokenize(para) if sent != '' and len(sent) > 4]
345
- all_sents.extend(sents)
346
- for ii, sent in enumerate(sents):
347
- sent_para[sentcnt + ii] = i
348
- sent = removeRedundant(sent)
349
- sentcnt += len(sents)
350
-
351
- # prepare para_sec
352
- paras = [removeRedundant(para) for para in paras]
353
- tf, lda_model = cate_models[category]
354
- X = tf.transform(paras)
355
- lda_top = lda_model.transform(X)
356
- for i, para_top in enumerate(lda_top):
357
- para_sec[i] = para_top.argmax()
358
-
359
- # output sent_sec
360
- for k, v in sent_para.items():
361
- sent_sec[k] = para_sec[v]
362
- return sent_sec, all_sents
363
-
364
-
365
- def loadClusterData(docs_org, category='Giáo dục'): # docs_org: list of text for each document
366
- seclist, docs = {}, []
367
- for d, doc in enumerate(docs_org):
368
- seclist[d], sentTexts = divideSection(doc, category)
369
- docs.append(sentTexts)
370
-
371
- secnum = 0
372
- for k, val_dict in seclist.items():
373
- vals = set(val_dict.values())
374
- for ki, vi in val_dict.items():
375
- for i, v in enumerate(vals):
376
- if vi == v:
377
- val_dict[ki] = i + secnum
378
- break
379
- seclist[k] = val_dict
380
- secnum += len(vals)
381
-
382
- sents, sentVecs, secIDs, doc_lens = [], [], [], []
383
- sentnum = sum([len(doc.values()) for doc in seclist.values()])
384
- doc_sec_mask = np.zeros((len(docs), secnum))
385
- sec_sen_mask = np.zeros((secnum, sentnum))
386
- cursec, cursent = 0, 0
387
-
388
- for d, doc in enumerate(docs):
389
- doc_lens.append(len(doc))
390
- doc_endsec = max(seclist[d].values())
391
- doc_sec_mask[d][cursec:doc_endsec + 1] = 1
392
- cursec = doc_endsec + 1
393
- for s, sent in enumerate(doc):
394
- sents.append(sent)
395
- sentVecs.append(meanTokenVecs(sent))
396
- sec_sen_mask[seclist[d][s], cursent] = 1
397
- cursent += 1
398
-
399
- return Cluster(sents, sentVecs, doc_lens, doc_sec_mask, sec_sen_mask)
400
-
401
-
402
- def val_e2e(data):
403
- feature = data.feature.unsqueeze(0)
404
- doc_lens = data.doc_lens
405
- adj = data.adj.unsqueeze(0)
406
- docnum = data.docnum
407
- secnum = data.secnum
408
-
409
- with torch.no_grad():
410
- feature = c_model(feature, doc_lens, adj, docnum, secnum)
411
- x = model(feature, doc_lens, adj, docnum, secnum)
412
- scores = torch.sigmoid(x.squeeze(-1))
413
-
414
- return scores, data.sent_text
415
-
416
-
417
- def normalize_text(text):
418
- text = str(text).replace('_', ' ')
419
- text = re.sub(r'\s+', ' ', text)
420
- text = re.sub(r'\s+([.,;:?)/!?”])', r'\1', text)
421
- text = re.sub(r'([\(“])\s+', r'\1', text)
422
- return text
423
-
424
-
425
- def track_changes(old_words, new_words):
426
- # Find the longest common subsequence (LCS) between the two word sequences
427
- def get_lcs_matrix(words1, words2):
428
- m, n = len(words1), len(words2)
429
- dp = [[0] * (n + 1) for _ in range(m + 1)]
430
-
431
- for i in range(1, m + 1):
432
- for j in range(1, n + 1):
433
- if words1[i - 1] == words2[j - 1]:
434
- dp[i][j] = dp[i - 1][j - 1] + 1
435
- else:
436
- dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
437
-
438
- return dp
439
-
440
- def get_lcs(words1, words2, dp):
441
- i, j = len(words1), len(words2)
442
- lcs = []
443
-
444
- while i > 0 and j > 0:
445
- if words1[i - 1] == words2[j - 1]:
446
- lcs.append((i - 1, j - 1))
447
- i -= 1
448
- j -= 1
449
- elif dp[i - 1][j] > dp[i][j - 1]:
450
- i -= 1
451
- else:
452
- j -= 1
453
-
454
- return sorted(lcs)
455
-
456
- # Find the changed segments at word level
457
- dp_matrix = get_lcs_matrix(old_words, new_words)
458
- lcs_positions = get_lcs(old_words, new_words, dp_matrix)
459
-
460
- changes = []
461
- old_pos = 0
462
- new_pos = 0
463
-
464
- # Process matching and non-matching segments
465
- for old_idx, new_idx in lcs_positions:
466
- # If there's a gap before this match, it's a change
467
- if old_idx > old_pos or new_idx > new_pos:
468
- changes.append((old_pos, old_idx, new_pos, new_idx))
469
-
470
- # Move positions after the match
471
- old_pos = old_idx + 1
472
- new_pos = new_idx + 1
473
-
474
- # Check if there's a change at the end
475
- if old_pos < len(old_words) or new_pos < len(new_words):
476
- changes.append((old_pos, len(old_words), new_pos, len(new_words)))
477
-
478
- return changes
479
-
480
-
481
- class Abstractive_Summarization:
482
- @staticmethod
483
- def generateSummaryBySent(texts, batch=32):
484
- model_summarization.eval()
485
- predictions = []
486
- with torch.no_grad():
487
- for i in range(0, len(texts), batch):
488
- batch_texts = texts[i:i + batch]
489
- inputs = tokenizer_summarization(batch_texts, padding=True, max_length=1024, truncation=True,
490
- return_tensors='pt').to(device)
491
- outputs = model_summarization.generate(**inputs, num_beams=5,
492
- early_stopping=True, no_repeat_ngram_size=3)
493
- prediction = tokenizer_summarization.batch_decode(outputs, skip_special_tokens=True)
494
- predictions.extend(prediction)
495
- return predictions
496
-
497
-
498
- PositionVec = torch.stack([torch.from_numpy(getPositionEncoding(i, d=768)) for i in range(200)], dim=0).float().to(
499
- device)
500
- stop_w = ['...']
501
- # with open(stopword_path, 'r', encoding='utf-8') as f:
502
- # for w in f.readlines():
503
- # stop_w.append(w.strip())
504
- stop_w.extend([c for c in '!"#$%&\'()*+,./:;<=>?@[\\]^`{|}~…“”’‘'])
505
-
506
- with open(LDA_model_path, mode='rb') as fp:
507
- cate_models = pickle.load(fp)
508
-
509
- c_model = Contrast_Encoder(768, 1024, 4).to(device)
510
- model = End2End_Encoder(768, 1024, 4).to(device)
511
- model.load_state_dict(torch.load(extractive_model_path, map_location=device), strict=False)
512
- c_model.load_state_dict(torch.load(contrastive_model_path, map_location=device), strict=False)
513
- model.eval()
514
- c_model.eval()
515
-
516
- def get_summary(scores, sents, max_sent=5):
517
- ranked_score_idxs = torch.argsort(scores[0], dim=0, descending=True)
518
- sents = [s.replace('_', ' ') for s in sents]
519
- summSentIDList = []
520
- for i in ranked_score_idxs:
521
- if len(summSentIDList) >= max_sent: break
522
- s = sents[i]
523
-
524
- replicated, delIDs = False, []
525
- for chosedID in summSentIDList:
526
- if getRouge2(s, sents[chosedID], 'p') >= 0.45:
527
- delIDs.append(chosedID)
528
- if getRouge2(sents[chosedID], s, 'p') >= 0.45:
529
- replicated = True
530
- break
531
- if replicated: continue
532
-
533
- for delID in delIDs:
534
- del summSentIDList[summSentIDList.index(delID)]
535
- summSentIDList.append(i)
536
- summSentIDList = sorted(summSentIDList)
537
- return [s for i, s in enumerate(sents) if i in summSentIDList]
538
-
539
- def MultiDocSummarizationAPI(texts, compress_ratio):
540
- """
541
- Summarizes a list of documents using both extractive and abstractive methods.
542
-
543
- Parameters:
544
- - texts (list of str): A list of document texts to be summarized.
545
- - compress_ratio (float): A ratio or count determining the number of sentences in the summary.
546
- If less than 1, it represents the fraction of the original sentences to include in the summary.
547
- If 1 or greater, it represents the exact number of sentences to include in the summary.
548
-
549
- Returns:
550
- - dict: A dictionary containing:
551
- - 'extractive_summ' (str): The extractive summary of the documents.
552
- - 'abstractive_summ' (str): The abstractive summary of the documents.
553
- """
554
- assert compress_ratio > 0, "Compress ratio need to be greater than 0."
555
- docs = [text.strip() for text in texts]
556
- data_tree = loadClusterData(docs)
557
- scores, sents = val_e2e(data_tree)
558
-
559
- output_sent_cnt = int(len(sents) * compress_ratio) if compress_ratio < 1 else int(compress_ratio)
560
- print('Expected sentence count:', output_sent_cnt)
561
-
562
- extractive_summ_sents = [normalize_text(sent) for sent in get_summary(scores, sents, max_sent=output_sent_cnt)]
563
- extractive_summ = ' '.join(extractive_summ_sents)
564
-
565
- abstractive_summ_sents = Abstractive_Summarization.generateSummaryBySent(extractive_summ_sents)
566
- abstractive_summ_sents = [normalize_text(s) for s in abstractive_summ_sents]
567
- final_sents = []
568
- for ii, (ext, abs) in enumerate(zip(extractive_summ_sents, abstractive_summ_sents)):
569
- if ii == 0:
570
- final_sents.append(ext)
571
- continue
572
- abs_splits, ext_splits = word_tokenize(abs), word_tokenize(ext)
573
- abs_splits_cop, ext_splits_cop = abs_splits.copy(), ext_splits.copy()
574
- if len(abs_splits_cop):
575
- abs_splits_cop[-1] = abs_splits[-1][:-1] if len(abs_splits[-1]) and abs_splits[-1][-1] == '.' else abs_splits[-1]
576
- if len(ext_splits_cop):
577
- ext_splits_cop[-1] = ext_splits[-1][:-1] if len(ext_splits[-1]) and ext_splits[-1][-1] == '.' else ext_splits[-1]
578
-
579
- changes, abs_parts = track_changes(ext_splits_cop, abs_splits_cop), [(0, len(abs_splits))]
580
- for start_old, end_old, start_new, end_new in changes:
581
- old_part = ' '.join(ext_splits[start_old:end_old])
582
- # Revert change in the cases of spelling errors
583
- revert, ignoreFirstSentWord = False, 1 if start_old == 0 else 0
584
- old_names = {}
585
- for w in ext_splits_cop[start_old + ignoreFirstSentWord:end_old]:
586
- if len(w) == 0: continue
587
- if 'A'<=w[0]<='Z' or w[0] in ['Ä‚', 'Ă‚', 'Đ', 'Ê', 'Ă”', 'Æ ', 'Ư']:
588
- if w in old_names:
589
- old_names[w] += 1
590
- else:
591
- old_names[w] = 1
592
-
593
- for w in abs_splits_cop[start_new + ignoreFirstSentWord:end_new]:
594
- if len(w) == 0: continue
595
- if 'A'<=w[0]<='Z' or w[0] in ['Ä‚', 'Ă‚', 'Đ', 'Ê', 'Ă”', 'Æ ', 'Ư']:
596
- if w in old_names:
597
- old_names[w] -= 1
598
- if old_names[w] < 0:
599
- revert = True
600
- break
601
- else:
602
- revert = True
603
- break
604
- if revert:
605
- pop_part = abs_parts[-1]
606
- abs_parts.pop()
607
- abs_parts.extend([(pop_part[0], start_new), old_part, (end_new, pop_part[1])])
608
- # print('\nOLD:', old_part, '\n', ' '.join(abs_splits[start_new:end_new]))
609
- # print(ext, '\n', abs)
610
-
611
- abs = ' '.join([part if isinstance(part, str) else ' '.join(abs_splits[part[0]:part[1]]) for part in abs_parts])
612
- final_sents.append(normalize_text(abs))
613
- abstract_summ = ' '.join(final_sents)
614
-
615
- return {'extractive_summ': extractive_summ,
616
  'abstractive_summ': abstract_summ}
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import pickle
5
+ import numpy as np
6
+ from rouge import Rouge
7
+ import string
8
+ import re
9
+ from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer
10
+ from underthesea import sent_tokenize, word_tokenize
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ abstract_tokenizer_path = "vinai/bartpho-syllable-base"
14
+ abstract_model_path = "htg2501/Checkpoint-2200"
15
+ extractive_model_path = "./e_25_0.3071.mdl"
16
+ contrastive_model_path = "./c_25_0.3071.mdl"
17
+
18
+ stopword_path = "./vietnamese-stopwords-dash.txt"
19
+ LDA_model_path = "./LDA_models.pkl"
20
+
21
+ phobert = AutoModel.from_pretrained("vinai/phobert-base-v2").to(device)
22
+ phobert_tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2")
23
+ model_summarization = AutoModelForSeq2SeqLM.from_pretrained(abstract_model_path).to(device)
24
+ tokenizer_summarization = AutoTokenizer.from_pretrained(abstract_tokenizer_path)
25
+
26
+ """# Extractive model"""
27
+
28
+
29
+ def getRouge2(ref, pred, kind):
30
+ try:
31
+ return round(Rouge().get_scores(pred.lower(), ref.lower())[0]['rouge-2'][kind], 4)
32
+ except ValueError:
33
+ return 0.0
34
+
35
+
36
+ class MLP(nn.Module):
37
+ def __init__(self, dims: list, layers=2, act=nn.LeakyReLU(), dropout_p=0.1, keep_last_layer=False):
38
+ super(MLP, self).__init__()
39
+ assert len(dims) == layers + 1
40
+ self.layers = layers
41
+ self.act = act
42
+ self.dropout = nn.Dropout(dropout_p)
43
+ self.keep_last = keep_last_layer
44
+
45
+ self.mlp_layers = nn.ModuleList([])
46
+ for i in range(self.layers):
47
+ self.mlp_layers.append(nn.Linear(dims[i], dims[i + 1]))
48
+
49
+ def forward(self, x):
50
+ for i in range(len(self.mlp_layers) - 1):
51
+ x = self.dropout(self.act(self.mlp_layers[i](x)))
52
+ if self.keep_last:
53
+ x = self.mlp_layers[-1](x)
54
+ else:
55
+ x = self.act(self.mlp_layers[-1](x))
56
+ return x
57
+
58
+
59
+ class GraphAttentionLayer(nn.Module):
60
+ def __init__(self, in_features: int, out_features: int, n_heads: int,
61
+ is_concat: bool = True,
62
+ dropout: float = 0.6,
63
+ leaky_relu_negative_slope: float = 0.2):
64
+ super().__init__()
65
+
66
+ self.is_concat = is_concat
67
+ self.n_heads = n_heads
68
+
69
+ # Calculate the number of dimensions per head
70
+ if is_concat:
71
+ assert out_features % n_heads == 0
72
+ self.n_hidden = out_features // n_heads
73
+ else:
74
+ self.n_hidden = out_features
75
+
76
+ self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)
77
+ self.attn = nn.Linear(self.n_hidden * 2, 1, bias=False)
78
+ self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
79
+ self.softmax = nn.Softmax(dim=1)
80
+ self.dropout = nn.Dropout(dropout)
81
+
82
+ def forward(self, h: torch.Tensor, adj_mat: torch.Tensor, docnum, secnum):
83
+ n_nodes = h.shape[0]
84
+ g = self.linear(h).view(n_nodes, self.n_heads, self.n_hidden)
85
+ g_repeat = g.repeat(n_nodes, 1, 1)
86
+ g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)
87
+ g_concat = torch.cat([g_repeat_interleave, g_repeat], dim=-1)
88
+ g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)
89
+ e = self.activation(self.attn(g_concat))
90
+
91
+ e = e.squeeze(-1)
92
+
93
+ # The adjacency matrix should have shape
94
+ # `[n_nodes, n_nodes, n_heads]` or`[n_nodes, n_nodes, 1]`
95
+ assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
96
+ assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
97
+ assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads
98
+ # Mask $e_{ij}$ based on adjacency matrix.
99
+ # $e_{ij}$ is set to $- \infty$ if there is no edge from $i$ to $j$.
100
+ e = e.masked_fill(adj_mat == 0, float(-1e9))
101
+ a = self.softmax(e)
102
+ a = self.dropout(a)
103
+ attn_res = torch.einsum('ijh,jhf->ihf', a, g)
104
+
105
+ # Concatenate the heads
106
+ if self.is_concat:
107
+ return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)
108
+ # Take the mean of the heads
109
+ else:
110
+ return attn_res.mean(dim=1)
111
+
112
+
113
+ class GAT(nn.Module):
114
+ def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):
115
+ super().__init__()
116
+ self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)
117
+ self.activation = nn.ELU()
118
+ self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)
119
+ self.dropout = nn.Dropout(dropout)
120
+
121
+ def forward(self, x: torch.Tensor, adj_mat: torch.Tensor, docnum, secnum):
122
+ x = x.squeeze(0)
123
+ adj_mat = adj_mat.squeeze(0)
124
+ adj_x = adj_mat.clone().sum(dim=1, keepdim=True).repeat(1, x.shape[1]).bool()
125
+ adj_mat = adj_mat.unsqueeze(-1).bool()
126
+ x = self.dropout(x)
127
+ x = self.layer1(x, adj_mat, docnum, secnum)
128
+ x = self.activation(x)
129
+ x = self.dropout(x)
130
+ x = self.output(x, adj_mat, docnum, secnum).masked_fill(adj_x == 0, float(0))
131
+ return x.unsqueeze(0)
132
+
133
+
134
+ class StepWiseGraphConvLayer(nn.Module):
135
+ def __init__(self, in_dim, hid_dim, dropout_p=0.1, act=nn.LeakyReLU(), nheads=6, iter=1, final="att"):
136
+ super().__init__()
137
+ self.act = act
138
+ self.dropout = nn.Dropout(dropout_p)
139
+ self.iter = iter
140
+ self.in_dim = in_dim
141
+ self.gat = nn.ModuleList([GAT(in_features=in_dim, n_hidden=hid_dim, n_classes=in_dim,
142
+ dropout=dropout_p, n_heads=nheads) for _ in range(iter)])
143
+ self.gat2 = nn.ModuleList([GAT(in_features=in_dim, n_hidden=hid_dim, n_classes=in_dim,
144
+ dropout=dropout_p, n_heads=nheads) for _ in range(iter)])
145
+ self.gat3 = nn.ModuleList([GAT(in_features=in_dim, n_hidden=hid_dim, n_classes=in_dim,
146
+ dropout=dropout_p, n_heads=nheads) for _ in range(iter)])
147
+
148
+ self.out_ffn = MLP([in_dim * 3, hid_dim, hid_dim, in_dim], layers=3, dropout_p=dropout_p)
149
+
150
+ def forward(self, feature, adj, docnum, secnum):
151
+ sen_adj = adj.clone()
152
+ sen_adj[:, -docnum - secnum - 1:, :] = sen_adj[:, :, -docnum - secnum - 1:] = 0
153
+ sec_adj = adj.clone()
154
+ sec_adj[:, :-docnum - secnum - 1, :] = sec_adj[:, -docnum - 1:, :] = sec_adj[:, :, -docnum - 1:] = 0
155
+ doc_adj = adj.clone()
156
+ doc_adj[:, :-docnum - 1, :] = 0
157
+
158
+ feature_sen = feature.clone()
159
+ feature_resi = feature
160
+
161
+ feature_sen_re = feature_sen.clone()
162
+ for i in range(0, self.iter):
163
+ feature_sen = self.gat[i](feature_sen, sen_adj, docnum, secnum)
164
+ feature_sen = F.layer_norm(feature_sen + feature_sen_re, [self.in_dim])
165
+
166
+ feature_sec = feature_sen.clone()
167
+ feature_sec_re = feature_sec.clone()
168
+ for i in range(0, self.iter):
169
+ feature_sec = self.gat2[i](feature_sec, sec_adj, docnum, secnum)
170
+ feature_sec = F.layer_norm(feature_sec + feature_sec_re, [self.in_dim])
171
+
172
+ feature_doc = feature_sec.clone()
173
+ feature_doc_re = feature_doc.clone()
174
+ for i in range(0, self.iter):
175
+ feature_doc = self.gat3[i](feature_doc, doc_adj, docnum, secnum)
176
+ feature_doc = F.layer_norm(feature_doc + feature_doc_re, [self.in_dim])
177
+
178
+ feature_sec[:, :-docnum - secnum - 1, :] = adj[:, :-docnum - secnum - 1,
179
+ -docnum - secnum - 1:-docnum - 1] @ feature_sec[:,
180
+ -docnum - secnum - 1:-docnum - 1,
181
+ :]
182
+ feature_doc[:, -docnum - secnum - 1:-docnum - 1, :] = adj[:, -docnum - secnum - 1:-docnum - 1,
183
+ -docnum - 1:] @ feature_doc[:, -docnum - 1:, :]
184
+ feature_doc[:, :-docnum - secnum - 1, :] = adj[:, :-docnum - secnum - 1,
185
+ -docnum - secnum - 1:-docnum - 1] @ feature_doc[:,
186
+ -docnum - secnum - 1:-docnum - 1,
187
+ :]
188
+ feature = torch.concat([feature_doc, feature_sec, feature_sen], dim=-1)
189
+ feature = F.layer_norm(self.out_ffn(feature) + feature_resi, [self.in_dim])
190
+ return feature
191
+
192
+
193
+ class Contrast_Encoder(nn.Module):
194
+ def __init__(self, input_dim, hidden_dim, heads, act=nn.LeakyReLU(0.1), dropout_p=0.1):
195
+ super(Contrast_Encoder, self).__init__()
196
+ self.graph_encoder = StepWiseGraphConvLayer(in_dim=input_dim, hid_dim=hidden_dim,
197
+ dropout_p=dropout_p, act=act, nheads=heads, iter=1)
198
+ self.common_proj_mlp = MLP([input_dim, hidden_dim, input_dim], layers=2, dropout_p=dropout_p, act=act,
199
+ keep_last_layer=False)
200
+
201
+ def forward(self, p_gfeature, doc_lens, p_adj, docnum, secnum):
202
+ posVec = torch.cat(
203
+ [PositionVec[:l] for l in doc_lens] + [torch.zeros(secnum + docnum + 1, 768).float().to(device)], dim=0)
204
+ p_gfeature = p_gfeature + posVec.unsqueeze(0)
205
+ pg = self.graph_encoder(p_gfeature, p_adj, docnum, secnum)
206
+ pg = self.common_proj_mlp(pg)
207
+ return pg
208
+
209
+
210
+ class End2End_Encoder(nn.Module):
211
+ def __init__(self, input_dim, hidden_dim, heads, act=nn.LeakyReLU(0.1), dropout_p=0.3):
212
+ super(End2End_Encoder, self).__init__()
213
+ self.graph_encoder = StepWiseGraphConvLayer(in_dim=input_dim, hid_dim=hidden_dim,
214
+ dropout_p=dropout_p, act=act, nheads=heads, iter=1)
215
+ self.dropout = nn.Dropout(dropout_p)
216
+ self.out_proj_layer_mlp = MLP([input_dim, hidden_dim, input_dim], layers=2, dropout_p=dropout_p, act=act,
217
+ keep_last_layer=False)
218
+ self.linear = MLP([input_dim, 1], layers=1, dropout_p=dropout_p, act=act, keep_last_layer=True)
219
+
220
+ def forward(self, x, doc_lens, adj, docnum, secnum):
221
+ x = self.graph_encoder(x, adj, docnum, secnum)
222
+ x = self.out_proj_layer_mlp(x)
223
+ return self.linear(x)[:, :-docnum - secnum - 1, :]
224
+
225
+
226
+ def _similarity(h1: torch.Tensor, h2: torch.Tensor):
227
+ h1 = F.normalize(h1)
228
+ h2 = F.normalize(h2)
229
+ return h1 @ h2.t()
230
+
231
+
232
+ class InfoNCE(nn.Module):
233
+ def __init__(self, tau):
234
+ super(InfoNCE, self).__init__()
235
+ self.tau = tau
236
+
237
+ def forward(self, anchor, sample, pos_mask, *args, **kwargs):
238
+ sim = _similarity(anchor, sample) / self.tau
239
+ if len(anchor) > 1:
240
+ sim, _ = torch.max(sim, dim=0, keepdim=True)
241
+ exp_sim = torch.exp(sim)
242
+ loss = torch.log((exp_sim * pos_mask).sum(dim=1)) - torch.log(exp_sim.sum(dim=1))
243
+ return -loss.mean()
244
+
245
+
246
+ class Cluster:
247
+ def __init__(self, sent_texts, sent_vecs, doc_lens, doc_sec_mask, sec_sen_mask):
248
+ assert len(sent_vecs) == len(sent_texts)
249
+ self.docnum = len(doc_sec_mask)
250
+ self.secnum = len(sec_sen_mask)
251
+ self.feature = torch.cat(
252
+ (torch.stack(sent_vecs, dim=0), torch.zeros((self.secnum + self.docnum + 1, sent_vecs[0].shape[0]))),
253
+ dim=0).to(device)
254
+ self.adj = torch.from_numpy(self.mask_to_adj(doc_sec_mask, sec_sen_mask)).float().to(device)
255
+ self.sent_text = np.array(sent_texts)
256
+ self.doc_lens = doc_lens
257
+ self.init_node_vec()
258
+ self.feature = self.feature.float()
259
+
260
+ def init_node_vec(self):
261
+ docnum, secnum = self.docnum, self.secnum
262
+ for i in range(-secnum - docnum - 1, -docnum - 1):
263
+ mask = self.adj[i].clone()
264
+ mask[-secnum - docnum - 1:] = 0
265
+ self.feature[i] = torch.mean(self.feature[mask.bool()], dim=0)
266
+ for i in range(-docnum - 1, -1):
267
+ mask = self.adj[i].clone()
268
+ mask[-docnum - 1:] = 0
269
+ self.feature[i] = torch.mean(self.feature[mask.bool()], dim=0)
270
+ self.feature[-1] = torch.mean(self.feature[-docnum - 1:-1], dim=0)
271
+
272
+ def mask_to_adj(self, doc_sec_mask, sec_sen_mask):
273
+ sen_num = sec_sen_mask.shape[1]
274
+ sec_num = sec_sen_mask.shape[0]
275
+ doc_num = doc_sec_mask.shape[0]
276
+ adj = np.zeros((sen_num + sec_num + doc_num + 1, sen_num + sec_num + doc_num + 1))
277
+ # section connection
278
+ adj[-sec_num - doc_num - 1:-doc_num - 1, 0:-sec_num - doc_num - 1] = sec_sen_mask
279
+ adj[0:-sec_num - doc_num - 1, -sec_num - doc_num - 1:-doc_num - 1] = sec_sen_mask.T
280
+ for i in range(0, doc_num):
281
+ doc_mask = doc_sec_mask[i]
282
+ doc_mask = doc_mask.reshape((1, len(doc_mask)))
283
+ adj[sen_num:-doc_num - 1, sen_num:-doc_num - 1] += doc_mask * doc_mask.T
284
+ # doc connection
285
+ adj[-doc_num - 1:-1, -sec_num - doc_num - 1:-doc_num - 1] = doc_sec_mask
286
+ adj[-sec_num - doc_num - 1:-doc_num - 1, -doc_num - 1:-1] = doc_sec_mask.T
287
+ adj[-doc_num - 1:, -doc_num - 1:] = 1
288
+
289
+ #build sentence connection
290
+ for i in range(0, sec_num):
291
+ sec_mask = sec_sen_mask[i]
292
+ sec_mask = sec_mask.reshape((1, len(sec_mask)))
293
+ adj[:sen_num, :sen_num] += sec_mask * sec_mask.T
294
+ return adj
295
+
296
+
297
+ def meanTokenVecs(text):
298
+ sent = text.lower()
299
+ input_ids = torch.tensor([phobert_tokenizer.encode(sent)])
300
+ tokenized_text = phobert_tokenizer.tokenize(sent)
301
+ with torch.no_grad():
302
+ features = phobert(input_ids.to(device))
303
+ wordVecs, buffer, buffer_str = {}, [], ''
304
+ for token in zip(tokenized_text, features.last_hidden_state[0, 1:-1, :]):
305
+ if token[0][-2:] == '@@':
306
+ buffer.append(token[1])
307
+ buffer_str += token[0][:-2]
308
+ continue
309
+ if buffer:
310
+ buffer.append(token[1])
311
+ buffer_str += token[0]
312
+ wordVecs[buffer_str] = torch.mean(torch.stack(buffer), dim=0)
313
+ buffer, buffer_str = [], ''
314
+ else:
315
+ wordVecs[token[0]] = token[1]
316
+
317
+ return torch.mean(torch.stack([vec for w, vec in wordVecs.items() if w not in string.punctuation]), dim=0).to(
318
+ torch.device('cpu'))
319
+
320
+
321
+ def getPositionEncoding(pos, d=768, n=10000):
322
+ P = np.zeros(d)
323
+ for i in np.arange(int(d / 2)):
324
+ denominator = np.power(n, 2 * i / d)
325
+ P[2 * i] = np.sin(pos / denominator)
326
+ P[2 * i + 1] = np.cos(pos / denominator)
327
+ return P
328
+
329
+
330
+ def removeRedundant(text):
331
+ text = text.lower()
332
+ words = [w for w in text.split(' ') if w not in stop_w]
333
+ return ' '.join(words)
334
+
335
+
336
+ def divideSection(doc_text, category='Giáo dục'):
337
+ sent_para, para_sec, sent_sec = {}, {}, {}
338
+
339
+ paras = [para for para in doc_text.split('\n') if para != '']
340
+ all_sents = []
341
+ # prepare sent_Para
342
+ sentcnt = 0
343
+ for i, para in enumerate(paras):
344
+ sents = [word_tokenize(sent, format="text") for sent in sent_tokenize(para) if sent != '' and len(sent) > 4]
345
+ all_sents.extend(sents)
346
+ for ii, sent in enumerate(sents):
347
+ sent_para[sentcnt + ii] = i
348
+ sent = removeRedundant(sent)
349
+ sentcnt += len(sents)
350
+
351
+ # prepare para_sec
352
+ paras = [removeRedundant(para) for para in paras]
353
+ tf, lda_model = cate_models[category]
354
+ X = tf.transform(paras)
355
+ lda_top = lda_model.transform(X)
356
+ for i, para_top in enumerate(lda_top):
357
+ para_sec[i] = para_top.argmax()
358
+
359
+ # output sent_sec
360
+ for k, v in sent_para.items():
361
+ sent_sec[k] = para_sec[v]
362
+ return sent_sec, all_sents
363
+
364
+
365
+ def loadClusterData(docs_org, category='Giáo dục'): # docs_org: list of text for each document
366
+ seclist, docs = {}, []
367
+ for d, doc in enumerate(docs_org):
368
+ seclist[d], sentTexts = divideSection(doc, category)
369
+ docs.append(sentTexts)
370
+
371
+ secnum = 0
372
+ for k, val_dict in seclist.items():
373
+ vals = set(val_dict.values())
374
+ for ki, vi in val_dict.items():
375
+ for i, v in enumerate(vals):
376
+ if vi == v:
377
+ val_dict[ki] = i + secnum
378
+ break
379
+ seclist[k] = val_dict
380
+ secnum += len(vals)
381
+
382
+ sents, sentVecs, secIDs, doc_lens = [], [], [], []
383
+ sentnum = sum([len(doc.values()) for doc in seclist.values()])
384
+ doc_sec_mask = np.zeros((len(docs), secnum))
385
+ sec_sen_mask = np.zeros((secnum, sentnum))
386
+ cursec, cursent = 0, 0
387
+
388
+ for d, doc in enumerate(docs):
389
+ doc_lens.append(len(doc))
390
+ doc_endsec = max(seclist[d].values())
391
+ doc_sec_mask[d][cursec:doc_endsec + 1] = 1
392
+ cursec = doc_endsec + 1
393
+ for s, sent in enumerate(doc):
394
+ sents.append(sent)
395
+ sentVecs.append(meanTokenVecs(sent))
396
+ sec_sen_mask[seclist[d][s], cursent] = 1
397
+ cursent += 1
398
+
399
+ return Cluster(sents, sentVecs, doc_lens, doc_sec_mask, sec_sen_mask)
400
+
401
+
402
+ def val_e2e(data):
403
+ feature = data.feature.unsqueeze(0)
404
+ doc_lens = data.doc_lens
405
+ adj = data.adj.unsqueeze(0)
406
+ docnum = data.docnum
407
+ secnum = data.secnum
408
+
409
+ with torch.no_grad():
410
+ feature = c_model(feature, doc_lens, adj, docnum, secnum)
411
+ x = model(feature, doc_lens, adj, docnum, secnum)
412
+ scores = torch.sigmoid(x.squeeze(-1))
413
+
414
+ return scores, data.sent_text
415
+
416
+
417
+ def normalize_text(text):
418
+ text = str(text).replace('_', ' ')
419
+ text = re.sub(r'\s+', ' ', text)
420
+ text = re.sub(r'\s+([.,;:?)/!?”])', r'\1', text)
421
+ text = re.sub(r'([\(“])\s+', r'\1', text)
422
+ return text
423
+
424
+
425
+ def track_changes(old_words, new_words):
426
+ # Find the longest common subsequence (LCS) between the two word sequences
427
+ def get_lcs_matrix(words1, words2):
428
+ m, n = len(words1), len(words2)
429
+ dp = [[0] * (n + 1) for _ in range(m + 1)]
430
+
431
+ for i in range(1, m + 1):
432
+ for j in range(1, n + 1):
433
+ if words1[i - 1] == words2[j - 1]:
434
+ dp[i][j] = dp[i - 1][j - 1] + 1
435
+ else:
436
+ dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
437
+
438
+ return dp
439
+
440
+ def get_lcs(words1, words2, dp):
441
+ i, j = len(words1), len(words2)
442
+ lcs = []
443
+
444
+ while i > 0 and j > 0:
445
+ if words1[i - 1] == words2[j - 1]:
446
+ lcs.append((i - 1, j - 1))
447
+ i -= 1
448
+ j -= 1
449
+ elif dp[i - 1][j] > dp[i][j - 1]:
450
+ i -= 1
451
+ else:
452
+ j -= 1
453
+
454
+ return sorted(lcs)
455
+
456
+ # Find the changed segments at word level
457
+ dp_matrix = get_lcs_matrix(old_words, new_words)
458
+ lcs_positions = get_lcs(old_words, new_words, dp_matrix)
459
+
460
+ changes = []
461
+ old_pos = 0
462
+ new_pos = 0
463
+
464
+ # Process matching and non-matching segments
465
+ for old_idx, new_idx in lcs_positions:
466
+ # If there's a gap before this match, it's a change
467
+ if old_idx > old_pos or new_idx > new_pos:
468
+ changes.append((old_pos, old_idx, new_pos, new_idx))
469
+
470
+ # Move positions after the match
471
+ old_pos = old_idx + 1
472
+ new_pos = new_idx + 1
473
+
474
+ # Check if there's a change at the end
475
+ if old_pos < len(old_words) or new_pos < len(new_words):
476
+ changes.append((old_pos, len(old_words), new_pos, len(new_words)))
477
+
478
+ return changes
479
+
480
+
481
+ class Abstractive_Summarization:
482
+ @staticmethod
483
+ def generateSummaryBySent(texts, batch=32):
484
+ model_summarization.eval()
485
+ predictions = []
486
+ with torch.no_grad():
487
+ for i in range(0, len(texts), batch):
488
+ batch_texts = texts[i:i + batch]
489
+ inputs = tokenizer_summarization(batch_texts, padding=True, max_length=1024, truncation=True,
490
+ return_tensors='pt').to(device)
491
+ outputs = model_summarization.generate(**inputs, num_beams=5,
492
+ early_stopping=True, no_repeat_ngram_size=3)
493
+ prediction = tokenizer_summarization.batch_decode(outputs, skip_special_tokens=True)
494
+ predictions.extend(prediction)
495
+ return predictions
496
+
497
+
498
+ PositionVec = torch.stack([torch.from_numpy(getPositionEncoding(i, d=768)) for i in range(200)], dim=0).float().to(
499
+ device)
500
+ stop_w = ['...']
501
+ # with open(stopword_path, 'r', encoding='utf-8') as f:
502
+ # for w in f.readlines():
503
+ # stop_w.append(w.strip())
504
+ stop_w.extend([c for c in '!"#$%&\'()*+,./:;<=>?@[\\]^`{|}~…“”’‘'])
505
+
506
+ with open(LDA_model_path, mode='rb') as fp:
507
+ cate_models = pickle.load(fp)
508
+
509
+ c_model = Contrast_Encoder(768, 1024, 4).to(device)
510
+ model = End2End_Encoder(768, 1024, 4).to(device)
511
+ model.load_state_dict(torch.load(extractive_model_path, map_location=device), strict=False)
512
+ c_model.load_state_dict(torch.load(contrastive_model_path, map_location=device), strict=False)
513
+ model.eval()
514
+ c_model.eval()
515
+
516
+ def get_summary(scores, sents, max_sent=5):
517
+ ranked_score_idxs = torch.argsort(scores[0], dim=0, descending=True)
518
+ sents = [s.replace('_', ' ') for s in sents]
519
+ summSentIDList = []
520
+ for i in ranked_score_idxs:
521
+ if len(summSentIDList) >= max_sent: break
522
+ s = sents[i]
523
+
524
+ replicated, delIDs = False, []
525
+ for chosedID in summSentIDList:
526
+ if getRouge2(s, sents[chosedID], 'p') >= 0.45:
527
+ delIDs.append(chosedID)
528
+ if getRouge2(sents[chosedID], s, 'p') >= 0.45:
529
+ replicated = True
530
+ break
531
+ if replicated: continue
532
+
533
+ for delID in delIDs:
534
+ del summSentIDList[summSentIDList.index(delID)]
535
+ summSentIDList.append(i)
536
+ summSentIDList = sorted(summSentIDList)
537
+ return [s for i, s in enumerate(sents) if i in summSentIDList]
538
+
539
+ def MultiDocSummarizationAPI(texts, compress_ratio):
540
+ """
541
+ Summarizes a list of documents using both extractive and abstractive methods.
542
+
543
+ Parameters:
544
+ - texts (list of str): A list of document texts to be summarized.
545
+ - compress_ratio (float): A ratio or count determining the number of sentences in the summary.
546
+ If less than 1, it represents the fraction of the original sentences to include in the summary.
547
+ If 1 or greater, it represents the exact number of sentences to include in the summary.
548
+
549
+ Returns:
550
+ - dict: A dictionary containing:
551
+ - 'extractive_summ' (str): The extractive summary of the documents.
552
+ - 'abstractive_summ' (str): The abstractive summary of the documents.
553
+ """
554
+ assert compress_ratio > 0, "Compress ratio need to be greater than 0."
555
+ docs = [text.strip() for text in texts]
556
+ data_tree = loadClusterData(docs)
557
+ scores, sents = val_e2e(data_tree)
558
+
559
+ output_sent_cnt = int(len(sents) * compress_ratio) if compress_ratio < 1 else int(compress_ratio)
560
+ print('Expected sentence count:', output_sent_cnt)
561
+
562
+ extractive_summ_sents = [normalize_text(sent) for sent in get_summary(scores, sents, max_sent=output_sent_cnt)]
563
+ extractive_summ = ' '.join(extractive_summ_sents)
564
+
565
+ abstractive_summ_sents = Abstractive_Summarization.generateSummaryBySent(extractive_summ_sents)
566
+ abstractive_summ_sents = [normalize_text(s) for s in abstractive_summ_sents]
567
+ final_sents = []
568
+ for ii, (ext, abs) in enumerate(zip(extractive_summ_sents, abstractive_summ_sents)):
569
+ if ii == 0:
570
+ final_sents.append(ext)
571
+ continue
572
+ abs_splits, ext_splits = word_tokenize(abs), word_tokenize(ext)
573
+ abs_splits_cop, ext_splits_cop = abs_splits.copy(), ext_splits.copy()
574
+ if len(abs_splits_cop):
575
+ abs_splits_cop[-1] = abs_splits[-1][:-1] if len(abs_splits[-1]) and abs_splits[-1][-1] == '.' else abs_splits[-1]
576
+ if len(ext_splits_cop):
577
+ ext_splits_cop[-1] = ext_splits[-1][:-1] if len(ext_splits[-1]) and ext_splits[-1][-1] == '.' else ext_splits[-1]
578
+
579
+ changes, abs_parts = track_changes(ext_splits_cop, abs_splits_cop), [(0, len(abs_splits))]
580
+ for start_old, end_old, start_new, end_new in changes:
581
+ old_part = ' '.join(ext_splits[start_old:end_old])
582
+ # Revert change in the cases of spelling errors
583
+ revert, ignoreFirstSentWord = False, 1 if start_old == 0 else 0
584
+ old_names = {}
585
+ for w in ext_splits_cop[start_old + ignoreFirstSentWord:end_old]:
586
+ if len(w) == 0: continue
587
+ if 'A'<=w[0]<='Z' or w[0] in ['Ä‚', 'Ă‚', 'Đ', 'Ê', 'Ă”', 'Æ ', 'Ư']:
588
+ if w in old_names:
589
+ old_names[w] += 1
590
+ else:
591
+ old_names[w] = 1
592
+
593
+ for w in abs_splits_cop[start_new + ignoreFirstSentWord:end_new]:
594
+ if len(w) == 0: continue
595
+ if 'A'<=w[0]<='Z' or w[0] in ['Ä‚', 'Ă‚', 'Đ', 'Ê', 'Ă”', 'Æ ', 'Ư']:
596
+ if w in old_names:
597
+ old_names[w] -= 1
598
+ if old_names[w] < 0:
599
+ revert = True
600
+ break
601
+ else:
602
+ revert = True
603
+ break
604
+ if revert:
605
+ pop_part = abs_parts[-1]
606
+ abs_parts.pop()
607
+ abs_parts.extend([(pop_part[0], start_new), old_part, (end_new, pop_part[1])])
608
+ # print('\nOLD:', old_part, '\n', ' '.join(abs_splits[start_new:end_new]))
609
+ # print(ext, '\n', abs)
610
+
611
+ abs = ' '.join([part if isinstance(part, str) else ' '.join(abs_splits[part[0]:part[1]]) for part in abs_parts])
612
+ final_sents.append(normalize_text(abs))
613
+ abstract_summ = ' '.join(final_sents)
614
+
615
+ return {'extractive_summ': extractive_summ,
616
  'abstractive_summ': abstract_summ}