Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| from Nested.nn.BertSeqTagger import BertSeqTagger | |
| import os | |
| from pydantic import BaseModel | |
| from fastapi.responses import JSONResponse | |
| from transformers import AutoTokenizer, AutoModel | |
| import json | |
| from IBO_to_XML import IBO_to_XML | |
| from XML_to_HTML import NER_XML_to_HTML | |
| from NER_Distiller import distill_entities | |
| from collections import namedtuple | |
| from Nested.utils.helpers import load_checkpoint | |
| from Nested.utils.data import get_dataloaders, text2segments | |
| import pickle | |
| print("Version ---- 2") | |
| from huggingface_hub import snapshot_download, hf_hub_download | |
| import os | |
| import shutil | |
| from fastapi import FastAPI | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| from pydantic import BaseModel | |
| from fastapi.responses import JSONResponse | |
| print("Version ---- 2") | |
| app = FastAPI() | |
| pretrained_path = "aubmindlab/bert-base-arabertv2" # must match training | |
| tokenizer = AutoTokenizer.from_pretrained(pretrained_path) | |
| encoder = AutoModel.from_pretrained(pretrained_path).eval() | |
| def download_file_from_hf(repo_id, filename): | |
| target_dir = os.path.expanduser("~/.sinatools/Wj27012000.tar") | |
| os.makedirs(target_dir, exist_ok=True) | |
| file_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| local_dir=target_dir, | |
| local_dir_use_symlinks=False | |
| ) | |
| return file_path | |
| download_file_from_hf("SinaLab/Nested-v1","args.json") | |
| download_file_from_hf("SinaLab/Nested-v1","tag_vocab.pkl") | |
| snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/") | |
| checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/") | |
| args_path = hf_hub_download( | |
| repo_id="SinaLab/Nested", | |
| filename="args.json" | |
| ) | |
| with open(args_path, 'r') as f: | |
| args_data = json.load(f) | |
| # Load model | |
| with open("Nested/utils/tag_vocab.pkl", "rb") as f: | |
| label_vocab = pickle.load(f) | |
| label_vocab = label_vocab[0] # the list loaded from pickle | |
| id2label = {i: s for i, s in enumerate(label_vocab.itos)} | |
| def split_text_into_groups_of_Ns(sentence, max_words_per_sentence): | |
| # Split the text into words | |
| words = sentence.split() | |
| # Initialize variables | |
| groups = [] | |
| current_group = "" | |
| group_size = 0 | |
| # Iterate through the words | |
| for word in words: | |
| if group_size < max_words_per_sentence - 1: | |
| if len(current_group) == 0: | |
| current_group = word | |
| else: | |
| current_group += " " + word | |
| group_size += 1 | |
| else: | |
| current_group += " " + word | |
| groups.append(current_group) | |
| current_group = "" | |
| group_size = 0 | |
| # Add the last group if it contains less than n words | |
| if current_group: | |
| groups.append(current_group) | |
| return groups | |
| def remove_empty_values(sentences): | |
| return [value for value in sentences if value != ''] | |
| def sentence_tokenizer(text, dot=True, new_line=True, question_mark=True, exclamation_mark=True): | |
| separators = [] | |
| split_text = [text] | |
| if new_line==True: | |
| separators.append('\n') | |
| if dot==True: | |
| separators.append('.') | |
| if question_mark==True: | |
| separators.append('?') | |
| separators.append('؟') | |
| if exclamation_mark==True: | |
| separators.append('!') | |
| for sep in separators: | |
| new_split_text = [] | |
| for part in split_text: | |
| tokens = part.split(sep) | |
| tokens_with_separator = [token + sep for token in tokens[:-1]] | |
| tokens_with_separator.append(tokens[-1].strip()) | |
| new_split_text.extend(tokens_with_separator) | |
| split_text = new_split_text | |
| split_text = remove_empty_values(split_text) | |
| return split_text | |
| def jsons_to_list_of_lists(json_list): | |
| return [[d['token'], d['tags']] for d in json_list] | |
| tagger, tag_vocab, train_config = load_checkpoint(checkpoint_path) | |
| def extract(sentence): | |
| dataset, token_vocab = text2segments(sentence) | |
| vocabs = namedtuple("Vocab", ["tags", "tokens"]) | |
| vocab = vocabs(tokens=token_vocab, tags=tag_vocab) | |
| dataloader = get_dataloaders( | |
| (dataset,), | |
| vocab, | |
| args_data, | |
| batch_size=32, | |
| shuffle=(False,), | |
| )[0] | |
| segments = tagger.infer(dataloader) | |
| lists = [] | |
| for segment in segments: | |
| for token in segment: | |
| item = {} | |
| item["token"] = token.text | |
| list_of_tags = [t["tag"] for t in token.pred_tag] | |
| list_of_tags = [i for i in list_of_tags if i not in ("O", " ", "")] | |
| if not list_of_tags: | |
| item["tags"] = "O" | |
| else: | |
| item["tags"] = " ".join(list_of_tags) | |
| lists.append(item) | |
| return lists | |
| def NER(sentence, mode): | |
| output_list = [] | |
| xml = "" | |
| if mode.strip() == "1": | |
| output_list = jsons_to_list_of_lists(extract(sentence)) | |
| return output_list | |
| elif mode.strip() == "2": | |
| if output_list != []: | |
| xml = IBO_to_XML(output_list) | |
| return xml | |
| else: | |
| output_list = jsons_to_list_of_lists(extract(sentence)) | |
| xml = IBO_to_XML(output_list) | |
| return xml | |
| elif mode.strip() == "3": | |
| if xml != "": | |
| html = NER_XML_to_HTML(xml) | |
| return html | |
| else: | |
| output_list = jsons_to_list_of_lists(extract(sentence)) | |
| xml = IBO_to_XML(output_list) | |
| html = NER_XML_to_HTML(xml) | |
| return html | |
| elif mode.strip() == "4": # json short | |
| if output_list != []: | |
| json_short = distill_entities(output_list) | |
| return json_short | |
| else: | |
| output_list = jsons_to_list_of_lists(extract(sentence)) | |
| json_short = distill_entities(output_list) | |
| return json_short | |
| BASE_DIR = os.path.expanduser("~/.sinatools") | |
| NER_DIR = os.path.join(BASE_DIR, "Wj27012000.tar") | |
| # Paths expected by sinatools | |
| RELATION_MODEL_DIR = os.path.join(BASE_DIR, "relation_model") | |
| os.makedirs(BASE_DIR, exist_ok=True) | |
| # ------------------------- | |
| # 1. Download relation model | |
| # ------------------------- | |
| if not os.path.exists(RELATION_MODEL_DIR) or not os.listdir(RELATION_MODEL_DIR): | |
| snapshot_download( | |
| repo_id="aaljabari/arabic-relation-extraction-model", | |
| local_dir=RELATION_MODEL_DIR, | |
| local_dir_use_symlinks=False | |
| ) | |
| if not os.path.exists(NER_DIR): | |
| os.makedirs(NER_DIR, exist_ok=True) | |
| nested_repo_path = snapshot_download( | |
| repo_id="SinaLab/Nested" | |
| ) | |
| # Copy tag_vocab.pkl to expected location | |
| src_vocab = os.path.join(nested_repo_path, "Nested", "utils", "tag_vocab.pkl") | |
| dst_vocab = os.path.join(NER_DIR, "tag_vocab.pkl") | |
| if os.path.exists(src_vocab): | |
| shutil.copy(src_vocab, dst_vocab) | |
| from sinatools.relations.relation_extractor import relation_extraction | |
| from sinatools.relations.event_relation_extractor import event_argument_relation_extraction | |
| class RelationRequest(BaseModel): | |
| text: str | |
| def predict_relation(request: RelationRequest): | |
| text = request.text | |
| result = relation_extraction(text) | |
| content = {"resp": result, "statusText": "OK", "statusCode": 0} | |
| return JSONResponse( | |
| content=content, | |
| media_type="application/json", | |
| status_code=200, | |
| ) | |
| def predict_event(request: RelationRequest): | |
| text = request.text | |
| result = event_argument_relation_extraction(text) | |
| content = {"resp": result, "statusText": "OK", "statusCode": 0} | |
| return JSONResponse( | |
| content=content, | |
| media_type="application/json", | |
| status_code=200, | |
| ) | |