NER_17_labels / app.py
dungquang's picture
Add: models, predict fn, gradio interface
f033e68 verified
import gradio as gr
from config import load_config
from models.bilstm_crf import load_model as load_bilstm_crf_model
from models.bilstm import load_model as load_bilstm_model
import torch
from utils import tokenize_sentence
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
configs = load_config()
word2idx = configs['word2idx']
idx2tag = configs['idx2tag']
model_bilstm_crf = load_bilstm_crf_model(
vocab_size=len(word2idx),
output_size=len(idx2tag) + 1, # +1 for the <PAD> token
embedding_size=200,
hidden_size=256,
pad_idx=configs['PAD_IDX']
)
model_bilstm = load_bilstm_model(
vocab_size=len(word2idx),
output_size=len(idx2tag) + 1, # +1 for the <PAD> token
embedding_size=200,
hidden_size=256,
pad_idx=configs['PAD_IDX']
)
def render_ner(tokens_with_labels):
color_map = {
"art": "#f4cccc",
"eve": "#d9ead3",
"geo": "#a0c4ff",
"gpe": "#b6d7a8",
"nat": "#ffe599",
"org": "#d5a6bd",
"per": "#ffb3b3",
"tim": "#ffd6a5",
}
html = ""
for token, label in tokens_with_labels:
if label == "O":
html += f"{token} "
else:
ent_type = label.split("-")[-1] # "geo", "per", ...
color = color_map.get(ent_type, "#dddddd")
html += f'<span style="background-color:{color}; padding:2px 4px; margin:2px; border-radius:4px;">{token} <sub>{ent_type}</sub></span> '
return html
def predict_sentence(model, sentence: str):
model.eval()
model.to(device)
tokens = tokenize_sentence(sentence)
idxs = [word2idx.get(token, word2idx["<UNK>"]) for token in tokens]
input_tensor = torch.tensor([idxs], dtype=torch.long).to(device)
with torch.inference_mode():
outputs = model(input_tensor)
preds = outputs.argmax(dim=-1).squeeze(0).tolist()
predicted_labels = [idx2tag[idx] for idx in preds]
return list(zip(tokens, predicted_labels))
def predict_sentence_crf(model, sentence: str):
model.eval()
model.to(device)
tokens = tokenize_sentence(sentence)
idxs = [word2idx.get(token, word2idx["<UNK>"]) for token in tokens]
input_tensor = torch.tensor([idxs], dtype=torch.long).to(device)
mask = torch.tensor([[1] * len(idxs)], dtype=torch.bool).to(device)
with torch.inference_mode():
pred_tags = model(input_tensor, mask=mask)
predicted_labels = [idx2tag[idx] for idx in pred_tags[0]]
return list(zip(tokens, predicted_labels))
def predict(text, model_choice):
if model_choice == "Without CRF":
result = predict_sentence(model_bilstm, text)
else:
result = predict_sentence_crf(model_bilstm_crf, text)
return render_ner(result)
examples = [
"On July 20, 1969, Neil Armstrong became the first human to walk on the Moon during the Apollo 11 mission.",
"Amazon acquired Whole Foods for $13.7 billion in 2017, expanding its footprint in the U.S. grocery market.",
"The United Nations held a climate summit in Geneva to discuss global warming with representatives from over 100 countries.",
"Taylor Swift performed at the Wembley Stadium in London as part of her Eras Tour in June 2023.",
"The COVID-19 pandemic was declared a global health emergency by the World Health Organization in March 2020.",
"In the movie The Social Network, Jesse Eisenberg portrays Mark Zuckerberg, the founder of Facebook."
]
examples = [[s] for s in examples]
demo = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(lines=2, placeholder="Enter a sentence to analyze named entities...", label="Input Sentence"),
gr.Radio(["Without CRF", "CRF"], label="Model Choice", value="Without CRF", type="value")
],
outputs=gr.HTML(label="NER Output"),
examples=examples,
title="Named Entity Recognition (NER) Demo",
description="Enter an English sentence or select from examples to see entity recognition results."
)
demo.launch()