Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from config import load_config | |
| from models.bilstm_crf import load_model as load_bilstm_crf_model | |
| from models.bilstm import load_model as load_bilstm_model | |
| import torch | |
| from utils import tokenize_sentence | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| configs = load_config() | |
| word2idx = configs['word2idx'] | |
| idx2tag = configs['idx2tag'] | |
| model_bilstm_crf = load_bilstm_crf_model( | |
| vocab_size=len(word2idx), | |
| output_size=len(idx2tag) + 1, # +1 for the <PAD> token | |
| embedding_size=200, | |
| hidden_size=256, | |
| pad_idx=configs['PAD_IDX'] | |
| ) | |
| model_bilstm = load_bilstm_model( | |
| vocab_size=len(word2idx), | |
| output_size=len(idx2tag) + 1, # +1 for the <PAD> token | |
| embedding_size=200, | |
| hidden_size=256, | |
| pad_idx=configs['PAD_IDX'] | |
| ) | |
| def render_ner(tokens_with_labels): | |
| color_map = { | |
| "art": "#f4cccc", | |
| "eve": "#d9ead3", | |
| "geo": "#a0c4ff", | |
| "gpe": "#b6d7a8", | |
| "nat": "#ffe599", | |
| "org": "#d5a6bd", | |
| "per": "#ffb3b3", | |
| "tim": "#ffd6a5", | |
| } | |
| html = "" | |
| for token, label in tokens_with_labels: | |
| if label == "O": | |
| html += f"{token} " | |
| else: | |
| ent_type = label.split("-")[-1] # "geo", "per", ... | |
| color = color_map.get(ent_type, "#dddddd") | |
| html += f'<span style="background-color:{color}; padding:2px 4px; margin:2px; border-radius:4px;">{token} <sub>{ent_type}</sub></span> ' | |
| return html | |
| def predict_sentence(model, sentence: str): | |
| model.eval() | |
| model.to(device) | |
| tokens = tokenize_sentence(sentence) | |
| idxs = [word2idx.get(token, word2idx["<UNK>"]) for token in tokens] | |
| input_tensor = torch.tensor([idxs], dtype=torch.long).to(device) | |
| with torch.inference_mode(): | |
| outputs = model(input_tensor) | |
| preds = outputs.argmax(dim=-1).squeeze(0).tolist() | |
| predicted_labels = [idx2tag[idx] for idx in preds] | |
| return list(zip(tokens, predicted_labels)) | |
| def predict_sentence_crf(model, sentence: str): | |
| model.eval() | |
| model.to(device) | |
| tokens = tokenize_sentence(sentence) | |
| idxs = [word2idx.get(token, word2idx["<UNK>"]) for token in tokens] | |
| input_tensor = torch.tensor([idxs], dtype=torch.long).to(device) | |
| mask = torch.tensor([[1] * len(idxs)], dtype=torch.bool).to(device) | |
| with torch.inference_mode(): | |
| pred_tags = model(input_tensor, mask=mask) | |
| predicted_labels = [idx2tag[idx] for idx in pred_tags[0]] | |
| return list(zip(tokens, predicted_labels)) | |
| def predict(text, model_choice): | |
| if model_choice == "Without CRF": | |
| result = predict_sentence(model_bilstm, text) | |
| else: | |
| result = predict_sentence_crf(model_bilstm_crf, text) | |
| return render_ner(result) | |
| examples = [ | |
| "On July 20, 1969, Neil Armstrong became the first human to walk on the Moon during the Apollo 11 mission.", | |
| "Amazon acquired Whole Foods for $13.7 billion in 2017, expanding its footprint in the U.S. grocery market.", | |
| "The United Nations held a climate summit in Geneva to discuss global warming with representatives from over 100 countries.", | |
| "Taylor Swift performed at the Wembley Stadium in London as part of her Eras Tour in June 2023.", | |
| "The COVID-19 pandemic was declared a global health emergency by the World Health Organization in March 2020.", | |
| "In the movie The Social Network, Jesse Eisenberg portrays Mark Zuckerberg, the founder of Facebook." | |
| ] | |
| examples = [[s] for s in examples] | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Textbox(lines=2, placeholder="Enter a sentence to analyze named entities...", label="Input Sentence"), | |
| gr.Radio(["Without CRF", "CRF"], label="Model Choice", value="Without CRF", type="value") | |
| ], | |
| outputs=gr.HTML(label="NER Output"), | |
| examples=examples, | |
| title="Named Entity Recognition (NER) Demo", | |
| description="Enter an English sentence or select from examples to see entity recognition results." | |
| ) | |
| demo.launch() |