radinplaid commited on
Commit
8b335c7
·
verified ·
1 Parent(s): e6aa17d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -154
app.py CHANGED
@@ -1,162 +1,148 @@
1
- import faicons as fa
2
- import plotly.express as px
3
-
4
- # Load data and compute static values
5
- from shared import app_dir, tips
6
- from shinywidgets import render_plotly
7
-
8
- from shiny import reactive, render
9
- from shiny.express import input, ui
10
-
11
- bill_rng = (min(tips.total_bill), max(tips.total_bill))
12
-
13
- # Add page title and sidebar
14
- ui.page_opts(title="Restaurant tipping", fillable=True)
15
-
16
- with ui.sidebar(open="desktop"):
17
- ui.input_slider(
18
- "total_bill",
19
- "Bill amount",
20
- min=bill_rng[0],
21
- max=bill_rng[1],
22
- value=bill_rng,
23
- pre="$",
24
- )
25
- ui.input_checkbox_group(
26
- "time",
27
- "Food service",
28
- ["Lunch", "Dinner"],
29
- selected=["Lunch", "Dinner"],
30
- inline=True,
31
- )
32
- ui.input_action_button("reset", "Reset filter")
33
-
34
- # Add main content
35
- ICONS = {
36
- "user": fa.icon_svg("user", "regular"),
37
- "wallet": fa.icon_svg("wallet"),
38
- "currency-dollar": fa.icon_svg("dollar-sign"),
39
- "ellipsis": fa.icon_svg("ellipsis"),
40
- }
41
-
42
- with ui.layout_columns(fill=False):
43
- with ui.value_box(showcase=ICONS["user"]):
44
- "Total tippers"
45
-
46
- @render.express
47
- def total_tippers():
48
- tips_data().shape[0]
49
-
50
- with ui.value_box(showcase=ICONS["wallet"]):
51
- "Average tip"
52
-
53
- @render.express
54
- def average_tip():
55
- d = tips_data()
56
- if d.shape[0] > 0:
57
- perc = d.tip / d.total_bill
58
- f"{perc.mean():.1%}"
59
-
60
- with ui.value_box(showcase=ICONS["currency-dollar"]):
61
- "Average bill"
62
-
63
- @render.express
64
- def average_bill():
65
- d = tips_data()
66
- if d.shape[0] > 0:
67
- bill = d.total_bill.mean()
68
- f"${bill:.2f}"
69
-
70
-
71
- with ui.layout_columns(col_widths=[6, 6, 12]):
72
- with ui.card(full_screen=True):
73
- ui.card_header("Tips data")
74
-
75
- @render.data_frame
76
- def table():
77
- return render.DataGrid(tips_data())
78
-
79
- with ui.card(full_screen=True):
80
- with ui.card_header(class_="d-flex justify-content-between align-items-center"):
81
- "Total bill vs tip"
82
- with ui.popover(title="Add a color variable", placement="top"):
83
- ICONS["ellipsis"]
84
- ui.input_radio_buttons(
85
- "scatter_color",
86
- None,
87
- ["none", "sex", "smoker", "day", "time"],
88
- inline=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  )
 
 
90
 
91
- @render_plotly
92
- def scatterplot():
93
- color = input.scatter_color()
94
- return px.scatter(
95
- tips_data(),
96
- x="total_bill",
97
- y="tip",
98
- color=None if color == "none" else color,
99
- trendline="lowess",
100
- )
101
-
102
- with ui.card(full_screen=True):
103
- with ui.card_header(class_="d-flex justify-content-between align-items-center"):
104
- "Tip percentages"
105
- with ui.popover(title="Add a color variable"):
106
- ICONS["ellipsis"]
107
- ui.input_radio_buttons(
108
- "tip_perc_y",
109
- "Split by:",
110
- ["sex", "smoker", "day", "time"],
111
- selected="day",
112
- inline=True,
113
  )
 
114
 
115
- @render_plotly
116
- def tip_perc():
117
- from ridgeplot import ridgeplot
118
-
119
- dat = tips_data()
120
- dat["percent"] = dat.tip / dat.total_bill
121
- yvar = input.tip_perc_y()
122
- uvals = dat[yvar].unique()
123
-
124
- samples = [[dat.percent[dat[yvar] == val]] for val in uvals]
125
-
126
- plt = ridgeplot(
127
- samples=samples,
128
- labels=uvals,
129
- bandwidth=0.01,
130
- colorscale="viridis",
131
- colormode="row-index",
132
- )
133
-
134
- plt.update_layout(
135
- legend=dict(
136
- orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5
137
- )
138
- )
139
-
140
- return plt
141
-
142
-
143
- ui.include_css(app_dir / "styles.css")
144
-
145
- # --------------------------------------------------------
146
- # Reactive calculations and effects
147
- # --------------------------------------------------------
148
 
 
149
 
150
- @reactive.calc
151
- def tips_data():
152
- bill = input.total_bill()
153
- idx1 = tips.total_bill.between(bill[0], bill[1])
154
- idx2 = tips.time.isin(input.time())
155
- return tips[idx1 & idx2]
156
 
157
 
158
- @reactive.effect
159
- @reactive.event(input.reset)
160
- def _():
161
- ui.update_slider("total_bill", value=bill_rng)
162
- ui.update_checkbox_group("time", selected=["Lunch", "Dinner"])
 
