TymaaHammouda's picture
Update app.py
b5d6f8a verified
from fastapi import FastAPI
from huggingface_hub import hf_hub_download, snapshot_download
from Nested.nn.BertSeqTagger import BertSeqTagger
import os
from pydantic import BaseModel
from fastapi.responses import JSONResponse
from transformers import AutoTokenizer, AutoModel
import json
from IBO_to_XML import IBO_to_XML
from XML_to_HTML import NER_XML_to_HTML
from NER_Distiller import distill_entities
from collections import namedtuple
from Nested.utils.helpers import load_checkpoint
from Nested.utils.data import get_dataloaders, text2segments
import pickle
print("Version ---- 2")
from huggingface_hub import snapshot_download, hf_hub_download
import os
import shutil
from fastapi import FastAPI
from huggingface_hub import hf_hub_download
import os
from pydantic import BaseModel
from fastapi.responses import JSONResponse
print("Version ---- 2")
app = FastAPI()
pretrained_path = "aubmindlab/bert-base-arabertv2" # must match training
tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
encoder = AutoModel.from_pretrained(pretrained_path).eval()
def download_file_from_hf(repo_id, filename):
target_dir = os.path.expanduser("~/.sinatools/Wj27012000.tar")
os.makedirs(target_dir, exist_ok=True)
file_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=target_dir,
local_dir_use_symlinks=False
)
return file_path
download_file_from_hf("SinaLab/Nested-v1","args.json")
download_file_from_hf("SinaLab/Nested-v1","tag_vocab.pkl")
snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")
checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")
args_path = hf_hub_download(
repo_id="SinaLab/Nested",
filename="args.json"
)
with open(args_path, 'r') as f:
args_data = json.load(f)
# Load model
with open("Nested/utils/tag_vocab.pkl", "rb") as f:
label_vocab = pickle.load(f)
label_vocab = label_vocab[0] # the list loaded from pickle
id2label = {i: s for i, s in enumerate(label_vocab.itos)}
def split_text_into_groups_of_Ns(sentence, max_words_per_sentence):
# Split the text into words
words = sentence.split()
# Initialize variables
groups = []
current_group = ""
group_size = 0
# Iterate through the words
for word in words:
if group_size < max_words_per_sentence - 1:
if len(current_group) == 0:
current_group = word
else:
current_group += " " + word
group_size += 1
else:
current_group += " " + word
groups.append(current_group)
current_group = ""
group_size = 0
# Add the last group if it contains less than n words
if current_group:
groups.append(current_group)
return groups
def remove_empty_values(sentences):
return [value for value in sentences if value != '']
def sentence_tokenizer(text, dot=True, new_line=True, question_mark=True, exclamation_mark=True):
separators = []
split_text = [text]
if new_line==True:
separators.append('\n')
if dot==True:
separators.append('.')
if question_mark==True:
separators.append('?')
separators.append('؟')
if exclamation_mark==True:
separators.append('!')
for sep in separators:
new_split_text = []
for part in split_text:
tokens = part.split(sep)
tokens_with_separator = [token + sep for token in tokens[:-1]]
tokens_with_separator.append(tokens[-1].strip())
new_split_text.extend(tokens_with_separator)
split_text = new_split_text
split_text = remove_empty_values(split_text)
return split_text
def jsons_to_list_of_lists(json_list):
return [[d['token'], d['tags']] for d in json_list]
tagger, tag_vocab, train_config = load_checkpoint(checkpoint_path)
def extract(sentence):
dataset, token_vocab = text2segments(sentence)
vocabs = namedtuple("Vocab", ["tags", "tokens"])
vocab = vocabs(tokens=token_vocab, tags=tag_vocab)
dataloader = get_dataloaders(
(dataset,),
vocab,
args_data,
batch_size=32,
shuffle=(False,),
)[0]
segments = tagger.infer(dataloader)
lists = []
for segment in segments:
for token in segment:
item = {}
item["token"] = token.text
list_of_tags = [t["tag"] for t in token.pred_tag]
list_of_tags = [i for i in list_of_tags if i not in ("O", " ", "")]
if not list_of_tags:
item["tags"] = "O"
else:
item["tags"] = " ".join(list_of_tags)
lists.append(item)
return lists
def NER(sentence, mode):
output_list = []
xml = ""
if mode.strip() == "1":
output_list = jsons_to_list_of_lists(extract(sentence))
return output_list
elif mode.strip() == "2":
if output_list != []:
xml = IBO_to_XML(output_list)
return xml
else:
output_list = jsons_to_list_of_lists(extract(sentence))
xml = IBO_to_XML(output_list)
return xml
elif mode.strip() == "3":
if xml != "":
html = NER_XML_to_HTML(xml)
return html
else:
output_list = jsons_to_list_of_lists(extract(sentence))
xml = IBO_to_XML(output_list)
html = NER_XML_to_HTML(xml)
return html
elif mode.strip() == "4": # json short
if output_list != []:
json_short = distill_entities(output_list)
return json_short
else:
output_list = jsons_to_list_of_lists(extract(sentence))
json_short = distill_entities(output_list)
return json_short
BASE_DIR = os.path.expanduser("~/.sinatools")
NER_DIR = os.path.join(BASE_DIR, "Wj27012000.tar")
# Paths expected by sinatools
RELATION_MODEL_DIR = os.path.join(BASE_DIR, "relation_model")
os.makedirs(BASE_DIR, exist_ok=True)
# -------------------------
# 1. Download relation model
# -------------------------
if not os.path.exists(RELATION_MODEL_DIR) or not os.listdir(RELATION_MODEL_DIR):
snapshot_download(
repo_id="aaljabari/arabic-relation-extraction-model",
local_dir=RELATION_MODEL_DIR,
local_dir_use_symlinks=False
)
if not os.path.exists(NER_DIR):
os.makedirs(NER_DIR, exist_ok=True)
nested_repo_path = snapshot_download(
repo_id="SinaLab/Nested"
)
# Copy tag_vocab.pkl to expected location
src_vocab = os.path.join(nested_repo_path, "Nested", "utils", "tag_vocab.pkl")
dst_vocab = os.path.join(NER_DIR, "tag_vocab.pkl")
if os.path.exists(src_vocab):
shutil.copy(src_vocab, dst_vocab)
from sinatools.relations.relation_extractor import relation_extraction
from sinatools.relations.event_relation_extractor import event_argument_relation_extraction
class RelationRequest(BaseModel):
text: str
@app.post("/predict_relation")
def predict_relation(request: RelationRequest):
text = request.text
result = relation_extraction(text)
content = {"resp": result, "statusText": "OK", "statusCode": 0}
return JSONResponse(
content=content,
media_type="application/json",
status_code=200,
)
@app.post("/predict_event")
def predict_event(request: RelationRequest):
text = request.text
result = event_argument_relation_extraction(text)
content = {"resp": result, "statusText": "OK", "statusCode": 0}
return JSONResponse(
content=content,
media_type="application/json",
status_code=200,
)