relation-api / main.py
alaajabari's picture
Update main.py
add7f19 verified
from fastapi import FastAPI
import torch
import pickle
from huggingface_hub import hf_hub_download, snapshot_download
from Nested.nn.BertSeqTagger import BertSeqTagger
from transformers import AutoTokenizer, AutoModel
import inspect
from collections import namedtuple
from Nested.utils.helpers import load_checkpoint
from Nested.utils.data import get_dataloaders, text2segments
import json
from pydantic import BaseModel
from fastapi.responses import JSONResponse
from IBO_to_XML import IBO_to_XML
from XML_to_HTML import NER_XML_to_HTML
from NER_Distiller import distill_entities
app = FastAPI()
pretrained_path = "aubmindlab/bert-base-arabertv2" # must match training
tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
encoder = AutoModel.from_pretrained(pretrained_path).eval()
checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")
args_path = hf_hub_download(
repo_id="SinaLab/Nested",
filename="args.json"
)
with open(args_path, 'r') as f:
args_data = json.load(f)
# Load model
with open("Nested/utils/tag_vocab.pkl", "rb") as f:
label_vocab = pickle.load(f)
label_vocab = label_vocab[0] # the list loaded from pickle
id2label = {i: s for i, s in enumerate(label_vocab.itos)}
def split_text_into_groups_of_Ns(sentence, max_words_per_sentence):
# Split the text into words
words = sentence.split()
# Initialize variables
groups = []
current_group = ""
group_size = 0
# Iterate through the words
for word in words:
if group_size < max_words_per_sentence - 1:
if len(current_group) == 0:
current_group = word
else:
current_group += " " + word
group_size += 1
else:
current_group += " " + word
groups.append(current_group)
current_group = ""
group_size = 0
# Add the last group if it contains less than n words
if current_group:
groups.append(current_group)
return groups
def remove_empty_values(sentences):
return [value for value in sentences if value != '']
def sentence_tokenizer(text, dot=True, new_line=True, question_mark=True, exclamation_mark=True):
separators = []
split_text = [text]
if new_line==True:
separators.append('\n')
if dot==True:
separators.append('.')
if question_mark==True:
separators.append('?')
separators.append('؟')
if exclamation_mark==True:
separators.append('!')
for sep in separators:
new_split_text = []
for part in split_text:
tokens = part.split(sep)
tokens_with_separator = [token + sep for token in tokens[:-1]]
tokens_with_separator.append(tokens[-1].strip())
new_split_text.extend(tokens_with_separator)
split_text = new_split_text
split_text = remove_empty_values(split_text)
return split_text
def jsons_to_list_of_lists(json_list):
return [[d['token'], d['tags']] for d in json_list]
tagger, tag_vocab, train_config = load_checkpoint(checkpoint_path)
def extract(sentence):
dataset, token_vocab = text2segments(sentence)
vocabs = namedtuple("Vocab", ["tags", "tokens"])
vocab = vocabs(tokens=token_vocab, tags=tag_vocab)
dataloader = get_dataloaders(
(dataset,),
vocab,
args_data,
batch_size=32,
shuffle=(False,),
)[0]
segments = tagger.infer(dataloader)
lists = []
for segment in segments:
for token in segment:
item = {}
item["token"] = token.text
list_of_tags = [t["tag"] for t in token.pred_tag]
list_of_tags = [i for i in list_of_tags if i not in ("O", " ", "")]
if not list_of_tags:
item["tags"] = "O"
else:
item["tags"] = " ".join(list_of_tags)
lists.append(item)
return lists
def NER(sentence, mode):
output_list = []
xml = ""
if mode.strip() == "1":
output_list = jsons_to_list_of_lists(extract(sentence))
return output_list
elif mode.strip() == "2":
if output_list != []:
xml = IBO_to_XML(output_list)
return xml
else:
output_list = jsons_to_list_of_lists(extract(sentence))
xml = IBO_to_XML(output_list)
return xml
elif mode.strip() == "3":
if xml != "":
html = NER_XML_to_HTML(xml)
return html
else:
output_list = jsons_to_list_of_lists(extract(sentence))
xml = IBO_to_XML(output_list)
html = NER_XML_to_HTML(xml)
return html
elif mode.strip() == "4": # json short
if output_list != []:
json_short = distill_entities(output_list)
return json_short
else:
output_list = jsons_to_list_of_lists(extract(sentence))
json_short = distill_entities(output_list)
return json_short
class NERRequest(BaseModel):
text: str
mode: str
@app.post("/predict")
def predict(request: NERRequest):
# Load tagger
text = request.text
mode = request.mode
sentences = sentence_tokenizer(
text, dot=False, new_line=True, question_mark=False, exclamation_mark=False
)
lists = []
for sentence in sentences:
se = split_text_into_groups_of_Ns(sentence, max_words_per_sentence=300)
for s in se:
output_list = NER(s, mode)
lists.append(output_list)
content = {
"resp": lists,
"statusText": "OK",
"statusCode": 0,
}
return JSONResponse(
content=content,
media_type="application/json",
status_code=200,
)
# ============ Relation Extraction ==============
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedTokenizerFast, BertModel
from itertools import permutations
from collections import defaultdict
# =========================
# Relation Extraction Model
# =========================
repo_id = "aaljabari/arabic-relation-extraction-v1"
# tokenizer
relation_tokenizer = PreTrainedTokenizerFast(
tokenizer_file=hf_hub_download(repo_id, "tokenizer.json")
)
# vocab
rel_vocab_path = hf_hub_download(repo_id, "tag_vocab.pkl")
with open(rel_vocab_path, "rb") as f:
vocab = pickle.load(f)
rel2id = vocab["rel2id"]
id2rel = vocab["id2rel"]
class BertRE(nn.Module):
def __init__(self, num_labels):
super().__init__()
self.bert = BertModel.from_pretrained(repo_id)
hidden = self.bert.config.hidden_size
self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
self.classifier = nn.Linear(hidden * 2, num_labels)
def forward(self, input_ids, attention_mask, sub_pos, obj_pos):
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask
)
hidden = outputs.last_hidden_state
batch = hidden.shape[0]
sub_vec = hidden[torch.arange(batch), sub_pos]
obj_vec = hidden[torch.arange(batch), obj_pos]
pair = torch.cat([sub_vec, obj_vec], dim=1)
pair = self.dropout(pair)
return self.classifier(pair)
weights_path = hf_hub_download(repo_id, "pytorch_model.bin")
re_model = BertRE(num_labels=len(rel2id))
re_model.load_state_dict(torch.load(weights_path, map_location="cpu"))
re_model.eval()
def convert_ner_format(ner_output):
return [[item["token"], item["tags"]] for item in ner_output]
def entities_and_types(sentence):
ner_output = extract(sentence)
converted = convert_ner_format(ner_output)
entities = distill_entities(converted)
entity_dict = {}
for name, entity_type, _, _ in entities:
entity_dict[name] = entity_type
return entity_dict
relation_domain_range=[
{
"relation": "manager_of",
"domain": ["PERS"],
"range": ["ORG", "FAC"]
},
{
"relation": "birth_date",
"domain": ["PERS"],
"range": ["DATE"]
},
{
"relation": "has_parent",
"domain": ["PERS"],
"range": ["PERS"]
},
{
"relation": "has_sibling",
"domain": ["PERS"],
"range": ["PERS"]
},
{
"relation": "has_spouse",
"domain": ["PERS"],
"range": ["PERS"]
},
{
"relation": "has_relative",
"domain": ["PERS"],
"range": ["PERS"]
},
{
"relation": "death_date",
"domain": ["PERS"],
"range": ["DATE"]
},
{
"relation": "birth_place",
"domain": ["PERS"],
"range": ["GPE", "LOC"]
},
{
"relation": "has_occupation",
"domain": ["PERS"],
"range": ["OCC"]
},
{
"relation": "has_conflict_with",
"domain": ["ORG", "NORP", "GPE"],
"range": ["ORG", "NORP", "GPE"]
},
{
"relation": "has_compititor",
"domain": ["PERS", "ORG"],
"range": ["PERS", "ORG"]
},
{
"relation": "has_partner_with",
"domain": ["ORG"],
"range": ["ORG"]
},
{
"relation": "president_of",
"domain": ["PERS"],
"range": ["ORG", "GPE"]
},
{
"relation": "leader_of",
"domain": ["PERS"],
"range": ["ORG"]
},
{
"relation": "geopolitical_division",
"domain": ["GPE", "LOC"],
"range": ["GPE", "LOC"]
},
{
"relation": "member_of",
"domain": ["PERS"],
"range": ["ORG", "NORP"]
},
{
"relation": "subsidary",
"domain": ["ORG"],
"range": ["ORG"]
},
{
"relation": "employee_of",
"domain": ["PERS"],
"range": ["ORG", "FAC"]
},
{
"relation": "student_at",
"domain": ["PERS"],
"range": ["ORG"]
},
{
"relation": "owner_of",
"domain": ["PERS"],
"range": ["ORG", "FAC"]
},
{
"relation": "inventor_of",
"domain": ["PERS"],
"range": ["PRODUCT"]
},
{
"relation": "manufacturer_of",
"domain": ["ORG"],
"range": ["PRODUCT"]
},
{
"relation": "builder_of",
"domain": ["PERS", "NORP"],
"range": ["FAC"]
},
{
"relation": "founder_of",
"domain": ["PERS"],
"range": ["ORG"]
},
{
"relation": "lives_in",
"domain": ["PERS", "NORP"],
"range": ["GPE", "LOC"]
},
{
"relation": "located_in",
"domain": ["FAC", "ORG"],
"range": ["GPE", "LOC"]
},
{
"relation": "headquartered_in",
"domain": ["ORG"],
"range": ["GPE", "LOC"]
},
{
"relation": "has_border_with",
"domain": ["LOC", "GPE"],
"range": ["LOC", "GPE"]
},
{
"relation": "nearby",
"domain": ["GPE", "LOC", "ORG", "FAC"],
"range": ["GPE", "LOC", "ORG", "FAC"]
},
{
"relation": "has_property",
"domain": ["ORG"],
"range": ["PRODUCT"]
},
{
"relation": "branch_count",
"domain": ["ORG"],
"range": ["CARDINAL"]
},
{
"relation": "has_revenue",
"domain": ["ORG"],
"range": ["MONEY"]
},
{
"relation": "employs",
"domain": ["ORG"],
"range": ["CARDINAL"]
},
{
"relation": "found_on",
"domain": ["ORG"],
"range": ["DATE", "TIME"]
},
{
"relation": "has_alternate_name",
"domain": ["ORG", "FAC"],
"range": ["ORG", "FAC"]
},
{
"relation": "has_area",
"domain": ["GPE", "LOC"],
"range": ["QUANTITY"]
},
{
"relation": "official_language",
"domain": ["GPE", "LOC"],
"range": ["LANGUAGE"]
},
{
"relation": "has_currency",
"domain": ["GPE", "LOC"],
"range": ["CURR"]
},
{
"relation": "has_population",
"domain": ["GPE"],
"range": ["CARDINAL"]
},
{
"relation": "capital_of",
"domain": ["GPE"],
"range": ["GPE"]
}
]
relation_lookup = defaultdict(lambda: defaultdict(list))
for rel in relation_domain_range:
for d in rel["domain"]:
for r in rel["range"]:
relation_lookup[d][r].append(rel["relation"])
def insert_markers(sentence, ent1, ent2):
if ent1 not in sentence or ent2 not in sentence:
return None
marked = sentence
marked = marked.replace(ent1, f"[Sub] {ent1} [/Sub]", 1)
marked = marked.replace(ent2, f"[Obj] {ent2} [/Obj]", 1)
return marked
def encode(sentence):
enc = relation_tokenizer(
sentence,
max_length=128,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids = enc["input_ids"]
attention_mask = enc["attention_mask"]
sub_id = relation_tokenizer.convert_tokens_to_ids("[Sub]")
obj_id = relation_tokenizer.convert_tokens_to_ids("[Obj]")
sub_pos = (input_ids == sub_id).nonzero(as_tuple=True)[1]
obj_pos = (input_ids == obj_id).nonzero(as_tuple=True)[1]
return input_ids, attention_mask, sub_pos, obj_pos
def predict_relation(sentence):
input_ids, mask, sub_pos, obj_pos = encode(sentence)
if len(sub_pos) == 0 or len(obj_pos) == 0:
return None, 0.0
with torch.no_grad():
logits = re_model(input_ids, mask, sub_pos, obj_pos)
probs = F.softmax(logits, dim=-1)
pred = torch.argmax(probs, dim=-1).item()
conf = probs[0, pred].item()
return id2rel[pred], conf
def relation_extractor(sentence):
entities = entities_and_types(sentence)
output = []
entity_items = list(entities.items())
pairs = [(e1, e2) for e1, e2 in permutations(entity_items, 2)]
for (ent1, type1), (ent2, type2) in pairs:
valid_rels = relation_lookup.get(type1, {}).get(type2, [])
if not valid_rels:
continue
marked_sentence = insert_markers(sentence, ent1, ent2)
if marked_sentence is None:
continue
rel, conf = predict_relation(marked_sentence)
if rel is None:
continue
if conf > 0.80 and rel != "no_relation" and rel.split(".")[-1] in valid_rels:
output.append({
"Subject": {
"Type": type1,
"Label": ent1
},
"Relation": rel,
"Object": {
"Type": type2,
"Label": ent2
},
"Confidence": float(round(conf, 4))
})
return output
class RERequest(BaseModel):
text: str
@app.post("/predict_re")
def predict_re(request: RERequest):
try:
results = relation_extractor(request.text)
return JSONResponse(
content={
"resp": results,
"statusText": "OK",
"statusCode": 0,
},
media_type="application/json",
status_code=200,
)
except Exception as e:
return {"error": str(e)}
# =========== Front End =============================
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
# mount frontend
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/")
def home():
return FileResponse("static/index.html")