NetherQuartz commited on
Commit
0d698bf
·
verified ·
1 Parent(s): 2707897

Add Gradio app

Browse files
Files changed (1) hide show
  1. app.py +100 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import spaces
4
+
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+ MODEL_PATH = "NetherQuartz/tatoeba-tok-multi-gemma-2-2b-merged"
8
+ DEVICE = "mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ LANGUAGE_LIST = ["English", "Russian", "Vietnamese"]
11
+
12
+
13
+ theme = gr.themes.Base(
14
+ primary_hue="red",
15
+ secondary_hue="pink",
16
+ neutral_hue="neutral",
17
+ radius_size="xxl"
18
+ )
19
+
20
+
21
+ def get_model() -> tuple[AutoModelForCausalLM, AutoTokenizer]:
22
+ model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).to(DEVICE)
23
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
24
+ return model, tokenizer
25
+
26
+
27
+ model, tokenizer = get_model()
28
+
29
+
30
+ @spaces.GPU
31
+ @torch.inference_mode()
32
+ def translate(src_lang: str, tgt_lang: str, query: str) -> str:
33
+ text = f"Translate {src_lang} to {tgt_lang}.\nQuery: {query}\nAnswer:"
34
+ tokens = tokenizer(text, return_tensors="pt").to(DEVICE)
35
+ outputs = model.generate(**tokens)
36
+ ans = tokenizer.decode(outputs[0], skip_special_tokens=True)
37
+ return ans.removeprefix(text).strip()
38
+
39
+
40
+ def process_input(from_toki: bool, chosen_language: str, query: str) -> str:
41
+ if from_toki:
42
+ src = "Toki Pona"
43
+ tgt = chosen_language
44
+ else:
45
+ src = chosen_language
46
+ tgt = "Toki Pona"
47
+
48
+ return translate(src, tgt, query)
49
+
50
+
51
+ def from_toki_handler(chosen_language: str, from_toki: bool):
52
+ if from_toki:
53
+ lang = "Toki Pona"
54
+ label = "Target language"
55
+ else:
56
+ lang = chosen_language
57
+ label = "Source language"
58
+ return (
59
+ gr.Radio(choices=LANGUAGE_LIST, label=label),
60
+ gr.Text(placeholder=f"Write in {lang}")
61
+ )
62
+
63
+
64
+ def language_handler(chosen_language: str, from_toki: bool):
65
+ if from_toki:
66
+ lang = "Toki Pona"
67
+ else:
68
+ lang = chosen_language
69
+ return gr.Text(placeholder=f"Write in {lang}")
70
+
71
+
72
+ with gr.Blocks(theme=theme, title="💬 ilo toki") as demo:
73
+ gr.Markdown("# 💬 ilo toki")
74
+
75
+ from_toki = gr.Checkbox(label="From Toki Pona")
76
+ chosen_language = gr.Radio(choices=LANGUAGE_LIST, label="Source language", value="English")
77
+ query = gr.Text(placeholder="Write in English", label="Query", max_lines=1)
78
+
79
+ from_toki.change(
80
+ from_toki_handler,
81
+ inputs=[chosen_language, from_toki],
82
+ outputs=[chosen_language, query]
83
+ )
84
+
85
+ chosen_language.select(
86
+ language_handler,
87
+ inputs=[chosen_language, from_toki],
88
+ outputs=query
89
+ )
90
+
91
+ output = gr.Text(show_label=False, placeholder="Translation result", max_lines=1)
92
+ query.submit(
93
+ process_input,
94
+ inputs=[from_toki, chosen_language, query],
95
+ outputs=output
96
+ )
97
+
98
+
99
+ if __name__ == "__main__":
100
+ demo.launch()