alaajabari's picture
Update relation_module/inference.py
61185d9 verified
import torch
import torch.nn.functional as F
from itertools import permutations
from huggingface_hub import hf_hub_download
import pickle
from .model import BertRE
from .utils import encode, insert_markers
from .ontology import build_relation_lookup
from ner import entities_and_types
repo_id = "aaljabari/arabic-relation-extraction-v1"
# load vocab
rel_vocab_path = hf_hub_download(repo_id, "tag_vocab.pkl")
with open(rel_vocab_path, "rb") as f:
vocab = pickle.load(f)
rel2id = vocab["rel2id"]
id2rel = vocab["id2rel"]
# model
weights_path = hf_hub_download(repo_id, "pytorch_model.bin")
model = BertRE(num_labels=len(rel2id))
model.load_state_dict(torch.load(weights_path, map_location="cpu"))
model.eval()
relation_lookup = build_relation_lookup()
def predict_relation(sentence):
input_ids, mask, sub_pos, obj_pos = encode(sentence)
if len(sub_pos) == 0 or len(obj_pos) == 0:
return None, 0.0
with torch.no_grad():
logits = model(input_ids, mask, sub_pos, obj_pos)
probs = F.softmax(logits, dim=-1)
pred = torch.argmax(probs, dim=-1).item()
conf = probs[0, pred].item()
return id2rel[pred], conf
def relation_extractor(sentence):
entities = entities_and_types(sentence)
output = []
entity_items = list(entities.items())
pairs = list(permutations(entity_items, 2))
for (ent1, type1), (ent2, type2) in pairs:
valid_rels = relation_lookup.get(type1, {}).get(type2, [])
if not valid_rels:
continue
marked = insert_markers(sentence, ent1, ent2)
if marked is None:
continue
rel, conf = predict_relation(marked)
if rel is None:
continue
rel_clean = rel.split(".")[-1]
if conf > 0.80 and rel != "no_relation" and rel_clean in valid_rels:
output.append({
"Subject": {"Type": type1, "Label": ent1},
"Relation": rel,
"Object": {"Type": type2, "Label": ent2},
"Confidence": round(conf, 4)
})
return output