Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| from itertools import permutations | |
| from huggingface_hub import hf_hub_download | |
| import pickle | |
| from .model import BertRE | |
| from .utils import encode, insert_markers | |
| from .ontology import build_relation_lookup | |
| from ner import entities_and_types | |
| repo_id = "aaljabari/arabic-relation-extraction-v1" | |
| # load vocab | |
| rel_vocab_path = hf_hub_download(repo_id, "tag_vocab.pkl") | |
| with open(rel_vocab_path, "rb") as f: | |
| vocab = pickle.load(f) | |
| rel2id = vocab["rel2id"] | |
| id2rel = vocab["id2rel"] | |
| # model | |
| weights_path = hf_hub_download(repo_id, "pytorch_model.bin") | |
| model = BertRE(num_labels=len(rel2id)) | |
| model.load_state_dict(torch.load(weights_path, map_location="cpu")) | |
| model.eval() | |
| relation_lookup = build_relation_lookup() | |
| def predict_relation(sentence): | |
| input_ids, mask, sub_pos, obj_pos = encode(sentence) | |
| if len(sub_pos) == 0 or len(obj_pos) == 0: | |
| return None, 0.0 | |
| with torch.no_grad(): | |
| logits = model(input_ids, mask, sub_pos, obj_pos) | |
| probs = F.softmax(logits, dim=-1) | |
| pred = torch.argmax(probs, dim=-1).item() | |
| conf = probs[0, pred].item() | |
| return id2rel[pred], conf | |
| def relation_extractor(sentence): | |
| entities = entities_and_types(sentence) | |
| output = [] | |
| entity_items = list(entities.items()) | |
| pairs = list(permutations(entity_items, 2)) | |
| for (ent1, type1), (ent2, type2) in pairs: | |
| valid_rels = relation_lookup.get(type1, {}).get(type2, []) | |
| if not valid_rels: | |
| continue | |
| marked = insert_markers(sentence, ent1, ent2) | |
| if marked is None: | |
| continue | |
| rel, conf = predict_relation(marked) | |
| if rel is None: | |
| continue | |
| rel_clean = rel.split(".")[-1] | |
| if conf > 0.80 and rel != "no_relation" and rel_clean in valid_rels: | |
| output.append({ | |
| "Subject": {"Type": type1, "Label": ent1}, | |
| "Relation": rel, | |
| "Object": {"Type": type2, "Label": ent2}, | |
| "Confidence": round(conf, 4) | |
| }) | |
| return output |