import gradio as gr from transformers import MT5Tokenizer, MT5ForConditionalGeneration import torch css_code = f""" .container {{ max-width: 1500px; margin: auto; padding: 20px; background-color: rgba(255, 255, 255, 0.8); /* 半透明背景使内容更清晰 */ border-radius: 8px; }} .title {{ text-align: center; margin-bottom: 30px; }} .title h1 {{ color: #2c3e50; }} .title h2 {{ color: #34495e; }} .gr-button {{ border-radius: 8px; }} .footer {{ max-width: 1500px; text-align: center; margin-top: 30px; color: #666; }} """ loaded_models = {} def get_model_and_tokenizer(model_choice, device): model_repo = "MHBS-IHB/fish-mt5" if model_choice in loaded_models: return loaded_models[model_choice] tokenizer = MT5Tokenizer.from_pretrained(model_repo, subfolder=model_choice) model = MT5ForConditionalGeneration.from_pretrained(model_repo, subfolder=model_choice) model.to(device) loaded_models[model_choice] = (model, tokenizer) return model, tokenizer def predict(model, tokenizer, input_text, device, translation_type): model.eval() if translation_type == "Chinese to Latin": input_text = f"translate Chinese to Latin: {input_text}" elif translation_type == "Latin to Chinese": input_text = f"translate Latin to Chinese: {input_text}" inputs = tokenizer( input_text, return_tensors='pt', max_length=100, truncation=True ).to(device) with torch.no_grad(): outputs = model.generate( input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], max_length=50, num_beams=5, early_stopping=True ) return tokenizer.decode(outputs[0], skip_special_tokens=True) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") def gradio_predict(input_text, translation_type, model_choice): model, tokenizer = get_model_and_tokenizer(model_choice, DEVICE) return predict(model=model, tokenizer=tokenizer, input_text=input_text, device=DEVICE, translation_type=translation_type) def clear_text(): return "", "" with gr.Blocks( theme=gr.themes.Soft(), css=css_code ) as iface: # 顶部标题 gr.Markdown( """

🐟 世界鱼类拉汉互译 🐠

Dual Latin-Chinese Translation of Global Fish Species

""", elem_classes="container" ) with gr.Row(elem_classes="container"): with gr.Column(scale=1): input_text = gr.Textbox( label="Enter Fish Species Names (Chinese or Latin)", placeholder="例如:中华鲟 / Acipenser sinensis" ) # 翻译方向改为“Bidirectional Translation”并提供其他备选项 translation_type = gr.Radio( choices=["Chinese to Latin", "Latin to Chinese"], label="Bidirectional translation", value="Chinese to Latin" ) with gr.Column(scale=1): output_text = gr.Textbox( label="Translation result" ) model_choice = gr.Dropdown( choices=["fish_mt5_small", "fish_mt5_base", "fish_mt5_large"], label="Select Model", value="fish_mt5_large" ) with gr.Row(elem_classes="container"): translate_btn = gr.Button("Translate 🔄", variant="primary") clear_btn = gr.Button("Clear 🗑️", variant="secondary") gr.Markdown( """ """, elem_classes="container" ) translate_btn.click( fn=gradio_predict, inputs=[input_text, translation_type, model_choice], outputs=output_text ) clear_btn.click( fn=clear_text, inputs=[], outputs=[input_text, output_text] ) gr.Examples( examples=[ ["中华鲟", "Chinese to Latin", "fish_mt5_small"], ["Acipenser sinensis", "Latin to Chinese", "fish_mt5_small"] ], inputs=[input_text, translation_type, model_choice], outputs=output_text, label="📋 Translation Examples" ) if __name__ == "__main__": iface.launch()