| | import streamlit as st |
| |
|
| | st.set_page_config( |
| | layout="wide", |
| | initial_sidebar_state="auto", |
| | page_title='ZShot', |
| | page_icon='./logo_zshot.png', |
| | ) |
| |
|
| | import os |
| | import sys |
| | import warnings |
| |
|
| | import spacy |
| | from zshot.linker import LinkerSMXM, LinkerTARS, LinkerRegen |
| | from zshot.utils.data_models import Entity |
| | from zshot.mentions_extractor import MentionsExtractorSpacy |
| | from zshot.mentions_extractor.utils import ExtractorType |
| | from zshot import PipelineConfig, displacy |
| |
|
| | sys.path.append(os.path.abspath('./')) |
| | import streamlit_apps_config as config |
| |
|
| | warnings.simplefilter('ignore') |
| |
|
| | |
| | st.markdown(config.STYLE_CONFIG, unsafe_allow_html=True) |
| |
|
| | |
| |
|
| | hide_menu_style = """ |
| | <style> |
| | #MainMenu {visibility: hidden;} |
| | </style> |
| | """ |
| | st.markdown(hide_menu_style, unsafe_allow_html=True) |
| |
|
| | |
| |
|
| | |
| | import base64 |
| |
|
| |
|
| | @st.cache_data() |
| | def get_base64_of_bin_file(bin_file): |
| | with open(bin_file, 'rb') as f: |
| | data = f.read() |
| | return base64.b64encode(data).decode() |
| |
|
| |
|
| | @st.cache_data() |
| | def get_img_with_href(local_img_path, target_url, size='big'): |
| | img_format = os.path.splitext(local_img_path)[-1].replace('.', '') |
| | bin_str = get_base64_of_bin_file(local_img_path) |
| | height = '90%' if size == 'big' else '45%' |
| | width = '90%' if size == 'big' else '45%' |
| | html_code = f''' |
| | <a href="{target_url}" style='text-align: center;'> |
| | <img height="{height}" width="{width}" style='display: block; margin-left: auto; margin-right: auto;' src="data:image/{img_format};base64,{bin_str}" /> |
| | </a>''' |
| | return html_code |
| |
|
| |
|
| | logo_html = get_img_with_href('./logo.png', 'https://www.ibm.com/') |
| | st.sidebar.markdown(logo_html, unsafe_allow_html=True) |
| | logo_html = get_img_with_href('./logo_zshot.png', 'https://github.com/IBM/zshot', size='small') |
| | st.sidebar.markdown(logo_html, unsafe_allow_html=True) |
| |
|
| | |
| | linkers = ["REGEN", "SMXM", "TARS"] |
| | st.sidebar.title("Linker to test") |
| | selected_model = st.sidebar.selectbox("", linkers) |
| |
|
| | |
| |
|
| | if selected_model == "REGEN": |
| | app_title = "REGEN Linker" |
| | app_description = "REGEN is a T5 implementation of GENRE. It performs retrieval generating the unique entity name conditioned on the input text using constrained beam search to only generate valid identifiers." |
| | st.title(app_title) |
| | st.markdown("<h2>" + app_description + "</h2>", unsafe_allow_html=True) |
| |
|
| | elif selected_model == "SMXM": |
| | app_title = "SMXM Linker" |
| | app_description = "SMXM model uses the description of the entities to give the model information about the entities." |
| | st.title(app_title) |
| | st.markdown("<h2>" + app_description + "</h2>", unsafe_allow_html=True) |
| |
|
| | elif selected_model == "TARS": |
| | app_title = "TARS Linker" |
| | app_description = "TARS doesn't need the descriptions of the entities, so if you can't provide the descriptions of the entities maybe this is the approach you're looking for." |
| | st.title(app_title) |
| | st.markdown("<h2>" + app_description + "</h2>", unsafe_allow_html=True) |
| |
|
| | st.subheader("") |
| |
|
| | if 'entities' not in st.session_state: |
| | st.session_state['entities'] = [ |
| | Entity(name="company", description="The name of a company"), |
| | Entity(name="location", description="A physical location"), |
| | Entity(name="chemical compound", description="Any substance composed of identical molecules consisting of atoms of two or more chemical elements.") |
| | ] |
| |
|
| | def add_ent(): |
| | st.session_state['entities'].append(Entity(name=st.session_state["name"], description=st.session_state["description"])) |
| | st.session_state['name'] = "" |
| | st.session_state['description'] = '' |
| | st.write(st.session_state["name"]) |
| | st.write(st.session_state["description"]) |
| |
|
| | for i, entity in enumerate(st.session_state['entities']): |
| | col1, col2, col3 = st.columns([2, 5, 1]) |
| | with col1: |
| | st.text(entity.name) |
| | with col2: |
| | st.text(entity.description) |
| | with col3: |
| | b = st.button('Remove', key=f"ent_{i}") |
| | if b: |
| | st.session_state['entities'].pop(i) |
| | st.experimental_rerun() |
| |
|
| | with st.form(key="form"): |
| | col1, col2, col3 = st.columns([2, 5, 1]) |
| | with col1: |
| | st.text_input("Entity Name", key="name") |
| | with col2: |
| | st.text_input("Entity Description", key="description") |
| | with col3: |
| | st.form_submit_button('Add', on_click=add_ent) |
| |
|
| | st.markdown("________") |
| | text = st.text_input("Type here your text and press enter to run:", |
| | value="CH2O2 is a chemical compound similar to Acetamide used in International Business " |
| | "Machines Corporation (IBM) to create new materials that act like PAGs.") |
| |
|
| | def build_pipeline(model_name=selected_model): |
| | nlp = spacy.blank('en') |
| | mentions_extractor = None |
| |
|
| | if model_name == "REGEN": |
| | linker = LinkerRegen() |
| | nlp = spacy.load('en_core_web_sm') |
| | mentions_extractor = MentionsExtractorSpacy(ExtractorType.NER) |
| | elif model_name == "TARS": |
| | linker = LinkerTARS() |
| | elif model_name == "SMXM": |
| | linker = LinkerSMXM() |
| |
|
| | config = PipelineConfig( |
| | entities=st.session_state['entities'], |
| | mentions_extractor=mentions_extractor, |
| | linker=linker |
| | ) |
| | nlp.add_pipe("zshot", config=config, last=True) |
| |
|
| | return nlp |
| |
|
| | predict = st.button("Run ZShot") |
| | if predict: |
| | |
| | placeholder = st.empty() |
| | placeholder.info("Processing...") |
| |
|
| | nlp = build_pipeline() |
| | doc = nlp(text) |
| | placeholder.empty() |
| |
|
| | ent_html = displacy.render(doc, style="ent", jupyter=False) |
| | st.markdown(ent_html, unsafe_allow_html=True) |
| |
|
| | st.sidebar.info("""See more: |
| | - Check ZShot Github [here](https://github.com/IBM/zshot) |
| | - Check ZShot documentation [here](https://ibm.github.io/zshot/)""") |
| |
|