Spaces:
Running
Running
| import time | |
| import spacy | |
| import json | |
| import gradio as gr | |
| from spacy.tokens import Doc, Span | |
| from spacy import displacy | |
| import matplotlib.pyplot as plt | |
| from matplotlib.colors import to_hex | |
| import tempfile | |
| from inference.model_inference import Inference | |
| from configs import * | |
| DESC_MD = """ | |
| <font size="3"> | |
| This space is a demo for <a href="https://arxiv.org/abs/2406.14654"> Major Entity Identification (MEI) </a>. MEI takes entities as additional input and aims to detect the mentions that refer only to these entities. <br/> | |
| <br/> | |
| Place the text in the text box with a single phrase of a selected entity in double curly braces(example: a single instance of {{Ron}} if you want to track Ron). Note that you can select one phrase for each entity and multiple entities can be selected. Check out the example below for clarity. <br/> | |
| <br/> | |
| Static: Uses an instance of: MEIRa-S model <br/> | |
| Hybrid: Uses an instance of: MEIRa-H model <br/> | |
| <br/> | |
| The demo provides a json file with clusters and an HTML file with visualizations. The visualizations are color-coded based on the clusters. <br/> | |
| </font> | |
| """ | |
| def get_MEIRa_clusters(doc_name, text, model_type): | |
| model_str = MODELS[model_type] | |
| model = Inference(model_str) | |
| output_dict = model.perform_coreference(text, doc_name) | |
| return output_dict | |
| def coref_visualizer(doc_name, text, model_type): | |
| coref_output = get_MEIRa_clusters(doc_name, text, model_type) | |
| tokens = coref_output["tokenized_doc"] | |
| clusters = coref_output["clusters"] | |
| labels = coref_output["representative_names"] | |
| ## Get a pastel palette | |
| color_palette = { | |
| label: to_hex(plt.cm.get_cmap("tab20", len(labels))(i)) | |
| for i, label in enumerate(labels) | |
| } | |
| nlp = spacy.blank("en") | |
| doc = Doc(nlp.vocab, words=tokens) | |
| print("Tokens:", tokens, flush=True) | |
| # print("Doc:", doc, flush=True) | |
| print(color_palette) | |
| spans = [] | |
| for cluster_ind, cluster in enumerate(clusters[:-1]): | |
| label = labels[cluster_ind] | |
| for (start, end), mention in cluster: | |
| span = Span(doc, start, end + 1, label=label) | |
| spans.append(span) | |
| doc.spans["coref_spans"] = spans | |
| print("Rendering the visualization...") | |
| # color_map = {label: color_palette[i] for i, label in enumerate(labels)} | |
| # Generate the HTML output | |
| html = displacy.render( | |
| doc, | |
| style="span", | |
| options={ | |
| "spans_key": "coref_spans", | |
| "colors": color_palette, | |
| }, | |
| jupyter=False, | |
| ) | |
| ## Create a hash based on time and doc_name | |
| time_hash = hash(str(time.time()) + doc_name) | |
| # html_file = f"temp/gradio_outputs/output_{time_hash}.html" | |
| # json_file = f"temp/gradio_outputs/output_{time_hash}.json" | |
| # Create a temporary HTML file | |
| with tempfile.NamedTemporaryFile(suffix=".html", delete=False) as tmp_html_file: | |
| html_file = tmp_html_file.name | |
| tmp_html_file.write(html.encode("utf-8")) | |
| with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp_json_file: | |
| json_file = tmp_json_file.name | |
| tmp_json_file.write(json.dumps(coref_output).encode("utf-8")) | |
| # with open(html_file, "w") as f: | |
| # f.write(html) | |
| # with open(json_file, "w") as f: | |
| # json.dump(coref_output, f) | |
| print("HTML file:", html_file) | |
| print("JSON file:", json_file) | |
| return ( | |
| html_file, | |
| json_file, | |
| gr.DownloadButton(value=html_file, visible=True), | |
| gr.DownloadButton(value=json_file, visible=True), | |
| ) | |
| def download_html(): | |
| return gr.DownloadButton(visible=False) | |
| def download_json(): | |
| return gr.DownloadButton(visible=False) | |
| with open("example_harry.txt", "r") as f: | |
| example_harry = f.read() | |
| options = ["static", "hybrid"] | |
| with gr.Blocks() as demo: | |
| html_file = gr.File(visible=False) | |
| json_file = gr.File(visible=False) | |
| html_button = gr.DownloadButton("Download HTML", visible=False) | |
| json_button = gr.DownloadButton("Download JSON", visible=False) | |
| html_button.click() | |
| json_button.click() | |
| iface = gr.Interface( | |
| fn=coref_visualizer, | |
| inputs=[ | |
| gr.Textbox(lines=1, placeholder="Enter document name:"), | |
| gr.Textbox(lines=10, placeholder="Enter text for coreference resolution:"), | |
| gr.Radio(choices=options, label="Select an Option"), | |
| ], | |
| outputs=[ | |
| html_file, | |
| json_file, | |
| html_button, | |
| json_button, | |
| ], | |
| title="MEI Visualizer", | |
| description=DESC_MD, | |
| examples=[ | |
| [ | |
| "example", | |
| "{{Harry}} went to Hogwarts to meet Hemoine and {{Ron}} . He also met Ron's mother at the railway station.", | |
| "static", | |
| ], | |
| ["example_large", example_harry, "static"], | |
| ], | |
| ) | |
| demo.launch(debug=True) | |