from fastapi import FastAPI import torch import pickle from huggingface_hub import hf_hub_download, snapshot_download from Nested.nn.BertSeqTagger import BertSeqTagger from transformers import AutoTokenizer, AutoModel import inspect from collections import namedtuple from Nested.utils.helpers import load_checkpoint from Nested.utils.data import get_dataloaders, text2segments import json from pydantic import BaseModel from fastapi.responses import JSONResponse from IBO_to_XML import IBO_to_XML from XML_to_HTML import NER_XML_to_HTML from NER_Distiller import distill_entities app = FastAPI() pretrained_path = "aubmindlab/bert-base-arabertv2" # must match training tokenizer = AutoTokenizer.from_pretrained(pretrained_path) encoder = AutoModel.from_pretrained(pretrained_path).eval() checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/") args_path = hf_hub_download( repo_id="SinaLab/Nested", filename="args.json" ) with open(args_path, 'r') as f: args_data = json.load(f) # Load model with open("Nested/utils/tag_vocab.pkl", "rb") as f: label_vocab = pickle.load(f) label_vocab = label_vocab[0] # the list loaded from pickle id2label = {i: s for i, s in enumerate(label_vocab.itos)} def split_text_into_groups_of_Ns(sentence, max_words_per_sentence): # Split the text into words words = sentence.split() # Initialize variables groups = [] current_group = "" group_size = 0 # Iterate through the words for word in words: if group_size < max_words_per_sentence - 1: if len(current_group) == 0: current_group = word else: current_group += " " + word group_size += 1 else: current_group += " " + word groups.append(current_group) current_group = "" group_size = 0 # Add the last group if it contains less than n words if current_group: groups.append(current_group) return groups def remove_empty_values(sentences): return [value for value in sentences if value != ''] def sentence_tokenizer(text, dot=True, new_line=True, question_mark=True, exclamation_mark=True): separators = [] split_text = [text] if new_line==True: separators.append('\n') if dot==True: separators.append('.') if question_mark==True: separators.append('?') separators.append('؟') if exclamation_mark==True: separators.append('!') for sep in separators: new_split_text = [] for part in split_text: tokens = part.split(sep) tokens_with_separator = [token + sep for token in tokens[:-1]] tokens_with_separator.append(tokens[-1].strip()) new_split_text.extend(tokens_with_separator) split_text = new_split_text split_text = remove_empty_values(split_text) return split_text def jsons_to_list_of_lists(json_list): return [[d['token'], d['tags']] for d in json_list] tagger, tag_vocab, train_config = load_checkpoint(checkpoint_path) def extract(sentence): dataset, token_vocab = text2segments(sentence) vocabs = namedtuple("Vocab", ["tags", "tokens"]) vocab = vocabs(tokens=token_vocab, tags=tag_vocab) dataloader = get_dataloaders( (dataset,), vocab, args_data, batch_size=32, shuffle=(False,), )[0] segments = tagger.infer(dataloader) lists = [] for segment in segments: for token in segment: item = {} item["token"] = token.text list_of_tags = [t["tag"] for t in token.pred_tag] list_of_tags = [i for i in list_of_tags if i not in ("O", " ", "")] if not list_of_tags: item["tags"] = "O" else: item["tags"] = " ".join(list_of_tags) lists.append(item) return lists def NER(sentence, mode): output_list = [] xml = "" if mode.strip() == "1": output_list = jsons_to_list_of_lists(extract(sentence)) return output_list elif mode.strip() == "2": if output_list != []: xml = IBO_to_XML(output_list) return xml else: output_list = jsons_to_list_of_lists(extract(sentence)) xml = IBO_to_XML(output_list) return xml elif mode.strip() == "3": if xml != "": html = NER_XML_to_HTML(xml) return html else: output_list = jsons_to_list_of_lists(extract(sentence)) xml = IBO_to_XML(output_list) html = NER_XML_to_HTML(xml) return html elif mode.strip() == "4": # json short if output_list != []: json_short = distill_entities(output_list) return json_short else: output_list = jsons_to_list_of_lists(extract(sentence)) json_short = distill_entities(output_list) return json_short class NERRequest(BaseModel): text: str mode: str @app.post("/predict") def predict(request: NERRequest): # Load tagger text = request.text mode = request.mode sentences = sentence_tokenizer( text, dot=False, new_line=True, question_mark=False, exclamation_mark=False ) lists = [] for sentence in sentences: se = split_text_into_groups_of_Ns(sentence, max_words_per_sentence=300) for s in se: output_list = NER(s, mode) lists.append(output_list) content = { "resp": lists, "statusText": "OK", "statusCode": 0, } return JSONResponse( content=content, media_type="application/json", status_code=200, ) # ============ Relation Extraction ============== import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedTokenizerFast, BertModel from itertools import permutations from collections import defaultdict # ========================= # Relation Extraction Model # ========================= repo_id = "aaljabari/arabic-relation-extraction-v1" # tokenizer relation_tokenizer = PreTrainedTokenizerFast( tokenizer_file=hf_hub_download(repo_id, "tokenizer.json") ) # vocab rel_vocab_path = hf_hub_download(repo_id, "tag_vocab.pkl") with open(rel_vocab_path, "rb") as f: vocab = pickle.load(f) rel2id = vocab["rel2id"] id2rel = vocab["id2rel"] class BertRE(nn.Module): def __init__(self, num_labels): super().__init__() self.bert = BertModel.from_pretrained(repo_id) hidden = self.bert.config.hidden_size self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob) self.classifier = nn.Linear(hidden * 2, num_labels) def forward(self, input_ids, attention_mask, sub_pos, obj_pos): outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask ) hidden = outputs.last_hidden_state batch = hidden.shape[0] sub_vec = hidden[torch.arange(batch), sub_pos] obj_vec = hidden[torch.arange(batch), obj_pos] pair = torch.cat([sub_vec, obj_vec], dim=1) pair = self.dropout(pair) return self.classifier(pair) weights_path = hf_hub_download(repo_id, "pytorch_model.bin") re_model = BertRE(num_labels=len(rel2id)) re_model.load_state_dict(torch.load(weights_path, map_location="cpu")) re_model.eval() def convert_ner_format(ner_output): return [[item["token"], item["tags"]] for item in ner_output] def entities_and_types(sentence): ner_output = extract(sentence) converted = convert_ner_format(ner_output) entities = distill_entities(converted) entity_dict = {} for name, entity_type, _, _ in entities: entity_dict[name] = entity_type return entity_dict relation_domain_range=[ { "relation": "manager_of", "domain": ["PERS"], "range": ["ORG", "FAC"] }, { "relation": "birth_date", "domain": ["PERS"], "range": ["DATE"] }, { "relation": "has_parent", "domain": ["PERS"], "range": ["PERS"] }, { "relation": "has_sibling", "domain": ["PERS"], "range": ["PERS"] }, { "relation": "has_spouse", "domain": ["PERS"], "range": ["PERS"] }, { "relation": "has_relative", "domain": ["PERS"], "range": ["PERS"] }, { "relation": "death_date", "domain": ["PERS"], "range": ["DATE"] }, { "relation": "birth_place", "domain": ["PERS"], "range": ["GPE", "LOC"] }, { "relation": "has_occupation", "domain": ["PERS"], "range": ["OCC"] }, { "relation": "has_conflict_with", "domain": ["ORG", "NORP", "GPE"], "range": ["ORG", "NORP", "GPE"] }, { "relation": "has_compititor", "domain": ["PERS", "ORG"], "range": ["PERS", "ORG"] }, { "relation": "has_partner_with", "domain": ["ORG"], "range": ["ORG"] }, { "relation": "president_of", "domain": ["PERS"], "range": ["ORG", "GPE"] }, { "relation": "leader_of", "domain": ["PERS"], "range": ["ORG"] }, { "relation": "geopolitical_division", "domain": ["GPE", "LOC"], "range": ["GPE", "LOC"] }, { "relation": "member_of", "domain": ["PERS"], "range": ["ORG", "NORP"] }, { "relation": "subsidary", "domain": ["ORG"], "range": ["ORG"] }, { "relation": "employee_of", "domain": ["PERS"], "range": ["ORG", "FAC"] }, { "relation": "student_at", "domain": ["PERS"], "range": ["ORG"] }, { "relation": "owner_of", "domain": ["PERS"], "range": ["ORG", "FAC"] }, { "relation": "inventor_of", "domain": ["PERS"], "range": ["PRODUCT"] }, { "relation": "manufacturer_of", "domain": ["ORG"], "range": ["PRODUCT"] }, { "relation": "builder_of", "domain": ["PERS", "NORP"], "range": ["FAC"] }, { "relation": "founder_of", "domain": ["PERS"], "range": ["ORG"] }, { "relation": "lives_in", "domain": ["PERS", "NORP"], "range": ["GPE", "LOC"] }, { "relation": "located_in", "domain": ["FAC", "ORG"], "range": ["GPE", "LOC"] }, { "relation": "headquartered_in", "domain": ["ORG"], "range": ["GPE", "LOC"] }, { "relation": "has_border_with", "domain": ["LOC", "GPE"], "range": ["LOC", "GPE"] }, { "relation": "nearby", "domain": ["GPE", "LOC", "ORG", "FAC"], "range": ["GPE", "LOC", "ORG", "FAC"] }, { "relation": "has_property", "domain": ["ORG"], "range": ["PRODUCT"] }, { "relation": "branch_count", "domain": ["ORG"], "range": ["CARDINAL"] }, { "relation": "has_revenue", "domain": ["ORG"], "range": ["MONEY"] }, { "relation": "employs", "domain": ["ORG"], "range": ["CARDINAL"] }, { "relation": "found_on", "domain": ["ORG"], "range": ["DATE", "TIME"] }, { "relation": "has_alternate_name", "domain": ["ORG", "FAC"], "range": ["ORG", "FAC"] }, { "relation": "has_area", "domain": ["GPE", "LOC"], "range": ["QUANTITY"] }, { "relation": "official_language", "domain": ["GPE", "LOC"], "range": ["LANGUAGE"] }, { "relation": "has_currency", "domain": ["GPE", "LOC"], "range": ["CURR"] }, { "relation": "has_population", "domain": ["GPE"], "range": ["CARDINAL"] }, { "relation": "capital_of", "domain": ["GPE"], "range": ["GPE"] } ] relation_lookup = defaultdict(lambda: defaultdict(list)) for rel in relation_domain_range: for d in rel["domain"]: for r in rel["range"]: relation_lookup[d][r].append(rel["relation"]) def insert_markers(sentence, ent1, ent2): if ent1 not in sentence or ent2 not in sentence: return None marked = sentence marked = marked.replace(ent1, f"[Sub] {ent1} [/Sub]", 1) marked = marked.replace(ent2, f"[Obj] {ent2} [/Obj]", 1) return marked def encode(sentence): enc = relation_tokenizer( sentence, max_length=128, padding="max_length", truncation=True, return_tensors="pt" ) input_ids = enc["input_ids"] attention_mask = enc["attention_mask"] sub_id = relation_tokenizer.convert_tokens_to_ids("[Sub]") obj_id = relation_tokenizer.convert_tokens_to_ids("[Obj]") sub_pos = (input_ids == sub_id).nonzero(as_tuple=True)[1] obj_pos = (input_ids == obj_id).nonzero(as_tuple=True)[1] return input_ids, attention_mask, sub_pos, obj_pos def predict_relation(sentence): input_ids, mask, sub_pos, obj_pos = encode(sentence) if len(sub_pos) == 0 or len(obj_pos) == 0: return None, 0.0 with torch.no_grad(): logits = re_model(input_ids, mask, sub_pos, obj_pos) probs = F.softmax(logits, dim=-1) pred = torch.argmax(probs, dim=-1).item() conf = probs[0, pred].item() return id2rel[pred], conf def relation_extractor(sentence): entities = entities_and_types(sentence) output = [] entity_items = list(entities.items()) pairs = [(e1, e2) for e1, e2 in permutations(entity_items, 2)] for (ent1, type1), (ent2, type2) in pairs: valid_rels = relation_lookup.get(type1, {}).get(type2, []) if not valid_rels: continue marked_sentence = insert_markers(sentence, ent1, ent2) if marked_sentence is None: continue rel, conf = predict_relation(marked_sentence) if rel is None: continue if conf > 0.80 and rel != "no_relation" and rel.split(".")[-1] in valid_rels: output.append({ "Subject": { "Type": type1, "Label": ent1 }, "Relation": rel, "Object": { "Type": type2, "Label": ent2 }, "Confidence": float(round(conf, 4)) }) return output class RERequest(BaseModel): text: str @app.post("/predict_re") def predict_re(request: RERequest): try: results = relation_extractor(request.text) return JSONResponse( content={ "resp": results, "statusText": "OK", "statusCode": 0, }, media_type="application/json", status_code=200, ) except Exception as e: return {"error": str(e)} # =========== Front End ============================= from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse # mount frontend app.mount("/static", StaticFiles(directory="static"), name="static") @app.get("/") def home(): return FileResponse("static/index.html")