import torch import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer MODEL_PATH = "NetherQuartz/tatoeba-tok-multi-gemma-2-2b-merged" DEVICE = "mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" st.set_page_config( page_icon="💬", page_title="ilo toki" ) @st.cache_resource def get_model() -> tuple[AutoModelForCausalLM, AutoTokenizer]: model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).to(DEVICE) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) return model, tokenizer with st.spinner(text="Loading model, please wait...", show_time=True): model, tokenizer = get_model() @torch.inference_mode() def translate(src_lang: str, tgt_lang: str, query: str) -> str: text = f"Translate {src_lang} to {tgt_lang}.\nQuery: {query}\nAnswer:" tokens = tokenizer(text, return_tensors="pt").to(DEVICE) outputs = model.generate(**tokens) ans = tokenizer.decode(outputs[0], skip_special_tokens=True) return ans.removeprefix(text).strip() st.title("💬 ilo toki") from_toki = st.toggle("From Toki Pona") chosen_language = st.pills( f"{"Source" if not from_toki else "Target"} language", ["English", "Russian", "Vietnamese"], default="English" ) if from_toki: src = "Toki Pona" tgt = chosen_language else: src = chosen_language tgt = "Toki Pona" if query := st.text_input("Query", placeholder=f"Write in {src}"): with st.spinner(): st.text(translate(src, tgt, query))