import torch import torch.nn.functional as F from itertools import permutations from huggingface_hub import hf_hub_download import pickle from .model import BertRE from .utils import encode, insert_markers from .ontology import build_relation_lookup from ner import entities_and_types repo_id = "aaljabari/arabic-relation-extraction-v1" # load vocab rel_vocab_path = hf_hub_download(repo_id, "tag_vocab.pkl") with open(rel_vocab_path, "rb") as f: vocab = pickle.load(f) rel2id = vocab["rel2id"] id2rel = vocab["id2rel"] # model weights_path = hf_hub_download(repo_id, "pytorch_model.bin") model = BertRE(num_labels=len(rel2id)) model.load_state_dict(torch.load(weights_path, map_location="cpu")) model.eval() relation_lookup = build_relation_lookup() def predict_relation(sentence): input_ids, mask, sub_pos, obj_pos = encode(sentence) if len(sub_pos) == 0 or len(obj_pos) == 0: return None, 0.0 with torch.no_grad(): logits = model(input_ids, mask, sub_pos, obj_pos) probs = F.softmax(logits, dim=-1) pred = torch.argmax(probs, dim=-1).item() conf = probs[0, pred].item() return id2rel[pred], conf def relation_extractor(sentence): entities = entities_and_types(sentence) output = [] entity_items = list(entities.items()) pairs = list(permutations(entity_items, 2)) for (ent1, type1), (ent2, type2) in pairs: valid_rels = relation_lookup.get(type1, {}).get(type2, []) if not valid_rels: continue marked = insert_markers(sentence, ent1, ent2) if marked is None: continue rel, conf = predict_relation(marked) if rel is None: continue rel_clean = rel.split(".")[-1] if conf > 0.80 and rel != "no_relation" and rel_clean in valid_rels: output.append({ "Subject": {"Type": type1, "Label": ent1}, "Relation": rel, "Object": {"Type": type2, "Label": ent2}, "Confidence": round(conf, 4) }) return output