from fastapi import FastAPI from huggingface_hub import hf_hub_download, snapshot_download from Nested.nn.BertSeqTagger import BertSeqTagger import os from pydantic import BaseModel from fastapi.responses import JSONResponse from transformers import AutoTokenizer, AutoModel import json from IBO_to_XML import IBO_to_XML from XML_to_HTML import NER_XML_to_HTML from NER_Distiller import distill_entities from collections import namedtuple from Nested.utils.helpers import load_checkpoint from Nested.utils.data import get_dataloaders, text2segments import pickle print("Version ---- 2") from huggingface_hub import snapshot_download, hf_hub_download import os import shutil from fastapi import FastAPI from huggingface_hub import hf_hub_download import os from pydantic import BaseModel from fastapi.responses import JSONResponse print("Version ---- 2") app = FastAPI() pretrained_path = "aubmindlab/bert-base-arabertv2" # must match training tokenizer = AutoTokenizer.from_pretrained(pretrained_path) encoder = AutoModel.from_pretrained(pretrained_path).eval() def download_file_from_hf(repo_id, filename): target_dir = os.path.expanduser("~/.sinatools/Wj27012000.tar") os.makedirs(target_dir, exist_ok=True) file_path = hf_hub_download( repo_id=repo_id, filename=filename, local_dir=target_dir, local_dir_use_symlinks=False ) return file_path download_file_from_hf("SinaLab/Nested-v1","args.json") download_file_from_hf("SinaLab/Nested-v1","tag_vocab.pkl") snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/") checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/") args_path = hf_hub_download( repo_id="SinaLab/Nested", filename="args.json" ) with open(args_path, 'r') as f: args_data = json.load(f) # Load model with open("Nested/utils/tag_vocab.pkl", "rb") as f: label_vocab = pickle.load(f) label_vocab = label_vocab[0] # the list loaded from pickle id2label = {i: s for i, s in enumerate(label_vocab.itos)} def split_text_into_groups_of_Ns(sentence, max_words_per_sentence): # Split the text into words words = sentence.split() # Initialize variables groups = [] current_group = "" group_size = 0 # Iterate through the words for word in words: if group_size < max_words_per_sentence - 1: if len(current_group) == 0: current_group = word else: current_group += " " + word group_size += 1 else: current_group += " " + word groups.append(current_group) current_group = "" group_size = 0 # Add the last group if it contains less than n words if current_group: groups.append(current_group) return groups def remove_empty_values(sentences): return [value for value in sentences if value != ''] def sentence_tokenizer(text, dot=True, new_line=True, question_mark=True, exclamation_mark=True): separators = [] split_text = [text] if new_line==True: separators.append('\n') if dot==True: separators.append('.') if question_mark==True: separators.append('?') separators.append('؟') if exclamation_mark==True: separators.append('!') for sep in separators: new_split_text = [] for part in split_text: tokens = part.split(sep) tokens_with_separator = [token + sep for token in tokens[:-1]] tokens_with_separator.append(tokens[-1].strip()) new_split_text.extend(tokens_with_separator) split_text = new_split_text split_text = remove_empty_values(split_text) return split_text def jsons_to_list_of_lists(json_list): return [[d['token'], d['tags']] for d in json_list] tagger, tag_vocab, train_config = load_checkpoint(checkpoint_path) def extract(sentence): dataset, token_vocab = text2segments(sentence) vocabs = namedtuple("Vocab", ["tags", "tokens"]) vocab = vocabs(tokens=token_vocab, tags=tag_vocab) dataloader = get_dataloaders( (dataset,), vocab, args_data, batch_size=32, shuffle=(False,), )[0] segments = tagger.infer(dataloader) lists = [] for segment in segments: for token in segment: item = {} item["token"] = token.text list_of_tags = [t["tag"] for t in token.pred_tag] list_of_tags = [i for i in list_of_tags if i not in ("O", " ", "")] if not list_of_tags: item["tags"] = "O" else: item["tags"] = " ".join(list_of_tags) lists.append(item) return lists def NER(sentence, mode): output_list = [] xml = "" if mode.strip() == "1": output_list = jsons_to_list_of_lists(extract(sentence)) return output_list elif mode.strip() == "2": if output_list != []: xml = IBO_to_XML(output_list) return xml else: output_list = jsons_to_list_of_lists(extract(sentence)) xml = IBO_to_XML(output_list) return xml elif mode.strip() == "3": if xml != "": html = NER_XML_to_HTML(xml) return html else: output_list = jsons_to_list_of_lists(extract(sentence)) xml = IBO_to_XML(output_list) html = NER_XML_to_HTML(xml) return html elif mode.strip() == "4": # json short if output_list != []: json_short = distill_entities(output_list) return json_short else: output_list = jsons_to_list_of_lists(extract(sentence)) json_short = distill_entities(output_list) return json_short BASE_DIR = os.path.expanduser("~/.sinatools") NER_DIR = os.path.join(BASE_DIR, "Wj27012000.tar") # Paths expected by sinatools RELATION_MODEL_DIR = os.path.join(BASE_DIR, "relation_model") os.makedirs(BASE_DIR, exist_ok=True) # ------------------------- # 1. Download relation model # ------------------------- if not os.path.exists(RELATION_MODEL_DIR) or not os.listdir(RELATION_MODEL_DIR): snapshot_download( repo_id="aaljabari/arabic-relation-extraction-model", local_dir=RELATION_MODEL_DIR, local_dir_use_symlinks=False ) if not os.path.exists(NER_DIR): os.makedirs(NER_DIR, exist_ok=True) nested_repo_path = snapshot_download( repo_id="SinaLab/Nested" ) # Copy tag_vocab.pkl to expected location src_vocab = os.path.join(nested_repo_path, "Nested", "utils", "tag_vocab.pkl") dst_vocab = os.path.join(NER_DIR, "tag_vocab.pkl") if os.path.exists(src_vocab): shutil.copy(src_vocab, dst_vocab) from sinatools.relations.relation_extractor import relation_extraction from sinatools.relations.event_relation_extractor import event_argument_relation_extraction class RelationRequest(BaseModel): text: str @app.post("/predict_relation") def predict_relation(request: RelationRequest): text = request.text result = relation_extraction(text) content = {"resp": result, "statusText": "OK", "statusCode": 0} return JSONResponse( content=content, media_type="application/json", status_code=200, ) @app.post("/predict_event") def predict_event(request: RelationRequest): text = request.text result = event_argument_relation_extraction(text) content = {"resp": result, "statusText": "OK", "statusCode": 0} return JSONResponse( content=content, media_type="application/json", status_code=200, )