| | import platform |
| | from typing import Tuple, List, Dict, Optional |
| |
|
| | import streamlit as st |
| | import torch |
| | from trecover.config import var |
| | from trecover.utils.beam_search import beam_search, dashboard_loop |
| | from trecover.utils.inference import data_to_columns, create_noisy_columns |
| | from trecover.utils.transform import columns_to_tensor, tensor_to_target |
| | from trecover.utils.visualization import visualize_columns, visualize_target |
| |
|
| | MAX_CHARS = 256 |
| |
|
| | PLAIN_EXAMPLES = { |
| | 'Select example': None, |
| | 'Example 1': 'As people around the country went into the streets to cheer the conviction, some businesses in ' |
| | 'Portland boarded up their windows once again.', |
| | 'Example 2': 'That night, a small group of activists wearing black approached a group of journalists, threatening' |
| | ' to smash the cameras of those who remained on scene.', |
| | 'Example 3': 'English as we know it today came to be exported to other parts of the world through British ' |
| | 'colonisation, and is now the dominant language in Britain' |
| | } |
| |
|
| | NOISED_EXAMPLES = { |
| | 'Select example': None, |
| | 'Example 1': 'a ds fpziq ofe ngkhbo p pghl ue waq frlqjo o u dnxrm dgr yrtsco kho deuasm dhysc ao u nwzhy tle r ' |
| | 'yzpe xwabc gce nger klqto wiq nfprso t no tpgq tcfh ae twas tw ur re e t gyutsm t xgo rc ubhq e wle ' |
| | 'r ty h nwpeaq xdsc o dnhelm v thir ikcq tkuo i o twn ps frio mo oe b kuiqtb jsq zi tnye ge dgrqs s ' |
| | 'cioe ys whic wne wp thlo dnprsc xvpyrt hurlm kveaj nbfp dome pbeaj dusmo a r dzrqsm xace du nxkuai ' |
| | 'gpulcm tpi h pie uim r wbhrj ui n dwgp dkeio nkwhqs zs' |
| | } |
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | model = torch.hub.load('alex-snd/TRecover', model='trecover', device=device, version='latest') |
| |
|
| |
|
| | def main() -> None: |
| | st.set_page_config( |
| | page_title='TRecover', |
| | page_icon='🩹', |
| | layout='wide', |
| | initial_sidebar_state='expanded') |
| |
|
| | if 'history' not in st.session_state: |
| | st.session_state.history = list() |
| |
|
| | if 'data' not in st.session_state: |
| | st.session_state.data = '' |
| |
|
| | if 'regenerate' not in st.session_state: |
| | st.session_state.regenerate = False |
| |
|
| | if 'columns' not in st.session_state: |
| | st.session_state.columns = None |
| |
|
| | if 'is_unix' not in st.session_state: |
| | st.session_state.is_unix = platform.system() != 'Windows' |
| |
|
| | sidebar() |
| |
|
| |
|
| | def set_regenerate() -> None: |
| | st.session_state.regenerate = True |
| |
|
| |
|
| | def unset_regenerate() -> None: |
| | st.session_state.regenerate = False |
| |
|
| |
|
| | def sidebar() -> None: |
| | st.sidebar.markdown(body= |
| | """ |
| | <h1 align="center"> |
| | <font size="20">🤷</font> |
| | <a href="https://alex-snd.github.io/TRecover">About the Project</a> |
| | </h1> |
| | <br><br> |
| | """, |
| | unsafe_allow_html=True) |
| |
|
| | option = st.sidebar.radio('Sections', ('Inference', 'Inference history')) |
| |
|
| | if option == 'Inference': |
| | is_plain, min_noise, max_noise, bw = inference_sidebar() |
| | inference_page(is_plain, min_noise, max_noise, bw) |
| | else: |
| | history_sidebar() |
| | history_page() |
| |
|
| |
|
| | def inference_sidebar() -> Tuple[bool, int, int, int]: |
| | st.sidebar.text('\n') |
| |
|
| | data_type = st.sidebar.radio('Input type', ('Plain text', 'Noisy columns'), key='data_type', |
| | index=0 if 'Plain text' == st.session_state.get('data_type', 'Plain text') else 1) |
| | is_plain = data_type == 'Plain text' |
| |
|
| | st.sidebar.text('\n') |
| |
|
| | if is_plain: |
| | min_noise, max_noise = st.sidebar.slider('\nNoise range', 0, 5, key='noise_range', |
| | value=st.session_state.get('noise_range', (0, 5)), |
| | on_change=set_regenerate) |
| | else: |
| | min_noise, max_noise = 0, 0 |
| |
|
| | bw = st.sidebar.slider('Beam search width', 1, 25, key='beam_width', |
| | value=st.session_state.get('beam_width', 5)) |
| |
|
| | if max_noise > var.MAX_NOISE: |
| | st.sidebar.warning('Max noise value is too large. This will entail poor performance') |
| |
|
| | return is_plain, min_noise, max_noise + 1, bw |
| |
|
| |
|
| | def history_sidebar() -> None: |
| | pass |
| |
|
| |
|
| | def save_to_history(is_plain: bool, |
| | min_noise: int, |
| | max_noise: int, |
| | bw: int, |
| | columns: List[str], |
| | chains: List[Tuple[str, float]] |
| | ) -> None: |
| | text = st.session_state.data if is_plain else None |
| |
|
| | st.session_state.history.append((is_plain, text, min_noise, max_noise, bw, columns, chains)) |
| |
|
| |
|
| | @st.cache(ttl=3600, show_spinner=False, suppress_st_warning=True) |
| | def predict(columns: List[str], bw: int) -> List[Tuple[str, float]]: |
| | src = columns_to_tensor(columns, device) |
| |
|
| | chains = beam_search(src, model, bw, device, beam_loop=dashboard_loop) |
| | chains = [(visualize_target(tensor_to_target(chain)), prob) for (chain, prob) in chains] |
| |
|
| | return chains |
| |
|
| |
|
| | def get_noisy_columns(data: str, min_noise: int, max_noise: int) -> List[str]: |
| | columns = create_noisy_columns(data, min_noise, max_noise) |
| |
|
| | return [''.join(set(c)) for c in columns] |
| |
|
| |
|
| | def get_input_data(examples: Dict[str, Optional[str]], max_chars: int) -> str: |
| | input_field, examples_filed = st.columns([1, 0.27]) |
| |
|
| | option = examples_filed.selectbox(label='', options=examples.keys()) |
| |
|
| | return input_field.text_input(label='', value=examples[option] or st.session_state.data, max_chars=max_chars) |
| |
|
| |
|
| | def inference_page(is_plain: bool, min_noise: int, max_noise: int, bw: int) -> None: |
| | st.subheader('Insert plain text' if is_plain else 'Insert noisy columns separated by spaces') |
| |
|
| | if is_plain: |
| | data = get_input_data(PLAIN_EXAMPLES, max_chars=MAX_CHARS) |
| | else: |
| | data = get_input_data(NOISED_EXAMPLES, max_chars=MAX_CHARS * 4) |
| |
|
| | if not data: |
| | st.stop() |
| |
|
| | if is_plain: |
| | if st.session_state.regenerate or not st.session_state.columns or data != st.session_state.data: |
| | columns = get_noisy_columns(data, min_noise, max_noise) |
| | st.session_state.columns = columns |
| | unset_regenerate() |
| | else: |
| | columns = st.session_state.columns |
| | else: |
| | columns = data_to_columns(data, separator=' ') |
| |
|
| | st.session_state.data = data |
| |
|
| | st.subheader('\nColumns') |
| | st.text(visualize_columns(columns, delimiter='')) |
| | st.subheader('\n') |
| |
|
| | placeholder = st.empty() |
| | recover_field, regen_filed = placeholder.columns([.11, 1]) |
| |
|
| | if is_plain: |
| | regen_filed.button('Regenerate', on_click=set_regenerate) |
| |
|
| | if columns and recover_field.button('Recover'): |
| | if st.session_state.is_unix: |
| | with placeholder.container(): |
| | progress_bar_placeholder = st.empty() |
| | st.button('Stop') |
| |
|
| | with progress_bar_placeholder: |
| | chains = predict(columns, bw) |
| | else: |
| | with placeholder: |
| | chains = predict(columns, bw) |
| |
|
| | with placeholder.container(): |
| | st.subheader('\nPrediction') |
| | st.text('\n\n'.join(chain for chain, _ in chains)) |
| |
|
| | if st.button('Clear'): |
| | st.session_state.task_id = None |
| |
|
| | save_to_history(is_plain, min_noise, max_noise, bw, columns, chains) |
| |
|
| |
|
| | def history_page() -> None: |
| | st.header('Inference History') |
| |
|
| | if len(st.session_state.history) == 0: |
| | st.info('No records saved') |
| | return |
| |
|
| | for record_id, (is_plain, text, min_noise, max_noise, bw, columns, chains) in enumerate(st.session_state.history, |
| | start=1): |
| | st.info(f'Record {record_id}') |
| |
|
| | if is_plain: |
| | st.text(f'Plain data: {text}') |
| |
|
| | st.text(f'Noise range: [{min_noise}, {max_noise}]') |
| | st.text(f'Beam search width: {bw}') |
| |
|
| | st.text('Columns:') |
| | st.text(visualize_columns(columns, delimiter='')) |
| |
|
| | st.text('Prediction:') |
| | st.text('\n\n'.join(chain for chain, _ in chains)) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |