Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| import torch | |
| import pickle | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| from Nested.nn.BertSeqTagger import BertSeqTagger | |
| from transformers import AutoTokenizer, AutoModel | |
| import inspect | |
| from collections import namedtuple | |
| from Nested.utils.helpers import load_checkpoint | |
| from Nested.utils.data import get_dataloaders, text2segments | |
| import json | |
| from pydantic import BaseModel | |
| from fastapi.responses import JSONResponse | |
| from IBO_to_XML import IBO_to_XML | |
| from XML_to_HTML import NER_XML_to_HTML | |
| from NER_Distiller import distill_entities | |
| app = FastAPI() | |
| pretrained_path = "aubmindlab/bert-base-arabertv2" # must match training | |
| tokenizer = AutoTokenizer.from_pretrained(pretrained_path) | |
| encoder = AutoModel.from_pretrained(pretrained_path).eval() | |
| checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/") | |
| args_path = hf_hub_download( | |
| repo_id="SinaLab/Nested", | |
| filename="args.json" | |
| ) | |
| with open(args_path, 'r') as f: | |
| args_data = json.load(f) | |
| # Load model | |
| with open("Nested/utils/tag_vocab.pkl", "rb") as f: | |
| label_vocab = pickle.load(f) | |
| label_vocab = label_vocab[0] # the list loaded from pickle | |
| id2label = {i: s for i, s in enumerate(label_vocab.itos)} | |
| def split_text_into_groups_of_Ns(sentence, max_words_per_sentence): | |
| # Split the text into words | |
| words = sentence.split() | |
| # Initialize variables | |
| groups = [] | |
| current_group = "" | |
| group_size = 0 | |
| # Iterate through the words | |
| for word in words: | |
| if group_size < max_words_per_sentence - 1: | |
| if len(current_group) == 0: | |
| current_group = word | |
| else: | |
| current_group += " " + word | |
| group_size += 1 | |
| else: | |
| current_group += " " + word | |
| groups.append(current_group) | |
| current_group = "" | |
| group_size = 0 | |
| # Add the last group if it contains less than n words | |
| if current_group: | |
| groups.append(current_group) | |
| return groups | |
| def remove_empty_values(sentences): | |
| return [value for value in sentences if value != ''] | |
| def sentence_tokenizer(text, dot=True, new_line=True, question_mark=True, exclamation_mark=True): | |
| separators = [] | |
| split_text = [text] | |
| if new_line==True: | |
| separators.append('\n') | |
| if dot==True: | |
| separators.append('.') | |
| if question_mark==True: | |
| separators.append('?') | |
| separators.append('؟') | |
| if exclamation_mark==True: | |
| separators.append('!') | |
| for sep in separators: | |
| new_split_text = [] | |
| for part in split_text: | |
| tokens = part.split(sep) | |
| tokens_with_separator = [token + sep for token in tokens[:-1]] | |
| tokens_with_separator.append(tokens[-1].strip()) | |
| new_split_text.extend(tokens_with_separator) | |
| split_text = new_split_text | |
| split_text = remove_empty_values(split_text) | |
| return split_text | |
| def jsons_to_list_of_lists(json_list): | |
| return [[d['token'], d['tags']] for d in json_list] | |
| tagger, tag_vocab, train_config = load_checkpoint(checkpoint_path) | |
| def extract(sentence): | |
| dataset, token_vocab = text2segments(sentence) | |
| vocabs = namedtuple("Vocab", ["tags", "tokens"]) | |
| vocab = vocabs(tokens=token_vocab, tags=tag_vocab) | |
| dataloader = get_dataloaders( | |
| (dataset,), | |
| vocab, | |
| args_data, | |
| batch_size=32, | |
| shuffle=(False,), | |
| )[0] | |
| segments = tagger.infer(dataloader) | |
| lists = [] | |
| for segment in segments: | |
| for token in segment: | |
| item = {} | |
| item["token"] = token.text | |
| list_of_tags = [t["tag"] for t in token.pred_tag] | |
| list_of_tags = [i for i in list_of_tags if i not in ("O", " ", "")] | |
| if not list_of_tags: | |
| item["tags"] = "O" | |
| else: | |
| item["tags"] = " ".join(list_of_tags) | |
| lists.append(item) | |
| return lists | |
| def NER(sentence, mode): | |
| output_list = [] | |
| xml = "" | |
| if mode.strip() == "1": | |
| output_list = jsons_to_list_of_lists(extract(sentence)) | |
| return output_list | |
| elif mode.strip() == "2": | |
| if output_list != []: | |
| xml = IBO_to_XML(output_list) | |
| return xml | |
| else: | |
| output_list = jsons_to_list_of_lists(extract(sentence)) | |
| xml = IBO_to_XML(output_list) | |
| return xml | |
| elif mode.strip() == "3": | |
| if xml != "": | |
| html = NER_XML_to_HTML(xml) | |
| return html | |
| else: | |
| output_list = jsons_to_list_of_lists(extract(sentence)) | |
| xml = IBO_to_XML(output_list) | |
| html = NER_XML_to_HTML(xml) | |
| return html | |
| elif mode.strip() == "4": # json short | |
| if output_list != []: | |
| json_short = distill_entities(output_list) | |
| return json_short | |
| else: | |
| output_list = jsons_to_list_of_lists(extract(sentence)) | |
| json_short = distill_entities(output_list) | |
| return json_short | |
| class NERRequest(BaseModel): | |
| text: str | |
| mode: str | |
| def predict(request: NERRequest): | |
| # Load tagger | |
| text = request.text | |
| mode = request.mode | |
| sentences = sentence_tokenizer( | |
| text, dot=False, new_line=True, question_mark=False, exclamation_mark=False | |
| ) | |
| lists = [] | |
| for sentence in sentences: | |
| se = split_text_into_groups_of_Ns(sentence, max_words_per_sentence=300) | |
| for s in se: | |
| output_list = NER(s, mode) | |
| lists.append(output_list) | |
| content = { | |
| "resp": lists, | |
| "statusText": "OK", | |
| "statusCode": 0, | |
| } | |
| return JSONResponse( | |
| content=content, | |
| media_type="application/json", | |
| status_code=200, | |
| ) | |
| # ============ Relation Extraction ============== | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import PreTrainedTokenizerFast, BertModel | |
| from itertools import permutations | |
| from collections import defaultdict | |
| # ========================= | |
| # Relation Extraction Model | |
| # ========================= | |
| repo_id = "aaljabari/arabic-relation-extraction-v1" | |
| # tokenizer | |
| relation_tokenizer = PreTrainedTokenizerFast( | |
| tokenizer_file=hf_hub_download(repo_id, "tokenizer.json") | |
| ) | |
| # vocab | |
| rel_vocab_path = hf_hub_download(repo_id, "tag_vocab.pkl") | |
| with open(rel_vocab_path, "rb") as f: | |
| vocab = pickle.load(f) | |
| rel2id = vocab["rel2id"] | |
| id2rel = vocab["id2rel"] | |
| class BertRE(nn.Module): | |
| def __init__(self, num_labels): | |
| super().__init__() | |
| self.bert = BertModel.from_pretrained(repo_id) | |
| hidden = self.bert.config.hidden_size | |
| self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob) | |
| self.classifier = nn.Linear(hidden * 2, num_labels) | |
| def forward(self, input_ids, attention_mask, sub_pos, obj_pos): | |
| outputs = self.bert( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask | |
| ) | |
| hidden = outputs.last_hidden_state | |
| batch = hidden.shape[0] | |
| sub_vec = hidden[torch.arange(batch), sub_pos] | |
| obj_vec = hidden[torch.arange(batch), obj_pos] | |
| pair = torch.cat([sub_vec, obj_vec], dim=1) | |
| pair = self.dropout(pair) | |
| return self.classifier(pair) | |
| weights_path = hf_hub_download(repo_id, "pytorch_model.bin") | |
| re_model = BertRE(num_labels=len(rel2id)) | |
| re_model.load_state_dict(torch.load(weights_path, map_location="cpu")) | |
| re_model.eval() | |
| def convert_ner_format(ner_output): | |
| return [[item["token"], item["tags"]] for item in ner_output] | |
| def entities_and_types(sentence): | |
| ner_output = extract(sentence) | |
| converted = convert_ner_format(ner_output) | |
| entities = distill_entities(converted) | |
| entity_dict = {} | |
| for name, entity_type, _, _ in entities: | |
| entity_dict[name] = entity_type | |
| return entity_dict | |
| relation_domain_range=[ | |
| { | |
| "relation": "manager_of", | |
| "domain": ["PERS"], | |
| "range": ["ORG", "FAC"] | |
| }, | |
| { | |
| "relation": "birth_date", | |
| "domain": ["PERS"], | |
| "range": ["DATE"] | |
| }, | |
| { | |
| "relation": "has_parent", | |
| "domain": ["PERS"], | |
| "range": ["PERS"] | |
| }, | |
| { | |
| "relation": "has_sibling", | |
| "domain": ["PERS"], | |
| "range": ["PERS"] | |
| }, | |
| { | |
| "relation": "has_spouse", | |
| "domain": ["PERS"], | |
| "range": ["PERS"] | |
| }, | |
| { | |
| "relation": "has_relative", | |
| "domain": ["PERS"], | |
| "range": ["PERS"] | |
| }, | |
| { | |
| "relation": "death_date", | |
| "domain": ["PERS"], | |
| "range": ["DATE"] | |
| }, | |
| { | |
| "relation": "birth_place", | |
| "domain": ["PERS"], | |
| "range": ["GPE", "LOC"] | |
| }, | |
| { | |
| "relation": "has_occupation", | |
| "domain": ["PERS"], | |
| "range": ["OCC"] | |
| }, | |
| { | |
| "relation": "has_conflict_with", | |
| "domain": ["ORG", "NORP", "GPE"], | |
| "range": ["ORG", "NORP", "GPE"] | |
| }, | |
| { | |
| "relation": "has_compititor", | |
| "domain": ["PERS", "ORG"], | |
| "range": ["PERS", "ORG"] | |
| }, | |
| { | |
| "relation": "has_partner_with", | |
| "domain": ["ORG"], | |
| "range": ["ORG"] | |
| }, | |
| { | |
| "relation": "president_of", | |
| "domain": ["PERS"], | |
| "range": ["ORG", "GPE"] | |
| }, | |
| { | |
| "relation": "leader_of", | |
| "domain": ["PERS"], | |
| "range": ["ORG"] | |
| }, | |
| { | |
| "relation": "geopolitical_division", | |
| "domain": ["GPE", "LOC"], | |
| "range": ["GPE", "LOC"] | |
| }, | |
| { | |
| "relation": "member_of", | |
| "domain": ["PERS"], | |
| "range": ["ORG", "NORP"] | |
| }, | |
| { | |
| "relation": "subsidary", | |
| "domain": ["ORG"], | |
| "range": ["ORG"] | |
| }, | |
| { | |
| "relation": "employee_of", | |
| "domain": ["PERS"], | |
| "range": ["ORG", "FAC"] | |
| }, | |
| { | |
| "relation": "student_at", | |
| "domain": ["PERS"], | |
| "range": ["ORG"] | |
| }, | |
| { | |
| "relation": "owner_of", | |
| "domain": ["PERS"], | |
| "range": ["ORG", "FAC"] | |
| }, | |
| { | |
| "relation": "inventor_of", | |
| "domain": ["PERS"], | |
| "range": ["PRODUCT"] | |
| }, | |
| { | |
| "relation": "manufacturer_of", | |
| "domain": ["ORG"], | |
| "range": ["PRODUCT"] | |
| }, | |
| { | |
| "relation": "builder_of", | |
| "domain": ["PERS", "NORP"], | |
| "range": ["FAC"] | |
| }, | |
| { | |
| "relation": "founder_of", | |
| "domain": ["PERS"], | |
| "range": ["ORG"] | |
| }, | |
| { | |
| "relation": "lives_in", | |
| "domain": ["PERS", "NORP"], | |
| "range": ["GPE", "LOC"] | |
| }, | |
| { | |
| "relation": "located_in", | |
| "domain": ["FAC", "ORG"], | |
| "range": ["GPE", "LOC"] | |
| }, | |
| { | |
| "relation": "headquartered_in", | |
| "domain": ["ORG"], | |
| "range": ["GPE", "LOC"] | |
| }, | |
| { | |
| "relation": "has_border_with", | |
| "domain": ["LOC", "GPE"], | |
| "range": ["LOC", "GPE"] | |
| }, | |
| { | |
| "relation": "nearby", | |
| "domain": ["GPE", "LOC", "ORG", "FAC"], | |
| "range": ["GPE", "LOC", "ORG", "FAC"] | |
| }, | |
| { | |
| "relation": "has_property", | |
| "domain": ["ORG"], | |
| "range": ["PRODUCT"] | |
| }, | |
| { | |
| "relation": "branch_count", | |
| "domain": ["ORG"], | |
| "range": ["CARDINAL"] | |
| }, | |
| { | |
| "relation": "has_revenue", | |
| "domain": ["ORG"], | |
| "range": ["MONEY"] | |
| }, | |
| { | |
| "relation": "employs", | |
| "domain": ["ORG"], | |
| "range": ["CARDINAL"] | |
| }, | |
| { | |
| "relation": "found_on", | |
| "domain": ["ORG"], | |
| "range": ["DATE", "TIME"] | |
| }, | |
| { | |
| "relation": "has_alternate_name", | |
| "domain": ["ORG", "FAC"], | |
| "range": ["ORG", "FAC"] | |
| }, | |
| { | |
| "relation": "has_area", | |
| "domain": ["GPE", "LOC"], | |
| "range": ["QUANTITY"] | |
| }, | |
| { | |
| "relation": "official_language", | |
| "domain": ["GPE", "LOC"], | |
| "range": ["LANGUAGE"] | |
| }, | |
| { | |
| "relation": "has_currency", | |
| "domain": ["GPE", "LOC"], | |
| "range": ["CURR"] | |
| }, | |
| { | |
| "relation": "has_population", | |
| "domain": ["GPE"], | |
| "range": ["CARDINAL"] | |
| }, | |
| { | |
| "relation": "capital_of", | |
| "domain": ["GPE"], | |
| "range": ["GPE"] | |
| } | |
| ] | |
| relation_lookup = defaultdict(lambda: defaultdict(list)) | |
| for rel in relation_domain_range: | |
| for d in rel["domain"]: | |
| for r in rel["range"]: | |
| relation_lookup[d][r].append(rel["relation"]) | |
| def insert_markers(sentence, ent1, ent2): | |
| if ent1 not in sentence or ent2 not in sentence: | |
| return None | |
| marked = sentence | |
| marked = marked.replace(ent1, f"[Sub] {ent1} [/Sub]", 1) | |
| marked = marked.replace(ent2, f"[Obj] {ent2} [/Obj]", 1) | |
| return marked | |
| def encode(sentence): | |
| enc = relation_tokenizer( | |
| sentence, | |
| max_length=128, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| input_ids = enc["input_ids"] | |
| attention_mask = enc["attention_mask"] | |
| sub_id = relation_tokenizer.convert_tokens_to_ids("[Sub]") | |
| obj_id = relation_tokenizer.convert_tokens_to_ids("[Obj]") | |
| sub_pos = (input_ids == sub_id).nonzero(as_tuple=True)[1] | |
| obj_pos = (input_ids == obj_id).nonzero(as_tuple=True)[1] | |
| return input_ids, attention_mask, sub_pos, obj_pos | |
| def predict_relation(sentence): | |
| input_ids, mask, sub_pos, obj_pos = encode(sentence) | |
| if len(sub_pos) == 0 or len(obj_pos) == 0: | |
| return None, 0.0 | |
| with torch.no_grad(): | |
| logits = re_model(input_ids, mask, sub_pos, obj_pos) | |
| probs = F.softmax(logits, dim=-1) | |
| pred = torch.argmax(probs, dim=-1).item() | |
| conf = probs[0, pred].item() | |
| return id2rel[pred], conf | |
| def relation_extractor(sentence): | |
| entities = entities_and_types(sentence) | |
| output = [] | |
| entity_items = list(entities.items()) | |
| pairs = [(e1, e2) for e1, e2 in permutations(entity_items, 2)] | |
| for (ent1, type1), (ent2, type2) in pairs: | |
| valid_rels = relation_lookup.get(type1, {}).get(type2, []) | |
| if not valid_rels: | |
| continue | |
| marked_sentence = insert_markers(sentence, ent1, ent2) | |
| if marked_sentence is None: | |
| continue | |
| rel, conf = predict_relation(marked_sentence) | |
| if rel is None: | |
| continue | |
| if conf > 0.80 and rel != "no_relation" and rel.split(".")[-1] in valid_rels: | |
| output.append({ | |
| "Subject": { | |
| "Type": type1, | |
| "Label": ent1 | |
| }, | |
| "Relation": rel, | |
| "Object": { | |
| "Type": type2, | |
| "Label": ent2 | |
| }, | |
| "Confidence": float(round(conf, 4)) | |
| }) | |
| return output | |
| class RERequest(BaseModel): | |
| text: str | |
| def predict_re(request: RERequest): | |
| try: | |
| results = relation_extractor(request.text) | |
| return JSONResponse( | |
| content={ | |
| "resp": results, | |
| "statusText": "OK", | |
| "statusCode": 0, | |
| }, | |
| media_type="application/json", | |
| status_code=200, | |
| ) | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # =========== Front End ============================= | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| # mount frontend | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| def home(): | |
| return FileResponse("static/index.html") |