ilo-toki / src /streamlit_app.py
NetherQuartz's picture
Update src/streamlit_app.py
2707897 verified
import torch
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_PATH = "NetherQuartz/tatoeba-tok-multi-gemma-2-2b-merged"
DEVICE = "mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
st.set_page_config(
page_icon="πŸ’¬",
page_title="ilo toki"
)
@st.cache_resource
def get_model() -> tuple[AutoModelForCausalLM, AutoTokenizer]:
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
return model, tokenizer
with st.spinner(text="Loading model, please wait...", show_time=True):
model, tokenizer = get_model()
@torch.inference_mode()
def translate(src_lang: str, tgt_lang: str, query: str) -> str:
text = f"Translate {src_lang} to {tgt_lang}.\nQuery: {query}\nAnswer:"
tokens = tokenizer(text, return_tensors="pt").to(DEVICE)
outputs = model.generate(**tokens)
ans = tokenizer.decode(outputs[0], skip_special_tokens=True)
return ans.removeprefix(text).strip()
st.title("πŸ’¬ ilo toki")
from_toki = st.toggle("From Toki Pona")
chosen_language = st.pills(
f"{"Source" if not from_toki else "Target"} language",
["English", "Russian", "Vietnamese"],
default="English"
)
if from_toki:
src = "Toki Pona"
tgt = chosen_language
else:
src = chosen_language
tgt = "Toki Pona"
if query := st.text_input("Query", placeholder=f"Write in {src}"):
with st.spinner():
st.text(translate(src, tgt, query))