Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| MODEL_PATH = "NetherQuartz/tatoeba-tok-multi-gemma-2-2b-merged" | |
| DEVICE = "mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" | |
| LANGUAGE_LIST = ["English", "Russian", "Vietnamese"] | |
| theme = gr.themes.Base( | |
| primary_hue="red", | |
| secondary_hue="pink", | |
| neutral_hue="neutral", | |
| radius_size="xxl" | |
| ) | |
| def get_model() -> tuple[AutoModelForCausalLM, AutoTokenizer]: | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).to(DEVICE) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
| return model, tokenizer | |
| model, tokenizer = get_model() | |
| def translate(src_lang: str, tgt_lang: str, query: str) -> str: | |
| text = f"Translate {src_lang} to {tgt_lang}.\nQuery: {query}\nAnswer:" | |
| tokens = tokenizer(text, return_tensors="pt").to(DEVICE) | |
| outputs = model.generate(**tokens) | |
| ans = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return ans.removeprefix(text).strip() | |
| def process_input(from_toki: bool, chosen_language: str, query: str) -> str: | |
| if from_toki: | |
| src = "Toki Pona" | |
| tgt = chosen_language | |
| else: | |
| src = chosen_language | |
| tgt = "Toki Pona" | |
| return translate(src, tgt, query) | |
| def from_toki_handler(chosen_language: str, from_toki: bool): | |
| if from_toki: | |
| lang = "Toki Pona" | |
| label = "Target language" | |
| else: | |
| lang = chosen_language | |
| label = "Source language" | |
| return ( | |
| gr.Radio(choices=LANGUAGE_LIST, label=label), | |
| gr.Text(placeholder=f"Write in {lang}") | |
| ) | |
| def language_handler(chosen_language: str, from_toki: bool): | |
| if from_toki: | |
| lang = "Toki Pona" | |
| else: | |
| lang = chosen_language | |
| return gr.Text(placeholder=f"Write in {lang}") | |
| with gr.Blocks(theme=theme, title="π¬ ilo toki") as demo: | |
| gr.Markdown("# π¬ ilo toki") | |
| from_toki = gr.Checkbox(label="From Toki Pona") | |
| chosen_language = gr.Radio(choices=LANGUAGE_LIST, label="Source language", value="English") | |
| query = gr.Text(placeholder="Write in English", label="Query", max_lines=1) | |
| from_toki.change( | |
| from_toki_handler, | |
| inputs=[chosen_language, from_toki], | |
| outputs=[chosen_language, query] | |
| ) | |
| chosen_language.select( | |
| language_handler, | |
| inputs=[chosen_language, from_toki], | |
| outputs=query | |
| ) | |
| output = gr.Text(show_label=False, placeholder="Translation result", max_lines=1) | |
| query.submit( | |
| process_input, | |
| inputs=[from_toki, chosen_language, query], | |
| outputs=output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |