File size: 1,525 Bytes
1ae966e
b0f8be7
 
1ae966e
 
07c5c8c
2707897
fb9637d
 
 
 
 
1ae966e
 
 
fb9637d
2707897
1ae966e
 
 
 
7d80ddc
 
1ae966e
 
b24ae08
fb9637d
1ae966e
2707897
1ae966e
 
fb9637d
1ae966e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2707897
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import streamlit as st

from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_PATH = "NetherQuartz/tatoeba-tok-multi-gemma-2-2b-merged"
DEVICE = "mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

st.set_page_config(
    page_icon="πŸ’¬",
    page_title="ilo toki"
)


@st.cache_resource
def get_model() -> tuple[AutoModelForCausalLM, AutoTokenizer]:
    model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).to(DEVICE)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    return model, tokenizer


with st.spinner(text="Loading model, please wait...", show_time=True):
    model, tokenizer = get_model()


@torch.inference_mode()
def translate(src_lang: str, tgt_lang: str, query: str) -> str:
    text = f"Translate {src_lang} to {tgt_lang}.\nQuery: {query}\nAnswer:"
    tokens = tokenizer(text, return_tensors="pt").to(DEVICE)
    outputs = model.generate(**tokens)
    ans = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return ans.removeprefix(text).strip()


st.title("πŸ’¬ ilo toki")

from_toki = st.toggle("From Toki Pona")

chosen_language = st.pills(
    f"{"Source" if not from_toki else "Target"} language",
    ["English", "Russian", "Vietnamese"],
    default="English"
)

if from_toki:
    src = "Toki Pona"
    tgt = chosen_language
else:
    src = chosen_language
    tgt = "Toki Pona"

if query := st.text_input("Query", placeholder=f"Write in {src}"):
    with st.spinner():
        st.text(translate(src, tgt, query))