aaljabari commited on
Commit
82e50c2
·
verified ·
1 Parent(s): c580da0

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +611 -0
main.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ import torch
3
+ import pickle
4
+ from huggingface_hub import hf_hub_download, snapshot_download
5
+ from Nested.nn.BertSeqTagger import BertSeqTagger
6
+ from transformers import AutoTokenizer, AutoModel
7
+ import inspect
8
+ from collections import namedtuple
9
+ from Nested.utils.helpers import load_checkpoint
10
+ from Nested.utils.data import get_dataloaders, text2segments
11
+ import json
12
+ from pydantic import BaseModel
13
+ from fastapi.responses import JSONResponse
14
+ from IBO_to_XML import IBO_to_XML
15
+ from XML_to_HTML import NER_XML_to_HTML
16
+ from NER_Distiller import distill_entities
17
+
18
+ app = FastAPI()
19
+
20
+ pretrained_path = "aubmindlab/bert-base-arabertv2" # must match training
21
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
22
+ encoder = AutoModel.from_pretrained(pretrained_path).eval()
23
+
24
+
25
+ checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")
26
+
27
+ args_path = hf_hub_download(
28
+ repo_id="SinaLab/Nested",
29
+ filename="args.json"
30
+ )
31
+
32
+ with open(args_path, 'r') as f:
33
+ args_data = json.load(f)
34
+
35
+ # Load model
36
+ with open("Nested/utils/tag_vocab.pkl", "rb") as f:
37
+ label_vocab = pickle.load(f)
38
+
39
+ label_vocab = label_vocab[0] # the list loaded from pickle
40
+ id2label = {i: s for i, s in enumerate(label_vocab.itos)}
41
+
42
+ def split_text_into_groups_of_Ns(sentence, max_words_per_sentence):
43
+ # Split the text into words
44
+ words = sentence.split()
45
+
46
+ # Initialize variables
47
+ groups = []
48
+ current_group = ""
49
+ group_size = 0
50
+
51
+ # Iterate through the words
52
+ for word in words:
53
+ if group_size < max_words_per_sentence - 1:
54
+ if len(current_group) == 0:
55
+ current_group = word
56
+ else:
57
+ current_group += " " + word
58
+ group_size += 1
59
+ else:
60
+ current_group += " " + word
61
+ groups.append(current_group)
62
+ current_group = ""
63
+ group_size = 0
64
+
65
+ # Add the last group if it contains less than n words
66
+ if current_group:
67
+ groups.append(current_group)
68
+
69
+ return groups
70
+
71
+
72
+
73
+ def remove_empty_values(sentences):
74
+ return [value for value in sentences if value != '']
75
+
76
+
77
+ def sentence_tokenizer(text, dot=True, new_line=True, question_mark=True, exclamation_mark=True):
78
+ separators = []
79
+ split_text = [text]
80
+ if new_line==True:
81
+ separators.append('\n')
82
+ if dot==True:
83
+ separators.append('.')
84
+ if question_mark==True:
85
+ separators.append('?')
86
+ separators.append('؟')
87
+ if exclamation_mark==True:
88
+ separators.append('!')
89
+
90
+ for sep in separators:
91
+ new_split_text = []
92
+ for part in split_text:
93
+ tokens = part.split(sep)
94
+ tokens_with_separator = [token + sep for token in tokens[:-1]]
95
+ tokens_with_separator.append(tokens[-1].strip())
96
+ new_split_text.extend(tokens_with_separator)
97
+ split_text = new_split_text
98
+
99
+ split_text = remove_empty_values(split_text)
100
+ return split_text
101
+
102
+ def jsons_to_list_of_lists(json_list):
103
+ return [[d['token'], d['tags']] for d in json_list]
104
+
105
+ tagger, tag_vocab, train_config = load_checkpoint(checkpoint_path)
106
+
107
+ def extract(sentence):
108
+ dataset, token_vocab = text2segments(sentence)
109
+
110
+ vocabs = namedtuple("Vocab", ["tags", "tokens"])
111
+ vocab = vocabs(tokens=token_vocab, tags=tag_vocab)
112
+
113
+ dataloader = get_dataloaders(
114
+ (dataset,),
115
+ vocab,
116
+ args_data,
117
+ batch_size=32,
118
+ shuffle=(False,),
119
+ )[0]
120
+
121
+ segments = tagger.infer(dataloader)
122
+
123
+ lists = []
124
+
125
+ for segment in segments:
126
+ for token in segment:
127
+ item = {}
128
+ item["token"] = token.text
129
+
130
+ list_of_tags = [t["tag"] for t in token.pred_tag]
131
+ list_of_tags = [i for i in list_of_tags if i not in ("O", " ", "")]
132
+
133
+ if not list_of_tags:
134
+ item["tags"] = "O"
135
+ else:
136
+ item["tags"] = " ".join(list_of_tags)
137
+ lists.append(item)
138
+ return lists
139
+
140
+
141
+ def NER(sentence, mode):
142
+ output_list = []
143
+ xml = ""
144
+ if mode.strip() == "1":
145
+ output_list = jsons_to_list_of_lists(extract(sentence))
146
+ return output_list
147
+ elif mode.strip() == "2":
148
+ if output_list != []:
149
+ xml = IBO_to_XML(output_list)
150
+ return xml
151
+ else:
152
+ output_list = jsons_to_list_of_lists(extract(sentence))
153
+ xml = IBO_to_XML(output_list)
154
+ return xml
155
+
156
+ elif mode.strip() == "3":
157
+ if xml != "":
158
+ html = NER_XML_to_HTML(xml)
159
+ return html
160
+ else:
161
+ output_list = jsons_to_list_of_lists(extract(sentence))
162
+ xml = IBO_to_XML(output_list)
163
+ html = NER_XML_to_HTML(xml)
164
+ return html
165
+
166
+ elif mode.strip() == "4": # json short
167
+ if output_list != []:
168
+ json_short = distill_entities(output_list)
169
+ return json_short
170
+ else:
171
+ output_list = jsons_to_list_of_lists(extract(sentence))
172
+ json_short = distill_entities(output_list)
173
+ return json_short
174
+
175
+
176
+
177
+ class NERRequest(BaseModel):
178
+ text: str
179
+ mode: str
180
+
181
+ @app.post("/predict")
182
+ def predict(request: NERRequest):
183
+ # Load tagger
184
+ text = request.text
185
+ mode = request.mode
186
+
187
+ sentences = sentence_tokenizer(
188
+ text, dot=False, new_line=True, question_mark=False, exclamation_mark=False
189
+ )
190
+
191
+ lists = []
192
+ for sentence in sentences:
193
+ se = split_text_into_groups_of_Ns(sentence, max_words_per_sentence=300)
194
+ for s in se:
195
+ output_list = NER(s, mode)
196
+ lists.append(output_list)
197
+
198
+ content = {
199
+ "resp": lists,
200
+ "statusText": "OK",
201
+ "statusCode": 0,
202
+ }
203
+
204
+ return JSONResponse(
205
+ content=content,
206
+ media_type="application/json",
207
+ status_code=200,
208
+ )
209
+
210
+
211
+ # ============ Relation Extraction ==============
212
+ import torch.nn as nn
213
+ import torch.nn.functional as F
214
+ from transformers import PreTrainedTokenizerFast, BertModel
215
+ from itertools import permutations
216
+ from collections import defaultdict
217
+
218
+
219
+ # =========================
220
+ # Relation Extraction Model
221
+ # =========================
222
+ repo_id = "aaljabari/arabic-relation-extraction-v1"
223
+
224
+ # tokenizer
225
+ relation_tokenizer = PreTrainedTokenizerFast(
226
+ tokenizer_file=hf_hub_download(repo_id, "tokenizer.json")
227
+ )
228
+
229
+ # vocab
230
+ rel_vocab_path = hf_hub_download(repo_id, "tag_vocab.pkl")
231
+ with open(rel_vocab_path, "rb") as f:
232
+ vocab = pickle.load(f)
233
+
234
+ rel2id = vocab["rel2id"]
235
+ id2rel = vocab["id2rel"]
236
+
237
+
238
+ class BertRE(nn.Module):
239
+ def __init__(self, num_labels):
240
+ super().__init__()
241
+ self.bert = BertModel.from_pretrained(repo_id)
242
+
243
+ hidden = self.bert.config.hidden_size
244
+ self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
245
+ self.classifier = nn.Linear(hidden * 2, num_labels)
246
+
247
+ def forward(self, input_ids, attention_mask, sub_pos, obj_pos):
248
+ outputs = self.bert(
249
+ input_ids=input_ids,
250
+ attention_mask=attention_mask
251
+ )
252
+
253
+ hidden = outputs.last_hidden_state
254
+ batch = hidden.shape[0]
255
+
256
+ sub_vec = hidden[torch.arange(batch), sub_pos]
257
+ obj_vec = hidden[torch.arange(batch), obj_pos]
258
+
259
+ pair = torch.cat([sub_vec, obj_vec], dim=1)
260
+ pair = self.dropout(pair)
261
+
262
+ return self.classifier(pair)
263
+
264
+ weights_path = hf_hub_download(repo_id, "pytorch_model.bin")
265
+
266
+ re_model = BertRE(num_labels=len(rel2id))
267
+ re_model.load_state_dict(torch.load(weights_path, map_location="cpu"))
268
+ re_model.eval()
269
+
270
+
271
+ def convert_ner_format(ner_output):
272
+ return [[item["token"], item["tags"]] for item in ner_output]
273
+
274
+ def entities_and_types(sentence):
275
+ ner_output = extract(sentence)
276
+ converted = convert_ner_format(ner_output)
277
+ entities = distill_entities(converted)
278
+ entity_dict = {}
279
+ for name, entity_type, _, _ in entities:
280
+ entity_dict[name] = entity_type
281
+
282
+ return entity_dict
283
+
284
+ relation_domain_range=[
285
+ {
286
+ "relation": "manager_of",
287
+ "domain": ["PERS"],
288
+ "range": ["ORG", "FAC"]
289
+ },
290
+ {
291
+ "relation": "birth_date",
292
+ "domain": ["PERS"],
293
+ "range": ["DATE"]
294
+ },
295
+ {
296
+ "relation": "has_parent",
297
+ "domain": ["PERS"],
298
+ "range": ["PERS"]
299
+ },
300
+ {
301
+ "relation": "has_sibling",
302
+ "domain": ["PERS"],
303
+ "range": ["PERS"]
304
+ },
305
+ {
306
+ "relation": "has_spouse",
307
+ "domain": ["PERS"],
308
+ "range": ["PERS"]
309
+ },
310
+ {
311
+ "relation": "has_relative",
312
+ "domain": ["PERS"],
313
+ "range": ["PERS"]
314
+ },
315
+ {
316
+ "relation": "death_date",
317
+ "domain": ["PERS"],
318
+ "range": ["DATE"]
319
+ },
320
+ {
321
+ "relation": "birth_place",
322
+ "domain": ["PERS"],
323
+ "range": ["GPE", "LOC"]
324
+ },
325
+ {
326
+ "relation": "has_occupation",
327
+ "domain": ["PERS"],
328
+ "range": ["OCC"]
329
+ },
330
+ {
331
+ "relation": "has_conflict_with",
332
+ "domain": ["ORG", "NORP", "GPE"],
333
+ "range": ["ORG", "NORP", "GPE"]
334
+ },
335
+ {
336
+ "relation": "has_compititor",
337
+ "domain": ["PERS", "ORG"],
338
+ "range": ["PERS", "ORG"]
339
+ },
340
+ {
341
+ "relation": "has_partner_with",
342
+ "domain": ["ORG"],
343
+ "range": ["ORG"]
344
+ },
345
+ {
346
+ "relation": "president_of",
347
+ "domain": ["PERS"],
348
+ "range": ["ORG", "GPE"]
349
+ },
350
+ {
351
+ "relation": "leader_of",
352
+ "domain": ["PERS"],
353
+ "range": ["ORG"]
354
+ },
355
+ {
356
+ "relation": "geopolitical_division",
357
+ "domain": ["GPE", "LOC"],
358
+ "range": ["GPE", "LOC"]
359
+ },
360
+ {
361
+ "relation": "member_of",
362
+ "domain": ["PERS"],
363
+ "range": ["ORG", "NORP"]
364
+ },
365
+ {
366
+ "relation": "subsidary",
367
+ "domain": ["ORG"],
368
+ "range": ["ORG"]
369
+ },
370
+ {
371
+ "relation": "employee_of",
372
+ "domain": ["PERS"],
373
+ "range": ["ORG", "FAC"]
374
+ },
375
+ {
376
+ "relation": "student_at",
377
+ "domain": ["PERS"],
378
+ "range": ["ORG"]
379
+ },
380
+ {
381
+ "relation": "owner_of",
382
+ "domain": ["PERS"],
383
+ "range": ["ORG", "FAC"]
384
+ },
385
+ {
386
+ "relation": "inventor_of",
387
+ "domain": ["PERS"],
388
+ "range": ["PRODUCT"]
389
+ },
390
+ {
391
+ "relation": "manufacturer_of",
392
+ "domain": ["ORG"],
393
+ "range": ["PRODUCT"]
394
+ },
395
+ {
396
+ "relation": "builder_of",
397
+ "domain": ["PERS", "NORP"],
398
+ "range": ["FAC"]
399
+ },
400
+ {
401
+ "relation": "founder_of",
402
+ "domain": ["PERS"],
403
+ "range": ["ORG"]
404
+ },
405
+ {
406
+ "relation": "lives_in",
407
+ "domain": ["PERS", "NORP"],
408
+ "range": ["GPE", "LOC"]
409
+ },
410
+ {
411
+ "relation": "located_in",
412
+ "domain": ["FAC", "ORG"],
413
+ "range": ["GPE", "LOC"]
414
+ },
415
+ {
416
+ "relation": "headquartered_in",
417
+ "domain": ["ORG"],
418
+ "range": ["GPE", "LOC"]
419
+ },
420
+ {
421
+ "relation": "has_border_with",
422
+ "domain": ["LOC", "GPE"],
423
+ "range": ["LOC", "GPE"]
424
+ },
425
+ {
426
+ "relation": "nearby",
427
+ "domain": ["GPE", "LOC", "ORG", "FAC"],
428
+ "range": ["GPE", "LOC", "ORG", "FAC"]
429
+ },
430
+ {
431
+ "relation": "has_property",
432
+ "domain": ["ORG"],
433
+ "range": ["PRODUCT"]
434
+ },
435
+ {
436
+ "relation": "branch_count",
437
+ "domain": ["ORG"],
438
+ "range": ["CARDINAL"]
439
+ },
440
+ {
441
+ "relation": "has_revenue",
442
+ "domain": ["ORG"],
443
+ "range": ["MONEY"]
444
+ },
445
+ {
446
+ "relation": "employs",
447
+ "domain": ["ORG"],
448
+ "range": ["CARDINAL"]
449
+ },
450
+ {
451
+ "relation": "found_on",
452
+ "domain": ["ORG"],
453
+ "range": ["DATE", "TIME"]
454
+ },
455
+ {
456
+ "relation": "has_alternate_name",
457
+ "domain": ["ORG", "FAC"],
458
+ "range": ["ORG", "FAC"]
459
+ },
460
+ {
461
+ "relation": "has_area",
462
+ "domain": ["GPE", "LOC"],
463
+ "range": ["QUANTITY"]
464
+ },
465
+ {
466
+ "relation": "official_language",
467
+ "domain": ["GPE", "LOC"],
468
+ "range": ["LANGUAGE"]
469
+ },
470
+ {
471
+ "relation": "has_currency",
472
+ "domain": ["GPE", "LOC"],
473
+ "range": ["CURR"]
474
+ },
475
+ {
476
+ "relation": "has_population",
477
+ "domain": ["GPE"],
478
+ "range": ["CARDINAL"]
479
+ },
480
+ {
481
+ "relation": "capital_of",
482
+ "domain": ["GPE"],
483
+ "range": ["GPE"]
484
+ }
485
+ ]
486
+
487
+ relation_lookup = defaultdict(lambda: defaultdict(list))
488
+
489
+ for rel in relation_domain_range:
490
+ for d in rel["domain"]:
491
+ for r in rel["range"]:
492
+ relation_lookup[d][r].append(rel["relation"])
493
+
494
+ def insert_markers(sentence, ent1, ent2):
495
+ if ent1 not in sentence or ent2 not in sentence:
496
+ return None
497
+
498
+ marked = sentence
499
+ marked = marked.replace(ent1, f"[Sub] {ent1} [/Sub]", 1)
500
+ marked = marked.replace(ent2, f"[Obj] {ent2} [/Obj]", 1)
501
+
502
+ return marked
503
+
504
+ def encode(sentence):
505
+ enc = relation_tokenizer(
506
+ sentence,
507
+ max_length=128,
508
+ padding="max_length",
509
+ truncation=True,
510
+ return_tensors="pt"
511
+ )
512
+
513
+ input_ids = enc["input_ids"]
514
+ attention_mask = enc["attention_mask"]
515
+
516
+ sub_id = relation_tokenizer.convert_tokens_to_ids("[Sub]")
517
+ obj_id = relation_tokenizer.convert_tokens_to_ids("[Obj]")
518
+
519
+ sub_pos = (input_ids == sub_id).nonzero(as_tuple=True)[1]
520
+ obj_pos = (input_ids == obj_id).nonzero(as_tuple=True)[1]
521
+
522
+ return input_ids, attention_mask, sub_pos, obj_pos
523
+
524
+
525
+ def predict_relation(sentence):
526
+ input_ids, mask, sub_pos, obj_pos = encode(sentence)
527
+
528
+ if len(sub_pos) == 0 or len(obj_pos) == 0:
529
+ return None, 0.0
530
+
531
+ with torch.no_grad():
532
+ logits = re_model(input_ids, mask, sub_pos, obj_pos)
533
+
534
+ probs = F.softmax(logits, dim=-1)
535
+
536
+ pred = torch.argmax(probs, dim=-1).item()
537
+ conf = probs[0, pred].item()
538
+
539
+ return id2rel[pred], conf
540
+
541
+ def relation_extractor(sentence):
542
+ entities = entities_and_types(sentence)
543
+
544
+ output = []
545
+
546
+ entity_items = list(entities.items())
547
+ pairs = [(e1, e2) for e1, e2 in permutations(entity_items, 2)]
548
+
549
+ for (ent1, type1), (ent2, type2) in pairs:
550
+
551
+ valid_rels = relation_lookup.get(type1, {}).get(type2, [])
552
+ if not valid_rels:
553
+ continue
554
+
555
+ marked_sentence = insert_markers(sentence, ent1, ent2)
556
+ if marked_sentence is None:
557
+ continue
558
+
559
+ rel, conf = predict_relation(marked_sentence)
560
+
561
+ if rel is None:
562
+ continue
563
+
564
+ if conf > 0.80 and rel != "no_relation" and rel.split(".")[-1] in valid_rels:
565
+ output.append({
566
+ "Subject": {
567
+ "Type": type1,
568
+ "Label": ent1
569
+ },
570
+ "Relation": rel,
571
+ "Object": {
572
+ "Type": type2,
573
+ "Label": ent2
574
+ },
575
+ "Confidence": float(round(conf, 4))
576
+ })
577
+
578
+ return output
579
+
580
+
581
+ class RERequest(BaseModel):
582
+ text: str
583
+
584
+ @app.post("/predict_re")
585
+ def predict_re(request: RERequest):
586
+ try:
587
+ results = relation_extractor(request.text)
588
+
589
+ return JSONResponse(
590
+ content={
591
+ "resp": results,
592
+ "statusText": "OK",
593
+ "statusCode": 0,
594
+ },
595
+ media_type="application/json",
596
+ status_code=200,
597
+ )
598
+
599
+ except Exception as e:
600
+ return {"error": str(e)}
601
+
602
+ # =========== Front End =============================
603
+ from fastapi.staticfiles import StaticFiles
604
+ from fastapi.responses import FileResponse
605
+
606
+ # mount frontend
607
+ app.mount("/static", StaticFiles(directory="static"), name="static")
608
+
609
+ @app.get("/")
610
+ def home():
611
+ return FileResponse("static/index.html")