File size: 2,738 Bytes
0d698bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
import torch
import spaces

from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_PATH = "NetherQuartz/tatoeba-tok-multi-gemma-2-2b-merged"
DEVICE = "mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

LANGUAGE_LIST = ["English", "Russian", "Vietnamese"]


theme = gr.themes.Base(
    primary_hue="red",
    secondary_hue="pink",
    neutral_hue="neutral",
    radius_size="xxl"
)


def get_model() -> tuple[AutoModelForCausalLM, AutoTokenizer]:
    model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).to(DEVICE)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    return model, tokenizer


model, tokenizer = get_model()


@spaces.GPU
@torch.inference_mode()
def translate(src_lang: str, tgt_lang: str, query: str) -> str:
    text = f"Translate {src_lang} to {tgt_lang}.\nQuery: {query}\nAnswer:"
    tokens = tokenizer(text, return_tensors="pt").to(DEVICE)
    outputs = model.generate(**tokens)
    ans = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return ans.removeprefix(text).strip()


def process_input(from_toki: bool, chosen_language: str, query: str) -> str:
    if from_toki:
        src = "Toki Pona"
        tgt = chosen_language
    else:
        src = chosen_language
        tgt = "Toki Pona"

    return translate(src, tgt, query)


def from_toki_handler(chosen_language: str, from_toki: bool):
    if from_toki:
        lang = "Toki Pona"
        label = "Target language"
    else:
        lang = chosen_language
        label = "Source language"
    return (
        gr.Radio(choices=LANGUAGE_LIST, label=label),
        gr.Text(placeholder=f"Write in {lang}")
    )


def language_handler(chosen_language: str, from_toki: bool):
    if from_toki:
        lang = "Toki Pona"
    else:
        lang = chosen_language
    return gr.Text(placeholder=f"Write in {lang}")


with gr.Blocks(theme=theme, title="πŸ’¬ ilo toki") as demo:
    gr.Markdown("# πŸ’¬ ilo toki")

    from_toki = gr.Checkbox(label="From Toki Pona")
    chosen_language = gr.Radio(choices=LANGUAGE_LIST, label="Source language", value="English")
    query = gr.Text(placeholder="Write in English", label="Query", max_lines=1)

    from_toki.change(
        from_toki_handler,
        inputs=[chosen_language, from_toki],
        outputs=[chosen_language, query]
    )

    chosen_language.select(
        language_handler,
        inputs=[chosen_language, from_toki],
        outputs=query
    )

    output = gr.Text(show_label=False, placeholder="Translation result", max_lines=1)
    query.submit(
        process_input,
        inputs=[from_toki, chosen_language, query],
        outputs=output
    )


if __name__ == "__main__":
    demo.launch()