| | import gradio as gr |
| | from transformers import MT5Tokenizer, MT5ForConditionalGeneration |
| | import torch |
| |
|
| | css_code = f""" |
| | .container {{ |
| | max-width: 1500px; |
| | margin: auto; |
| | padding: 20px; |
| | background-color: rgba(255, 255, 255, 0.8); /* 半透明背景使内容更清晰 */ |
| | border-radius: 8px; |
| | }} |
| | .title {{ text-align: center; margin-bottom: 30px; }} |
| | .title h1 {{ color: #2c3e50; }} |
| | .title h2 {{ color: #34495e; }} |
| | .gr-button {{ border-radius: 8px; }} |
| | .footer {{ max-width: 1500px; text-align: center; margin-top: 30px; color: #666; }} |
| | """ |
| |
|
| | loaded_models = {} |
| |
|
| | def get_model_and_tokenizer(model_choice, device): |
| | model_repo = "MHBS-IHB/fish-mt5" |
| | if model_choice in loaded_models: |
| | return loaded_models[model_choice] |
| | |
| | tokenizer = MT5Tokenizer.from_pretrained(model_repo, subfolder=model_choice) |
| | model = MT5ForConditionalGeneration.from_pretrained(model_repo, subfolder=model_choice) |
| | model.to(device) |
| | loaded_models[model_choice] = (model, tokenizer) |
| | return model, tokenizer |
| |
|
| | def predict(model, tokenizer, input_text, device, translation_type): |
| | model.eval() |
| | if translation_type == "Chinese to Latin": |
| | input_text = f"translate Chinese to Latin: {input_text}" |
| | elif translation_type == "Latin to Chinese": |
| | input_text = f"translate Latin to Chinese: {input_text}" |
| | |
| | inputs = tokenizer( |
| | input_text, |
| | return_tensors='pt', |
| | max_length=100, |
| | truncation=True |
| | ).to(device) |
| | |
| | with torch.no_grad(): |
| | outputs = model.generate( |
| | input_ids=inputs['input_ids'], |
| | attention_mask=inputs['attention_mask'], |
| | max_length=50, |
| | num_beams=5, |
| | early_stopping=True |
| | ) |
| | |
| | return tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
|
| | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | def gradio_predict(input_text, translation_type, model_choice): |
| | model, tokenizer = get_model_and_tokenizer(model_choice, DEVICE) |
| | return predict(model=model, tokenizer=tokenizer, |
| | input_text=input_text, device=DEVICE, |
| | translation_type=translation_type) |
| |
|
| | def clear_text(): |
| | return "", "" |
| |
|
| | with gr.Blocks( |
| | theme=gr.themes.Soft(), |
| | css=css_code |
| |
|
| | ) as iface: |
| | |
| | gr.Markdown( |
| | """ |
| | <div class="title"> |
| | <h1>🐟 世界鱼类拉汉互译 🐠</h1> |
| | <h2>Dual Latin-Chinese Translation of Global Fish Species</h2> |
| | </div> |
| | """, |
| | elem_classes="container" |
| | ) |
| |
|
| | with gr.Row(elem_classes="container"): |
| | with gr.Column(scale=1): |
| | input_text = gr.Textbox( |
| | label="Enter Fish Species Names (Chinese or Latin)", |
| | placeholder="例如:中华鲟 / Acipenser sinensis" |
| | ) |
| | |
| | translation_type = gr.Radio( |
| | choices=["Chinese to Latin", "Latin to Chinese"], |
| | label="Bidirectional translation", |
| | value="Chinese to Latin" |
| | ) |
| | with gr.Column(scale=1): |
| | output_text = gr.Textbox( |
| | label="Translation result" |
| | ) |
| | model_choice = gr.Dropdown( |
| | choices=["fish_mt5_small", "fish_mt5_base", "fish_mt5_large"], |
| | label="Select Model", |
| | value="fish_mt5_large" |
| | ) |
| | |
| | with gr.Row(elem_classes="container"): |
| | translate_btn = gr.Button("Translate 🔄", variant="primary") |
| | clear_btn = gr.Button("Clear 🗑️", variant="secondary") |
| | |
| | |
| | gr.Markdown( |
| | """ |
| | <div class="footer"> |
| | <p>🌊 Powered by fine-tuned MT5 Model | The model might take a while to load for the first time, so please be patient—but once it's loaded, translations are lightning fast! 🚀</p> |
| | </div> |
| | """, |
| | elem_classes="container" |
| | ) |
| | |
| | |
| | translate_btn.click( |
| | fn=gradio_predict, |
| | inputs=[input_text, translation_type, model_choice], |
| | outputs=output_text |
| | ) |
| | clear_btn.click( |
| | fn=clear_text, |
| | inputs=[], |
| | outputs=[input_text, output_text] |
| | ) |
| | |
| | gr.Examples( |
| | examples=[ |
| | ["中华鲟", "Chinese to Latin", "fish_mt5_small"], |
| | ["Acipenser sinensis", "Latin to Chinese", "fish_mt5_small"] |
| | ], |
| | inputs=[input_text, translation_type, model_choice], |
| | outputs=output_text, |
| | label="📋 Translation Examples" |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | iface.launch() |
| |
|