ilo-toki / app.py
NetherQuartz's picture
Add Gradio app
0d698bf verified
import gradio as gr
import torch
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_PATH = "NetherQuartz/tatoeba-tok-multi-gemma-2-2b-merged"
DEVICE = "mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
LANGUAGE_LIST = ["English", "Russian", "Vietnamese"]
theme = gr.themes.Base(
primary_hue="red",
secondary_hue="pink",
neutral_hue="neutral",
radius_size="xxl"
)
def get_model() -> tuple[AutoModelForCausalLM, AutoTokenizer]:
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
return model, tokenizer
model, tokenizer = get_model()
@spaces.GPU
@torch.inference_mode()
def translate(src_lang: str, tgt_lang: str, query: str) -> str:
text = f"Translate {src_lang} to {tgt_lang}.\nQuery: {query}\nAnswer:"
tokens = tokenizer(text, return_tensors="pt").to(DEVICE)
outputs = model.generate(**tokens)
ans = tokenizer.decode(outputs[0], skip_special_tokens=True)
return ans.removeprefix(text).strip()
def process_input(from_toki: bool, chosen_language: str, query: str) -> str:
if from_toki:
src = "Toki Pona"
tgt = chosen_language
else:
src = chosen_language
tgt = "Toki Pona"
return translate(src, tgt, query)
def from_toki_handler(chosen_language: str, from_toki: bool):
if from_toki:
lang = "Toki Pona"
label = "Target language"
else:
lang = chosen_language
label = "Source language"
return (
gr.Radio(choices=LANGUAGE_LIST, label=label),
gr.Text(placeholder=f"Write in {lang}")
)
def language_handler(chosen_language: str, from_toki: bool):
if from_toki:
lang = "Toki Pona"
else:
lang = chosen_language
return gr.Text(placeholder=f"Write in {lang}")
with gr.Blocks(theme=theme, title="πŸ’¬ ilo toki") as demo:
gr.Markdown("# πŸ’¬ ilo toki")
from_toki = gr.Checkbox(label="From Toki Pona")
chosen_language = gr.Radio(choices=LANGUAGE_LIST, label="Source language", value="English")
query = gr.Text(placeholder="Write in English", label="Query", max_lines=1)
from_toki.change(
from_toki_handler,
inputs=[chosen_language, from_toki],
outputs=[chosen_language, query]
)
chosen_language.select(
language_handler,
inputs=[chosen_language, from_toki],
outputs=query
)
output = gr.Text(show_label=False, placeholder="Translation result", max_lines=1)
query.submit(
process_input,
inputs=[from_toki, chosen_language, query],
outputs=output
)
if __name__ == "__main__":
demo.launch()