Spaces:
Running
Running
| from pathlib import Path | |
| import uvicorn | |
| from faicons import icon_svg as icon | |
| from fire import Fire | |
| from shiny import App, reactive, render, ui | |
| from quickmt import Translator | |
| from quickmt.hub import hf_download, hf_list | |
| t = None | |
| port: int = 7860, | |
| host: str = "0.0.0.0" | |
| ui.navbar_options( | |
| bg="red", | |
| ) | |
| app_ui = ui.page_navbar( | |
| ui.nav_panel( | |
| None, | |
| ui.layout_columns( | |
| ui.card( | |
| ui.h4("Input Text"), | |
| ui.input_text_area( | |
| "input_text", | |
| "", | |
| value="", | |
| width="100%", | |
| height="600px", | |
| ), | |
| ui.input_action_button( | |
| "translate_button", "Translate!", class_="btn-primary" | |
| ), | |
| ), | |
| ui.card(ui.h4("Translation"), ui.output_ui("translate")), | |
| ), | |
| ), | |
| ui.nav_spacer(), | |
| ui.nav_control( | |
| ui.input_dark_mode( | |
| id="darkmode_toggle", mode="dark", style="padding-top: 10px;" | |
| ), | |
| ), | |
| ui.nav_control( | |
| ui.a( | |
| icon("github", height="30px", width="30px", fill="#17a2b8"), | |
| href="https://github.com/quickmt/quickmt", | |
| target="_blank", | |
| class_="btn btn-link", | |
| ), | |
| ), | |
| sidebar=ui.sidebar( | |
| ui.tooltip( | |
| ui.input_selectize( | |
| "model", | |
| "Select model", | |
| choices=[i.split("/")[1] for i in hf_list()], | |
| ), | |
| "QuickMT model to use. quickmt-fr-en will translate from French (fr) to English (en)", | |
| ), | |
| ui.tooltip( | |
| ui.input_slider( | |
| "beam_size", "Beam size", min=1, max=8, step=1, value=2 | |
| ), | |
| "Balances speed and quality. 1 for fastest speed, 8 for highest quality, in between for a balance.", | |
| ), | |
| width="350px", | |
| ), | |
| title=ui.h2("QuickMT Machine Translation Demo"), | |
| window_title="QuickMT", | |
| theme=ui.Theme.from_brand(__file__), | |
| navbar_options=ui.navbar_options(underline=False, theme="auto"), | |
| ) | |
| def server(input, output, session): | |
| # Take a dependency on the button | |
| def model_download_output(): | |
| #print(f"Downloading {input.model()} to {input.model_folder()}") | |
| hf_download( | |
| model_name="quickmt/" + input.model(), | |
| output_dir=Path("/code/models") / input.model(), | |
| ) | |
| return "Model downloaded" | |
| # Take a dependency on the button | |
| def translate(): | |
| global t | |
| model_path = Path("/code/models") / input.model() | |
| if not model_path.exists(): | |
| ui.notification_show( | |
| f"Downloading model {input.model()}...", | |
| type="message", | |
| duration=3, | |
| ) | |
| hf_download( | |
| model_name="quickmt/" + input.model(), | |
| output_dir=Path("/code/models") / input.model(), | |
| ) | |
| try: | |
| if t is None or str(input.model()) != str(Path(t.model_path).name): | |
| print(f"Loading model {input.model()}") | |
| t = Translator( | |
| str(model_path), | |
| device="cpu", | |
| inter_threads=2, | |
| ) | |
| if len(input.input_text()) == 0: | |
| return "" | |
| return [ | |
| ui.p(i) | |
| for i in t( | |
| input.input_text().splitlines(), beam_size=input.beam_size() | |
| ) | |
| ] | |
| except: | |
| return [ | |
| ui.value_box( | |
| title=f"Unexpected error", | |
| value="Failed to load model", | |
| showcase=icon("bug"), | |
| ), | |
| ] | |
| app = App(app_ui, server) | |
| if __name__=="__main__": | |
| uvicorn.run(app, port=port, host=host) | |