File size: 2,083 Bytes
fd20677
 
 
 
 
 
 
 
 
61185d9
fd20677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import torch.nn.functional as F
from itertools import permutations
from huggingface_hub import hf_hub_download
import pickle

from .model import BertRE
from .utils import encode, insert_markers
from .ontology import build_relation_lookup
from ner import entities_and_types   

repo_id = "aaljabari/arabic-relation-extraction-v1"

# load vocab
rel_vocab_path = hf_hub_download(repo_id, "tag_vocab.pkl")
with open(rel_vocab_path, "rb") as f:
    vocab = pickle.load(f)

rel2id = vocab["rel2id"]
id2rel = vocab["id2rel"]

# model
weights_path = hf_hub_download(repo_id, "pytorch_model.bin")
model = BertRE(num_labels=len(rel2id))
model.load_state_dict(torch.load(weights_path, map_location="cpu"))
model.eval()

relation_lookup = build_relation_lookup()


def predict_relation(sentence):
    input_ids, mask, sub_pos, obj_pos = encode(sentence)

    if len(sub_pos) == 0 or len(obj_pos) == 0:
        return None, 0.0

    with torch.no_grad():
        logits = model(input_ids, mask, sub_pos, obj_pos)

    probs = F.softmax(logits, dim=-1)

    pred = torch.argmax(probs, dim=-1).item()
    conf = probs[0, pred].item()

    return id2rel[pred], conf


def relation_extractor(sentence):
    entities = entities_and_types(sentence)

    output = []
    entity_items = list(entities.items())
    pairs = list(permutations(entity_items, 2))

    for (ent1, type1), (ent2, type2) in pairs:

        valid_rels = relation_lookup.get(type1, {}).get(type2, [])
        if not valid_rels:
            continue

        marked = insert_markers(sentence, ent1, ent2)
        if marked is None:
            continue

        rel, conf = predict_relation(marked)

        if rel is None:
            continue

        rel_clean = rel.split(".")[-1]

        if conf > 0.80 and rel != "no_relation" and rel_clean in valid_rels:
            output.append({
                "Subject": {"Type": type1, "Label": ent1},
                "Relation": rel,
                "Object": {"Type": type2, "Label": ent2},
                "Confidence": round(conf, 4)
            })

    return output