File size: 4,622 Bytes
4cf75da 234eea5 4cf75da 234eea5 4cf75da 6d3e976 4cf75da 2505a29 4cf75da f5e7875 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | import gradio as gr
from transformers import MT5Tokenizer, MT5ForConditionalGeneration
import torch
css_code = f"""
.container {{
max-width: 1500px;
margin: auto;
padding: 20px;
background-color: rgba(255, 255, 255, 0.8); /* 半透明背景使内容更清晰 */
border-radius: 8px;
}}
.title {{ text-align: center; margin-bottom: 30px; }}
.title h1 {{ color: #2c3e50; }}
.title h2 {{ color: #34495e; }}
.gr-button {{ border-radius: 8px; }}
.footer {{ max-width: 1500px; text-align: center; margin-top: 30px; color: #666; }}
"""
loaded_models = {}
def get_model_and_tokenizer(model_choice, device):
model_repo = "MHBS-IHB/fish-mt5"
if model_choice in loaded_models:
return loaded_models[model_choice]
tokenizer = MT5Tokenizer.from_pretrained(model_repo, subfolder=model_choice)
model = MT5ForConditionalGeneration.from_pretrained(model_repo, subfolder=model_choice)
model.to(device)
loaded_models[model_choice] = (model, tokenizer)
return model, tokenizer
def predict(model, tokenizer, input_text, device, translation_type):
model.eval()
if translation_type == "Chinese to Latin":
input_text = f"translate Chinese to Latin: {input_text}"
elif translation_type == "Latin to Chinese":
input_text = f"translate Latin to Chinese: {input_text}"
inputs = tokenizer(
input_text,
return_tensors='pt',
max_length=100,
truncation=True
).to(device)
with torch.no_grad():
outputs = model.generate(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_length=50,
num_beams=5,
early_stopping=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def gradio_predict(input_text, translation_type, model_choice):
model, tokenizer = get_model_and_tokenizer(model_choice, DEVICE)
return predict(model=model, tokenizer=tokenizer,
input_text=input_text, device=DEVICE,
translation_type=translation_type)
def clear_text():
return "", ""
with gr.Blocks(
theme=gr.themes.Soft(),
css=css_code
) as iface:
# 顶部标题
gr.Markdown(
"""
<div class="title">
<h1>🐟 世界鱼类拉汉互译 🐠</h1>
<h2>Dual Latin-Chinese Translation of Global Fish Species</h2>
</div>
""",
elem_classes="container"
)
with gr.Row(elem_classes="container"):
with gr.Column(scale=1):
input_text = gr.Textbox(
label="Enter Fish Species Names (Chinese or Latin)",
placeholder="例如:中华鲟 / Acipenser sinensis"
)
# 翻译方向改为“Bidirectional Translation”并提供其他备选项
translation_type = gr.Radio(
choices=["Chinese to Latin", "Latin to Chinese"],
label="Bidirectional translation",
value="Chinese to Latin"
)
with gr.Column(scale=1):
output_text = gr.Textbox(
label="Translation result"
)
model_choice = gr.Dropdown(
choices=["fish_mt5_small", "fish_mt5_base", "fish_mt5_large"],
label="Select Model",
value="fish_mt5_large"
)
with gr.Row(elem_classes="container"):
translate_btn = gr.Button("Translate 🔄", variant="primary")
clear_btn = gr.Button("Clear 🗑️", variant="secondary")
gr.Markdown(
"""
<div class="footer">
<p>🌊 Powered by fine-tuned MT5 Model | The model might take a while to load for the first time, so please be patient—but once it's loaded, translations are lightning fast! 🚀</p>
</div>
""",
elem_classes="container"
)
translate_btn.click(
fn=gradio_predict,
inputs=[input_text, translation_type, model_choice],
outputs=output_text
)
clear_btn.click(
fn=clear_text,
inputs=[],
outputs=[input_text, output_text]
)
gr.Examples(
examples=[
["中华鲟", "Chinese to Latin", "fish_mt5_small"],
["Acipenser sinensis", "Latin to Chinese", "fish_mt5_small"]
],
inputs=[input_text, translation_type, model_choice],
outputs=output_text,
label="📋 Translation Examples"
)
if __name__ == "__main__":
iface.launch()
|