1
+ from pathlib import Path
2
+
3
+ import uvicorn
4
+ from faicons import icon_svg as icon
5
+ from fire import Fire
6
+ from shiny import App, reactive, render, ui
7
+
8
+ from quickmt import Translator
9
+ from quickmt.hub import hf_download, hf_list
10
+
11
+ t = None
12
+
13
+ port: int = 7860,
14
+ host: str = "0.0.0.0"
15
+ ui.navbar_options(
16
+ bg="red",
17
+ )
18
+ app_ui = ui.page_navbar(
19
+ ui.nav_panel(
20
+ None,
21
+ ui.layout_columns(
22
+ ui.card(
23
+ ui.h4("Input Text"),
24
+ ui.input_text_area(
25
+ "input_text",
26
+ "",
27
+ value="",
28
+ width="100%",
29
+ height="600px",
30
+ ),
31
+ ui.input_action_button(
32
+ "translate_button", "Translate!", class_="btn-primary"
33
+ ),
34
+ ),
35
+ ui.card(ui.h4("Translation"), ui.output_ui("translate")),
36
+ ),
37
+ ),
38
+ ui.nav_spacer(),
39
+ ui.nav_control(
40
+ ui.input_dark_mode(
41
+ id="darkmode_toggle", mode="dark", style="padding-top: 10px;"
42
+ ),
43
+ ),
44
+ ui.nav_control(
45
+ ui.a(
46
+ icon("github", height="30px", width="30px", fill="#17a2b8"),
47
+ href="https://github.com/quickmt/quickmt",
48
+ target="_blank",
49
+ class_="btn btn-link",
50
+ ),
51
+ ),
52
+ sidebar=ui.sidebar(
53
+ ui.tooltip(
54
+ ui.input_selectize(
55
+ "model",
56
+ "Select model",
57
+ choices=[i.split("/")[1] for i in hf_list()],
58
+ ),
59
+ "QuickMT model to use. quickmt-fr-en will translate from French (fr) to English (en)",
60
+ ),
61
+ ui.tooltip(
62
+ ui.input_text(
63
+ "model_folder", "Model Folder", value=str(Path(".").absolute())
64
+ ),
65
+ "Folder where QuickMT models are (or will be) stored.",
66
+ ),
67
+ ui.tooltip(
68
+ ui.input_slider(
69
+ "beam_size", "Beam size", min=1, max=8, step=1, value=2
70
+ ),
71
+ "Balances speed and quality. 1 for fastest speed, 8 for highest quality, in between for a balance.",
72
+ ),
73
+ ui.tooltip(
74
+ ui.input_numeric(
75
+ "num_threads", "CPU Threads", min=1, max=16, step=1, value=4
76
+ ),
77
+ "Number of CPU threads to use for translation. Does not affect speed when using GPU.",
78
+ ),
79
+ ui.tooltip(
80
+ ui.input_selectize(
81
+ "compute_device",
82
+ "Compute Device",
83
+ choices=["auto", "cpu", "cuda"],
84
+ selected="cpu",
85
+ ),
86
+ "Auto will use the GPU if available, otherwise will use CPU.",
87
+ ),
88
+ width="350px",
89
+ ),
90
+ title=ui.h2("QuickMT"),
91
+ window_title="QuickMT",
92
+ theme=ui.Theme.from_brand(__file__),
93
+ navbar_options=ui.navbar_options(underline=False, theme="auto"),
94
+ )
95
+
96
+ def server(input, output, session):
97
+ @render.ui
98
+ @reactive.event(input.quickmt_model_download) # Take a dependency on the button
99
+ def model_download_output():
100
+ print(f"Downloading {input.model()} to {input.model_folder()}")
101
+ hf_download(
102
+ model_name="quickmt/" + input.model(),
103
+ output_dir=Path(input.model_folder()) / input.model(),
104
+ )
105
+ return "Model downloaded"
106
+
107
+ @render.ui
108
+ @reactive.event(input.translate_button) # Take a dependency on the button
109
+ def translate():
110
+ global t
111
+ model_path = Path(input.model_folder()) / input.model()
112
+ try:
113
+ if t is None or str(input.model()) != str(Path(t.model_path).name):
114
+ print(f"Loading model {input.model()}")
115
+ t = Translator(
116
+ str(model_path),
117
+ device=input.compute_device(),
118
+ inter_threads=int(input.num_threads()),
119
  )
120
+ if len(input.input_text()) == 0:
121
+ return ""
122
 
123
+ return [
124
+ ui.p(i)
125
+ for i in t(
126
+ input.input_text().splitlines(), beam_size=input.beam_size()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  )
128
+ ]
129
 
130
+ except:
131
+ return [
132
+ ui.value_box(
133
+ title=f"Ensure model is downloaded to {input.model_folder()}",
134
+ value="Failed to load model",
135
+ showcase=icon("bug"),
136
+ ),
137
+ ui.input_action_button(
138
+ "quickmt_model_download", "Download Model", class_="btn-primary"
139
+ ),
140
+ ui.output_ui("model_download_output"),
141
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ app = App(app_ui, server)
144
 
145
+ if __name__=="__main__":
146
+ uvicorn.run(app, port=port, host=host)
 
 
 
 
147
 
148