Spaces:
Running
Running
| import torch | |
| import streamlit as st | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| MODEL_PATH = "NetherQuartz/tatoeba-tok-multi-gemma-2-2b-merged" | |
| DEVICE = "mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" | |
| st.set_page_config( | |
| page_icon="π¬", | |
| page_title="ilo toki" | |
| ) | |
| def get_model() -> tuple[AutoModelForCausalLM, AutoTokenizer]: | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).to(DEVICE) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
| return model, tokenizer | |
| with st.spinner(text="Loading model, please wait...", show_time=True): | |
| model, tokenizer = get_model() | |
| def translate(src_lang: str, tgt_lang: str, query: str) -> str: | |
| text = f"Translate {src_lang} to {tgt_lang}.\nQuery: {query}\nAnswer:" | |
| tokens = tokenizer(text, return_tensors="pt").to(DEVICE) | |
| outputs = model.generate(**tokens) | |
| ans = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return ans.removeprefix(text).strip() | |
| st.title("π¬ ilo toki") | |
| from_toki = st.toggle("From Toki Pona") | |
| chosen_language = st.pills( | |
| f"{"Source" if not from_toki else "Target"} language", | |
| ["English", "Russian", "Vietnamese"], | |
| default="English" | |
| ) | |
| if from_toki: | |
| src = "Toki Pona" | |
| tgt = chosen_language | |
| else: | |
| src = chosen_language | |
| tgt = "Toki Pona" | |
| if query := st.text_input("Query", placeholder=f"Write in {src}"): | |
| with st.spinner(): | |
| st.text(translate(src, tgt, query)) |