fishmt5 / app.py
Otolith's picture
Update app.py
6d3e976 verified
import gradio as gr
from transformers import MT5Tokenizer, MT5ForConditionalGeneration
import torch
css_code = f"""
.container {{
max-width: 1500px;
margin: auto;
padding: 20px;
background-color: rgba(255, 255, 255, 0.8); /* 半透明背景使内容更清晰 */
border-radius: 8px;
}}
.title {{ text-align: center; margin-bottom: 30px; }}
.title h1 {{ color: #2c3e50; }}
.title h2 {{ color: #34495e; }}
.gr-button {{ border-radius: 8px; }}
.footer {{ max-width: 1500px; text-align: center; margin-top: 30px; color: #666; }}
"""
loaded_models = {}
def get_model_and_tokenizer(model_choice, device):
model_repo = "MHBS-IHB/fish-mt5"
if model_choice in loaded_models:
return loaded_models[model_choice]
tokenizer = MT5Tokenizer.from_pretrained(model_repo, subfolder=model_choice)
model = MT5ForConditionalGeneration.from_pretrained(model_repo, subfolder=model_choice)
model.to(device)
loaded_models[model_choice] = (model, tokenizer)
return model, tokenizer
def predict(model, tokenizer, input_text, device, translation_type):
model.eval()
if translation_type == "Chinese to Latin":
input_text = f"translate Chinese to Latin: {input_text}"
elif translation_type == "Latin to Chinese":
input_text = f"translate Latin to Chinese: {input_text}"
inputs = tokenizer(
input_text,
return_tensors='pt',
max_length=100,
truncation=True
).to(device)
with torch.no_grad():
outputs = model.generate(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_length=50,
num_beams=5,
early_stopping=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def gradio_predict(input_text, translation_type, model_choice):
model, tokenizer = get_model_and_tokenizer(model_choice, DEVICE)
return predict(model=model, tokenizer=tokenizer,
input_text=input_text, device=DEVICE,
translation_type=translation_type)
def clear_text():
return "", ""
with gr.Blocks(
theme=gr.themes.Soft(),
css=css_code
) as iface:
# 顶部标题
gr.Markdown(
"""
<div class="title">
<h1>🐟 世界鱼类拉汉互译 🐠</h1>
<h2>Dual Latin-Chinese Translation of Global Fish Species</h2>
</div>
""",
elem_classes="container"
)
with gr.Row(elem_classes="container"):
with gr.Column(scale=1):
input_text = gr.Textbox(
label="Enter Fish Species Names (Chinese or Latin)",
placeholder="例如:中华鲟 / Acipenser sinensis"
)
# 翻译方向改为“Bidirectional Translation”并提供其他备选项
translation_type = gr.Radio(
choices=["Chinese to Latin", "Latin to Chinese"],
label="Bidirectional translation",
value="Chinese to Latin"
)
with gr.Column(scale=1):
output_text = gr.Textbox(
label="Translation result"
)
model_choice = gr.Dropdown(
choices=["fish_mt5_small", "fish_mt5_base", "fish_mt5_large"],
label="Select Model",
value="fish_mt5_large"
)
with gr.Row(elem_classes="container"):
translate_btn = gr.Button("Translate 🔄", variant="primary")
clear_btn = gr.Button("Clear 🗑️", variant="secondary")
gr.Markdown(
"""
<div class="footer">
<p>🌊 Powered by fine-tuned MT5 Model | The model might take a while to load for the first time, so please be patient—but once it's loaded, translations are lightning fast! 🚀</p>
</div>
""",
elem_classes="container"
)
translate_btn.click(
fn=gradio_predict,
inputs=[input_text, translation_type, model_choice],
outputs=output_text
)
clear_btn.click(
fn=clear_text,
inputs=[],
outputs=[input_text, output_text]
)
gr.Examples(
examples=[
["中华鲟", "Chinese to Latin", "fish_mt5_small"],
["Acipenser sinensis", "Latin to Chinese", "fish_mt5_small"]
],
inputs=[input_text, translation_type, model_choice],
outputs=output_text,
label="📋 Translation Examples"
)
if __name__ == "__main__":
iface.launch()