import gradio as gr from transformers import MT5Tokenizer, MT5ForConditionalGeneration import torch css_code = f""" .container {{ max-width: 1500px; margin: auto; padding: 20px; background-color: rgba(255, 255, 255, 0.8); /* 半透明背景使内容更清晰 */ border-radius: 8px; }} .title {{ text-align: center; margin-bottom: 30px; }} .title h1 {{ color: #2c3e50; }} .title h2 {{ color: #34495e; }} .gr-button {{ border-radius: 8px; }} .footer {{ max-width: 1500px; text-align: center; margin-top: 30px; color: #666; }} """ loaded_models = {} def get_model_and_tokenizer(model_choice, device): model_repo = "MHBS-IHB/fish-mt5" if model_choice in loaded_models: return loaded_models[model_choice] tokenizer = MT5Tokenizer.from_pretrained(model_repo, subfolder=model_choice) model = MT5ForConditionalGeneration.from_pretrained(model_repo, subfolder=model_choice) model.to(device) loaded_models[model_choice] = (model, tokenizer) return model, tokenizer def predict(model, tokenizer, input_text, device, translation_type): model.eval() if translation_type == "Chinese to Latin": input_text = f"translate Chinese to Latin: {input_text}" elif translation_type == "Latin to Chinese": input_text = f"translate Latin to Chinese: {input_text}" inputs = tokenizer( input_text, return_tensors='pt', max_length=100, truncation=True ).to(device) with torch.no_grad(): outputs = model.generate( input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], max_length=50, num_beams=5, early_stopping=True ) return tokenizer.decode(outputs[0], skip_special_tokens=True) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") def gradio_predict(input_text, translation_type, model_choice): model, tokenizer = get_model_and_tokenizer(model_choice, DEVICE) return predict(model=model, tokenizer=tokenizer, input_text=input_text, device=DEVICE, translation_type=translation_type) def clear_text(): return "", "" with gr.Blocks( theme=gr.themes.Soft(), css=css_code ) as iface: # 顶部标题 gr.Markdown( """