|
|
import os, time, re |
|
|
import gradio as gr |
|
|
import supabase |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
MODELS = ['VuAI/khi-van', 'VuAI/tuan', 'VuAI/vi2vi_vn98'] |
|
|
|
|
|
|
|
|
supabase_url = os.getenv("supabase_url") |
|
|
supabase_token = os.getenv("supabase_token") |
|
|
client = supabase.create_client(supabase_url, supabase_token) |
|
|
|
|
|
def inference(input_sentence, name_model): |
|
|
|
|
|
access_token = os.getenv("access_token") |
|
|
tokenizer = AutoTokenizer.from_pretrained(name_model, use_auth_token=access_token) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(name_model, use_auth_token=access_token) |
|
|
|
|
|
|
|
|
input_ids = tokenizer.encode(input_sentence, return_tensors="pt") |
|
|
|
|
|
|
|
|
translation = model.generate(input_ids=input_ids, max_length = 512)[0] |
|
|
|
|
|
|
|
|
translated_text = tokenizer.decode(translation, skip_special_tokens=True) |
|
|
|
|
|
return translated_text.replace('▁', ' ') |
|
|
|
|
|
def translator(input_sentence, name_model): |
|
|
start_time = time.time() |
|
|
|
|
|
input_sentence = input_sentence.replace('【', '[').replace('】', ']').replace('“', '"').replace('”', '"').replace('." ', '."\n').replace('?" ', '?"\n').replace('!" ', '!"\n').replace(': "', ':\n"').replace(', "', ':\n"') |
|
|
input_sentence = input_sentence.split('\n') |
|
|
|
|
|
result = '' |
|
|
percent_complete = 0 |
|
|
for i in input_sentence: |
|
|
if i != '' and i != ' ': |
|
|
if i.find('"') == 0 and i.find('"',1) == (len(i)-1): |
|
|
str_out = '"' + inference(i.replace('"',''), name_model) + '"' |
|
|
elif i.find('[') == 0: |
|
|
str_out = '[' + inference(i.replace('[','').replace(']',''), name_model) + ']' |
|
|
else: |
|
|
str_out = inference(i, name_model) |
|
|
|
|
|
result += str_out + '\n' |
|
|
|
|
|
percent_complete += 1/len(input_sentence) |
|
|
|
|
|
title_str = re.findall('^Chương\s\d+:\s.*\n', result) |
|
|
if title_str: |
|
|
result = result.replace(title_str[0], title_str[0].title()) |
|
|
|
|
|
end_time = time.time() |
|
|
process_time = end_time - start_time |
|
|
|
|
|
|
|
|
return result |
|
|
|
|
|
def words(input_sentence): |
|
|
len_inp = len(input_sentence.replace('\n', ' ').replace(' ', ' ').replace(' ', ' ').strip().split(' ')) |
|
|
return 'Words: ' + str(len_inp) + ' / 9999' |
|
|
|
|
|
def login(username, password): |
|
|
|
|
|
data = client.auth.sign_in_with_password({"email": username, "password": password}) |
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
curr_words = 9999 |
|
|
with gr.Blocks(title="VN98 Translator", theme=gr.themes.Default(primary_hue="red", secondary_hue="red")) as demo: |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=20): |
|
|
gr.Text('VN98 Translation', show_label=False).style(container=False) |
|
|
with gr.Column(scale=1, min_width=50): |
|
|
darkBtn = gr.Button("Dark", variant='secondary', elem_id="darkTheme") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
inp = gr.Textbox(lines=5, label="Input Text") |
|
|
inp.change(fn=words, inputs=inp, outputs=gr.Text(show_label=False, value='Words: 0 / ' + str(curr_words))) |
|
|
sel = gr.Dropdown(label="Select Model?", choices=MODELS, value=MODELS[0]) |
|
|
btn = gr.Button("Run", variant='primary') |
|
|
with gr.Column(): |
|
|
out = gr.Textbox(lines=5, label="Result", interactive=True).style(show_copy_button=True) |
|
|
|
|
|
changeTheme = """function changeTheme() { |
|
|
body = document.querySelector('body').classList.toggle('dark'); |
|
|
btn_theme = document.querySelector('#darkTheme'); |
|
|
btn_theme.innerText = (btn_theme.innerText == 'Light') ? 'Dark' : 'Light'; |
|
|
} |
|
|
""" |
|
|
|
|
|
btn.click(translator, inputs=[inp, sel], outputs=[out]) |
|
|
darkBtn.click(fn=None, _js=changeTheme) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(auth=login, auth_message="VN98 Translator.", max_threads=80) |