| | import gradio as gr |
| | from modeling_llama_seq2seq import LlamaCrossAttentionEncDec |
| | from transformers import AutoTokenizer, AutoConfig |
| | from PIL import Image |
| |
|
| | |
| | def load_model(model_name_or_path): |
| | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
| | config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) |
| | model = LlamaCrossAttentionEncDec.from_pretrained(model_name_or_path, config=config) |
| | return tokenizer, model |
| |
|
| | tokenizer, model = load_model('path_to_your_model') |
| |
|
| | |
| | def translate_text(src_text, src_lang, tgt_lang, task_type, term_text=None, mt_text=None): |
| | if task_type == "常规翻译": |
| | prompt = f"Translate the following text from {src_lang} into {tgt_lang}.\n{src_lang}: {src_text}\n{tgt_lang}: " |
| | elif task_type == "术语受限翻译": |
| | prompt = f"Translate the following text from {src_lang} into {tgt_lang} using the provided terminology pairs.\nTerminology pairs: {term_text}\n{src_lang}: {src_text}\n{tgt_lang}: " |
| | elif task_type == "自动后期编辑": |
| | prompt = f"Improve the following machine-generated translation from {src_lang} to {tgt_lang}.\n{src_lang}: {src_text}\n{tgt_lang}: {mt_text}\n{tgt_lang}: " |
| | else: |
| | return "请选择正确的任务类型" |
| |
|
| | input_ids = tokenizer(prompt, return_tensors="pt") |
| | outputs_tokenized = model.generate(**input_ids, num_beams=5, do_sample=False) |
| | outputs = tokenizer.batch_decode(outputs_tokenized, skip_special_tokens=True) |
| | return outputs[0] |
| |
|
| | |
| | def create_interface(): |
| | logo_image = Image.open('path_to_your_logo/logo.png') |
| |
|
| | with gr.Blocks() as demo: |
| | gr.Image(logo_image, elem_id="logo", label="Logo") |
| | gr.Markdown("## 🌎 AI 翻译助手") |
| | |
| | with gr.Row(): |
| | src_text = gr.Textbox(label="输入文本", placeholder="请输入需要翻译的文本") |
| | |
| | with gr.Row(): |
| | src_lang = gr.Dropdown(["English", "Chinese", "French", "German"], label="源语言") |
| | tgt_lang = gr.Dropdown(["Chinese", "English", "French", "German"], label="目标语言") |
| | |
| | with gr.Row(): |
| | task_type = gr.Radio(["常规翻译", "术语受限翻译", "自动后期编辑"], label="任务类型") |
| | term_text = gr.Textbox(label="术语表(术语受限翻译)", visible=False) |
| | mt_text = gr.Textbox(label="机器翻译结果(自动后期编辑)", visible=False) |
| | |
| | def show_extra_fields(task_type): |
| | return (task_type == "术语受限翻译", task_type == "自动后期编辑") |
| | |
| | task_type.change(show_extra_fields, inputs=[task_type], outputs=[term_text, mt_text]) |
| | output_text = gr.Textbox(label="翻译结果") |
| |
|
| | translate_button = gr.Button("翻译") |
| | translate_button.click(translate_text, inputs=[src_text, src_lang, tgt_lang, task_type, term_text, mt_text], outputs=output_text) |
| |
|
| | demo.launch(server_name="0.0.0.0", server_port=8080, share=True) |
| |
|
| | if __name__ == "__main__": |
| | create_interface() |
| |
|