import gradio as gr import torch import spaces from transformers import AutoModelForCausalLM, AutoTokenizer MODEL_PATH = "NetherQuartz/tatoeba-tok-multi-gemma-2-2b-merged" DEVICE = "mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" LANGUAGE_LIST = ["English", "Russian", "Vietnamese"] theme = gr.themes.Base( primary_hue="red", secondary_hue="pink", neutral_hue="neutral", radius_size="xxl" ) def get_model() -> tuple[AutoModelForCausalLM, AutoTokenizer]: model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).to(DEVICE) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) return model, tokenizer model, tokenizer = get_model() @spaces.GPU @torch.inference_mode() def translate(src_lang: str, tgt_lang: str, query: str) -> str: text = f"Translate {src_lang} to {tgt_lang}.\nQuery: {query}\nAnswer:" tokens = tokenizer(text, return_tensors="pt").to(DEVICE) outputs = model.generate(**tokens) ans = tokenizer.decode(outputs[0], skip_special_tokens=True) return ans.removeprefix(text).strip() def process_input(from_toki: bool, chosen_language: str, query: str) -> str: if from_toki: src = "Toki Pona" tgt = chosen_language else: src = chosen_language tgt = "Toki Pona" return translate(src, tgt, query) def from_toki_handler(chosen_language: str, from_toki: bool): if from_toki: lang = "Toki Pona" label = "Target language" else: lang = chosen_language label = "Source language" return ( gr.Radio(choices=LANGUAGE_LIST, label=label), gr.Text(placeholder=f"Write in {lang}") ) def language_handler(chosen_language: str, from_toki: bool): if from_toki: lang = "Toki Pona" else: lang = chosen_language return gr.Text(placeholder=f"Write in {lang}") with gr.Blocks(theme=theme, title="💬 ilo toki") as demo: gr.Markdown("# 💬 ilo toki") from_toki = gr.Checkbox(label="From Toki Pona") chosen_language = gr.Radio(choices=LANGUAGE_LIST, label="Source language", value="English") query = gr.Text(placeholder="Write in English", label="Query", max_lines=1) from_toki.change( from_toki_handler, inputs=[chosen_language, from_toki], outputs=[chosen_language, query] ) chosen_language.select( language_handler, inputs=[chosen_language, from_toki], outputs=query ) output = gr.Text(show_label=False, placeholder="Translation result", max_lines=1) query.submit( process_input, inputs=[from_toki, chosen_language, query], outputs=output ) if __name__ == "__main__": demo.launch()