File size: 3,982 Bytes
8b335c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3dbb08
8b335c7
 
 
 
 
 
 
 
 
dfc3287
8b335c7
 
a9289c4
8b335c7
 
 
 
 
 
 
a9289c4
640ff2e
 
244bd49
 
 
 
 
3d12515
640ff2e
a9289c4
640ff2e
8b335c7
 
 
 
 
9ffdab1
 
2235573
8b335c7
 
2235573
8b335c7
 
 
 
2235573
8b335c7
2235573
8b335c7
 
 
640ff2e
8b335c7
 
 
 
2235573
8b335c7
2235573
8b335c7
 
2235573
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from pathlib import Path

import uvicorn
from faicons import icon_svg as icon
from fire import Fire
from shiny import App, reactive, render, ui

from quickmt import Translator
from quickmt.hub import hf_download, hf_list

t = None

port: int = 7860,
host: str = "0.0.0.0"
ui.navbar_options(
    bg="red",
)
app_ui = ui.page_navbar(
    ui.nav_panel(
        None,
        ui.layout_columns(
            ui.card(
                ui.h4("Input Text"),
                ui.input_text_area(
                    "input_text",
                    "",
                    value="",
                    width="100%",
                    height="600px",
                ),
                ui.input_action_button(
                    "translate_button", "Translate!", class_="btn-primary"
                ),
            ),
            ui.card(ui.h4("Translation"), ui.output_ui("translate")),
        ),
    ),
    ui.nav_spacer(),
    ui.nav_control(
        ui.input_dark_mode(
            id="darkmode_toggle", mode="dark", style="padding-top: 10px;"
        ),
    ),
    ui.nav_control(
        ui.a(
            icon("github", height="30px", width="30px", fill="#17a2b8"),
            href="https://github.com/quickmt/quickmt",
            target="_blank",
            class_="btn btn-link",
        ),
    ),
    sidebar=ui.sidebar(
        ui.tooltip(
            ui.input_selectize(
                "model",
                "Select model",
                choices=[i.split("/")[1] for i in hf_list()],
            ),
            "QuickMT model to use. quickmt-fr-en will translate from French (fr) to English (en)",
        ),
        ui.tooltip(
            ui.input_slider(
                "beam_size", "Beam size", min=1, max=8, step=1, value=2
            ),
            "Balances speed and quality. 1 for fastest speed, 8 for highest quality, in between for a balance.",
        ),
        width="350px",
    ),
    title=ui.h2("QuickMT Machine Translation Demo"),
    window_title="QuickMT",
    theme=ui.Theme.from_brand(__file__),
    navbar_options=ui.navbar_options(underline=False, theme="auto"),
)

def server(input, output, session):
    @render.ui
    @reactive.event(input.quickmt_model_download)  # Take a dependency on the button
    def model_download_output():
        #print(f"Downloading {input.model()} to {input.model_folder()}")
        hf_download(
            model_name="quickmt/" + input.model(),
            output_dir=Path("/code/models") / input.model(),
        )
        return "Model downloaded"

    @render.ui
    @reactive.event(input.translate_button)  # Take a dependency on the button
    def translate():
        global t
        model_path = Path("/code/models") / input.model()

        if not model_path.exists():
            ui.notification_show(
                f"Downloading model {input.model()}...",
                type="message",
                duration=3,
            )
            hf_download(
                model_name="quickmt/" + input.model(),
                output_dir=Path("/code/models") / input.model(),
            )
        try:
            if t is None or str(input.model()) != str(Path(t.model_path).name):
                print(f"Loading model {input.model()}")
                t = Translator(
                    str(model_path),
                    device="cpu",
                    inter_threads=2,
                )
            if len(input.input_text()) == 0:
                return ""

            return [
                ui.p(i)
                for i in t(
                    input.input_text().splitlines(), beam_size=input.beam_size()
                )
            ]

        except:
            return [
                ui.value_box(
                    title=f"Unexpected error",
                    value="Failed to load model",
                    showcase=icon("bug"),
                ),
            ]

app = App(app_ui, server)

if __name__=="__main__":
    uvicorn.run(app, port=port, host=host)