File size: 4,622 Bytes
4cf75da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234eea5
4cf75da
 
 
234eea5
 
4cf75da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d3e976
4cf75da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2505a29
 
4cf75da
 
 
 
 
 
 
f5e7875
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
from transformers import MT5Tokenizer, MT5ForConditionalGeneration
import torch

css_code = f"""
.container {{ 
    max-width: 1500px; 
    margin: auto; 
    padding: 20px; 
    background-color: rgba(255, 255, 255, 0.8); /* 半透明背景使内容更清晰 */
    border-radius: 8px;
}}
.title {{ text-align: center; margin-bottom: 30px; }}
.title h1 {{ color: #2c3e50; }}
.title h2 {{ color: #34495e; }}
.gr-button {{ border-radius: 8px; }}
.footer {{ max-width: 1500px; text-align: center; margin-top: 30px; color: #666; }}
"""

loaded_models = {}

def get_model_and_tokenizer(model_choice, device):
    model_repo = "MHBS-IHB/fish-mt5"
    if model_choice in loaded_models:
        return loaded_models[model_choice]
    
    tokenizer = MT5Tokenizer.from_pretrained(model_repo, subfolder=model_choice)
    model = MT5ForConditionalGeneration.from_pretrained(model_repo, subfolder=model_choice)
    model.to(device)
    loaded_models[model_choice] = (model, tokenizer)
    return model, tokenizer

def predict(model, tokenizer, input_text, device, translation_type):
    model.eval()
    if translation_type == "Chinese to Latin":
        input_text = f"translate Chinese to Latin: {input_text}"
    elif translation_type == "Latin to Chinese":
        input_text = f"translate Latin to Chinese: {input_text}"
    
    inputs = tokenizer(
        input_text,
        return_tensors='pt',
        max_length=100,
        truncation=True
    ).to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_length=50,
            num_beams=5,
            early_stopping=True
        )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def gradio_predict(input_text, translation_type, model_choice):
    model, tokenizer = get_model_and_tokenizer(model_choice, DEVICE)
    return predict(model=model, tokenizer=tokenizer, 
                   input_text=input_text, device=DEVICE, 
                   translation_type=translation_type)

def clear_text():
    return "", ""

with gr.Blocks(
    theme=gr.themes.Soft(),
    css=css_code

) as iface:
    # 顶部标题
    gr.Markdown(
        """
        <div class="title">
            <h1>🐟 世界鱼类拉汉互译 🐠</h1>
            <h2>Dual Latin-Chinese Translation of Global Fish Species</h2>
        </div>
        """,
        elem_classes="container"
    )

    with gr.Row(elem_classes="container"):
        with gr.Column(scale=1):
            input_text = gr.Textbox(
                label="Enter Fish Species Names (Chinese or Latin)", 
                placeholder="例如:中华鲟 / Acipenser sinensis"
            )
            # 翻译方向改为“Bidirectional Translation”并提供其他备选项
            translation_type = gr.Radio(
                choices=["Chinese to Latin", "Latin to Chinese"],
                label="Bidirectional translation",
                value="Chinese to Latin"
            )
        with gr.Column(scale=1):
            output_text = gr.Textbox(
                label="Translation result"
            )
            model_choice = gr.Dropdown(
            choices=["fish_mt5_small", "fish_mt5_base", "fish_mt5_large"],
            label="Select Model",
            value="fish_mt5_large"
            )  
    
    with gr.Row(elem_classes="container"):
        translate_btn = gr.Button("Translate 🔄", variant="primary")
        clear_btn = gr.Button("Clear 🗑️", variant="secondary")
    
    
    gr.Markdown(
        """
        <div class="footer">
            <p>🌊 Powered by fine-tuned MT5 Model | The model might take a while to load for the first time, so please be patient—but once it's loaded, translations are lightning fast! 🚀</p>
        </div>
        """,
        elem_classes="container"
    )
    
    
    translate_btn.click(
        fn=gradio_predict,
        inputs=[input_text, translation_type, model_choice],
        outputs=output_text
    )
    clear_btn.click(
        fn=clear_text,
        inputs=[],
        outputs=[input_text, output_text]
    )
    
    gr.Examples(
        examples=[
            ["中华鲟", "Chinese to Latin", "fish_mt5_small"],
            ["Acipenser sinensis", "Latin to Chinese", "fish_mt5_small"]
        ],
        inputs=[input_text, translation_type, model_choice],
        outputs=output_text,
        label="📋 Translation Examples"
    )

if __name__ == "__main__":
    iface.launch()