Spaces:
Runtime error
Runtime error
File size: 2,083 Bytes
fd20677 61185d9 fd20677 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | import torch
import torch.nn.functional as F
from itertools import permutations
from huggingface_hub import hf_hub_download
import pickle
from .model import BertRE
from .utils import encode, insert_markers
from .ontology import build_relation_lookup
from ner import entities_and_types
repo_id = "aaljabari/arabic-relation-extraction-v1"
# load vocab
rel_vocab_path = hf_hub_download(repo_id, "tag_vocab.pkl")
with open(rel_vocab_path, "rb") as f:
vocab = pickle.load(f)
rel2id = vocab["rel2id"]
id2rel = vocab["id2rel"]
# model
weights_path = hf_hub_download(repo_id, "pytorch_model.bin")
model = BertRE(num_labels=len(rel2id))
model.load_state_dict(torch.load(weights_path, map_location="cpu"))
model.eval()
relation_lookup = build_relation_lookup()
def predict_relation(sentence):
input_ids, mask, sub_pos, obj_pos = encode(sentence)
if len(sub_pos) == 0 or len(obj_pos) == 0:
return None, 0.0
with torch.no_grad():
logits = model(input_ids, mask, sub_pos, obj_pos)
probs = F.softmax(logits, dim=-1)
pred = torch.argmax(probs, dim=-1).item()
conf = probs[0, pred].item()
return id2rel[pred], conf
def relation_extractor(sentence):
entities = entities_and_types(sentence)
output = []
entity_items = list(entities.items())
pairs = list(permutations(entity_items, 2))
for (ent1, type1), (ent2, type2) in pairs:
valid_rels = relation_lookup.get(type1, {}).get(type2, [])
if not valid_rels:
continue
marked = insert_markers(sentence, ent1, ent2)
if marked is None:
continue
rel, conf = predict_relation(marked)
if rel is None:
continue
rel_clean = rel.split(".")[-1]
if conf > 0.80 and rel != "no_relation" and rel_clean in valid_rels:
output.append({
"Subject": {"Type": type1, "Label": ent1},
"Relation": rel,
"Object": {"Type": type2, "Label": ent2},
"Confidence": round(conf, 4)
})
return output |