alaajabari commited on
Commit
0b0c965
·
verified ·
1 Parent(s): 9858485

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +379 -0
main.py CHANGED
@@ -207,6 +207,385 @@ def predict(request: NERRequest):
207
  status_code=200,
208
  )
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  from fastapi.staticfiles import StaticFiles
211
  from fastapi.responses import FileResponse
212
 
 
207
  status_code=200,
208
  )
209
 
210
+
211
+ # ============ Relation Extraction ==============
212
+ import torch.nn as nn
213
+ import torch.nn.functional as F
214
+ from transformers import PreTrainedTokenizerFast, BertModel
215
+ from itertools import permutations
216
+ from collections import defaultdict
217
+
218
+
219
+ # =========================
220
+ # Relation Extraction Model
221
+ # =========================
222
+ repo_id = "aaljabari/arabic-relation-extraction-v1"
223
+
224
+ # tokenizer
225
+ relation_tokenizer = PreTrainedTokenizerFast(
226
+ tokenizer_file=hf_hub_download(repo_id, "tokenizer.json")
227
+ )
228
+
229
+ # vocab
230
+ rel_vocab_path = hf_hub_download(repo_id, "tag_vocab.pkl")
231
+ with open(rel_vocab_path, "rb") as f:
232
+ vocab = pickle.load(f)
233
+
234
+ rel2id = vocab["rel2id"]
235
+ id2rel = vocab["id2rel"]
236
+
237
+
238
+ class BertRE(nn.Module):
239
+ def __init__(self, num_labels):
240
+ super().__init__()
241
+ self.bert = BertModel.from_pretrained(repo_id)
242
+
243
+ hidden = self.bert.config.hidden_size
244
+ self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
245
+ self.classifier = nn.Linear(hidden * 2, num_labels)
246
+
247
+ def forward(self, input_ids, attention_mask, sub_pos, obj_pos):
248
+ outputs = self.bert(
249
+ input_ids=input_ids,
250
+ attention_mask=attention_mask
251
+ )
252
+
253
+ hidden = outputs.last_hidden_state
254
+ batch = hidden.shape[0]
255
+
256
+ sub_vec = hidden[torch.arange(batch), sub_pos]
257
+ obj_vec = hidden[torch.arange(batch), obj_pos]
258
+
259
+ pair = torch.cat([sub_vec, obj_vec], dim=1)
260
+ pair = self.dropout(pair)
261
+
262
+ return self.classifier(pair)
263
+
264
+ weights_path = hf_hub_download(repo_id, "pytorch_model.bin")
265
+
266
+ re_model = BertRE(num_labels=len(rel2id))
267
+ re_model.load_state_dict(torch.load(weights_path, map_location="cpu"))
268
+ re_model.eval()
269
+
270
+ def entities_and_types(sentence):
271
+ ner_output = extract(sentence) # your NER
272
+ entities = distill_entities(ner_output)
273
+
274
+ entity_dict = {}
275
+ for name, entity_type, _, _ in entities:
276
+ entity_dict[name] = entity_type
277
+
278
+ return entity_dict
279
+
280
+ relation_domain_range=[
281
+ {
282
+ "relation": "manager_of",
283
+ "domain": ["PERS"],
284
+ "range": ["ORG", "FAC"]
285
+ },
286
+ {
287
+ "relation": "birth_date",
288
+ "domain": ["PERS"],
289
+ "range": ["DATE"]
290
+ },
291
+ {
292
+ "relation": "has_parent",
293
+ "domain": ["PERS"],
294
+ "range": ["PERS"]
295
+ },
296
+ {
297
+ "relation": "has_sibling",
298
+ "domain": ["PERS"],
299
+ "range": ["PERS"]
300
+ },
301
+ {
302
+ "relation": "has_spouse",
303
+ "domain": ["PERS"],
304
+ "range": ["PERS"]
305
+ },
306
+ {
307
+ "relation": "has_relative",
308
+ "domain": ["PERS"],
309
+ "range": ["PERS"]
310
+ },
311
+ {
312
+ "relation": "death_date",
313
+ "domain": ["PERS"],
314
+ "range": ["DATE"]
315
+ },
316
+ {
317
+ "relation": "birth_place",
318
+ "domain": ["PERS"],
319
+ "range": ["GPE", "LOC"]
320
+ },
321
+ {
322
+ "relation": "has_occupation",
323
+ "domain": ["PERS"],
324
+ "range": ["OCC"]
325
+ },
326
+ {
327
+ "relation": "has_conflict_with",
328
+ "domain": ["ORG", "NORP", "GPE"],
329
+ "range": ["ORG", "NORP", "GPE"]
330
+ },
331
+ {
332
+ "relation": "has_compititor",
333
+ "domain": ["PERS", "ORG"],
334
+ "range": ["PERS", "ORG"]
335
+ },
336
+ {
337
+ "relation": "has_partner_with",
338
+ "domain": ["ORG"],
339
+ "range": ["ORG"]
340
+ },
341
+ {
342
+ "relation": "president_of",
343
+ "domain": ["PERS"],
344
+ "range": ["ORG", "GPE"]
345
+ },
346
+ {
347
+ "relation": "leader_of",
348
+ "domain": ["PERS"],
349
+ "range": ["ORG"]
350
+ },
351
+ {
352
+ "relation": "geopolitical_division",
353
+ "domain": ["GPE", "LOC"],
354
+ "range": ["GPE", "LOC"]
355
+ },
356
+ {
357
+ "relation": "member_of",
358
+ "domain": ["PERS"],
359
+ "range": ["ORG", "NORP"]
360
+ },
361
+ {
362
+ "relation": "subsidary",
363
+ "domain": ["ORG"],
364
+ "range": ["ORG"]
365
+ },
366
+ {
367
+ "relation": "employee_of",
368
+ "domain": ["PERS"],
369
+ "range": ["ORG", "FAC"]
370
+ },
371
+ {
372
+ "relation": "student_at",
373
+ "domain": ["PERS"],
374
+ "range": ["ORG"]
375
+ },
376
+ {
377
+ "relation": "owner_of",
378
+ "domain": ["PERS"],
379
+ "range": ["ORG", "FAC"]
380
+ },
381
+ {
382
+ "relation": "inventor_of",
383
+ "domain": ["PERS"],
384
+ "range": ["PRODUCT"]
385
+ },
386
+ {
387
+ "relation": "manufacturer_of",
388
+ "domain": ["ORG"],
389
+ "range": ["PRODUCT"]
390
+ },
391
+ {
392
+ "relation": "builder_of",
393
+ "domain": ["PERS", "NORP"],
394
+ "range": ["FAC"]
395
+ },
396
+ {
397
+ "relation": "founder_of",
398
+ "domain": ["PERS"],
399
+ "range": ["ORG"]
400
+ },
401
+ {
402
+ "relation": "lives_in",
403
+ "domain": ["PERS", "NORP"],
404
+ "range": ["GPE", "LOC"]
405
+ },
406
+ {
407
+ "relation": "located_in",
408
+ "domain": ["FAC", "ORG"],
409
+ "range": ["GPE", "LOC"]
410
+ },
411
+ {
412
+ "relation": "headquartered_in",
413
+ "domain": ["ORG"],
414
+ "range": ["GPE", "LOC"]
415
+ },
416
+ {
417
+ "relation": "has_border_with",
418
+ "domain": ["LOC", "GPE"],
419
+ "range": ["LOC", "GPE"]
420
+ },
421
+ {
422
+ "relation": "nearby",
423
+ "domain": ["GPE", "LOC", "ORG", "FAC"],
424
+ "range": ["GPE", "LOC", "ORG", "FAC"]
425
+ },
426
+ {
427
+ "relation": "has_property",
428
+ "domain": ["ORG"],
429
+ "range": ["PRODUCT"]
430
+ },
431
+ {
432
+ "relation": "branch_count",
433
+ "domain": ["ORG"],
434
+ "range": ["CARDINAL"]
435
+ },
436
+ {
437
+ "relation": "has_revenue",
438
+ "domain": ["ORG"],
439
+ "range": ["MONEY"]
440
+ },
441
+ {
442
+ "relation": "employs",
443
+ "domain": ["ORG"],
444
+ "range": ["CARDINAL"]
445
+ },
446
+ {
447
+ "relation": "found_on",
448
+ "domain": ["ORG"],
449
+ "range": ["DATE", "TIME"]
450
+ },
451
+ {
452
+ "relation": "has_alternate_name",
453
+ "domain": ["ORG", "FAC"],
454
+ "range": ["ORG", "FAC"]
455
+ },
456
+ {
457
+ "relation": "has_area",
458
+ "domain": ["GPE", "LOC"],
459
+ "range": ["QUANTITY"]
460
+ },
461
+ {
462
+ "relation": "official_language",
463
+ "domain": ["GPE", "LOC"],
464
+ "range": ["LANGUAGE"]
465
+ },
466
+ {
467
+ "relation": "has_currency",
468
+ "domain": ["GPE", "LOC"],
469
+ "range": ["CURR"]
470
+ },
471
+ {
472
+ "relation": "has_population",
473
+ "domain": ["GPE"],
474
+ "range": ["CARDINAL"]
475
+ },
476
+ {
477
+ "relation": "capital_of",
478
+ "domain": ["GPE"],
479
+ "range": ["GPE"]
480
+ }
481
+ ]
482
+
483
+ relation_lookup = defaultdict(lambda: defaultdict(list))
484
+
485
+ for rel in relation_domain_range:
486
+ for d in rel["domain"]:
487
+ for r in rel["range"]:
488
+ relation_lookup[d][r].append(rel["relation"])
489
+
490
+
491
+ def insert_markers(sentence, ent1, ent2):
492
+ if ent1 not in sentence or ent2 not in sentence:
493
+ return None
494
+
495
+ marked = sentence
496
+ marked = marked.replace(ent1, f"[Sub] {ent1} [/Sub]", 1)
497
+ marked = marked.replace(ent2, f"[Obj] {ent2} [/Obj]", 1)
498
+
499
+ return marked
500
+
501
+ def encode(sentence):
502
+ enc = relation_tokenizer(
503
+ sentence,
504
+ max_length=128,
505
+ padding="max_length",
506
+ truncation=True,
507
+ return_tensors="pt"
508
+ )
509
+
510
+ input_ids = enc["input_ids"]
511
+ attention_mask = enc["attention_mask"]
512
+
513
+ sub_id = relation_tokenizer.convert_tokens_to_ids("[Sub]")
514
+ obj_id = relation_tokenizer.convert_tokens_to_ids("[Obj]")
515
+
516
+ sub_pos = (input_ids == sub_id).nonzero(as_tuple=True)[1]
517
+ obj_pos = (input_ids == obj_id).nonzero(as_tuple=True)[1]
518
+
519
+ return input_ids, attention_mask, sub_pos, obj_pos
520
+
521
+
522
+ def predict_relation(sentence):
523
+ input_ids, mask, sub_pos, obj_pos = encode(sentence)
524
+
525
+ if len(sub_pos) == 0 or len(obj_pos) == 0:
526
+ return None, 0.0
527
+
528
+ with torch.no_grad():
529
+ logits = re_model(input_ids, mask, sub_pos, obj_pos)
530
+
531
+ probs = F.softmax(logits, dim=-1)
532
+
533
+ pred = torch.argmax(probs, dim=-1).item()
534
+ conf = probs[0, pred].item()
535
+
536
+ return id2rel[pred], conf
537
+
538
+ def relation_extractor(sentence):
539
+ entities = entities_and_types(sentence)
540
+
541
+ output = []
542
+
543
+ entity_items = list(entities.items())
544
+ pairs = [(e1, e2) for e1, e2 in permutations(entity_items, 2)]
545
+
546
+ for (ent1, type1), (ent2, type2) in pairs:
547
+
548
+ valid_rels = relation_lookup.get(type1, {}).get(type2, [])
549
+ if not valid_rels:
550
+ continue
551
+
552
+ marked_sentence = insert_markers(sentence, ent1, ent2)
553
+ if marked_sentence is None:
554
+ continue
555
+
556
+ rel, conf = predict_relation(marked_sentence)
557
+
558
+ if rel is None:
559
+ continue
560
+
561
+ if conf > 0.80 and rel != "no_relation" and rel.split(".")[-1] in valid_rels:
562
+ output.append([ent1, rel, ent2, conf])
563
+
564
+ return output
565
+
566
+
567
+ class RERequest(BaseModel):
568
+ text: str
569
+
570
+ @app.post("/predict_re")
571
+ def predict_re(request: RERequest):
572
+ try:
573
+ results = relation_extractor(request.text)
574
+
575
+ return JSONResponse(
576
+ content={
577
+ "resp": results,
578
+ "statusText": "OK",
579
+ "statusCode": 0,
580
+ },
581
+ media_type="application/json",
582
+ status_code=200,
583
+ )
584
+
585
+ except Exception as e:
586
+ return {"error": str(e)}
587
+
588
+ # =========== Front End =============================
589
  from fastapi.staticfiles import StaticFiles
590
  from fastapi.responses import FileResponse
591