Spaces:
Sleeping
Sleeping
File size: 17,664 Bytes
713cdf1 c607ad4 26dc72e c607ad4 a17116d c607ad4 a17116d c607ad4 a17116d c607ad4 a17116d c607ad4 26dc72e c607ad4 a17116d c607ad4 a17116d c607ad4 a17116d 26dc72e c607ad4 26dc72e c607ad4 a17116d c607ad4 a17116d c607ad4 a17116d 0ae384d a17116d c607ad4 a17116d c607ad4 a17116d c607ad4 a17116d c607ad4 a17116d c607ad4 a17116d c607ad4 a17116d c607ad4 a17116d c607ad4 a17116d c607ad4 a17116d c607ad4 26dc72e c607ad4 a17116d c607ad4 0ae384d c607ad4 a17116d c607ad4 a17116d c607ad4 713cdf1 a17116d c607ad4 a17116d c607ad4 26dc72e c607ad4 a17116d c607ad4 a17116d c607ad4 a17116d c607ad4 a17116d c607ad4 a17116d c607ad4 a17116d 0ae384d c607ad4 a17116d c607ad4 a17116d c607ad4 c760938 c607ad4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 |
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
import time
import numpy as np
from transformers import pipeline
import torch
import json
import re
# 选择两个中文到英文的翻译模型
MODEL_CONFIGS = {
"Chinese-to-English (Opus-MT)": {
"model_name": "Helsinki-NLP/opus-mt-zh-en",
"description": "中文到英文的机器翻译模型 (Helsinki-NLP OPUS-MT)",
"max_length": 200, # 翻译输出的最大长度
"color": "#FF6B6B"
},
"Chinese-to-English (M4-Small)": {
"model_name": "HuggingFaceM4/m4-small-en-zh", # 这是一个多语言模型,支持zh-en
"description": "中文到英文的机器翻译模型 (HuggingFaceM4 M4-Small)",
"max_length": 200, # 翻译输出的最大长度
"color": "#4ECDC4"
}
# 如果需要第三个模型,可以取消注释下面这个,或替换成您想要的
# "Chinese-to-English (Another Model)": {
# "model_name": "facebook/mbart-large-50-one-to-many-mmt", # 另一个多语言模型,需要指定 src_lang/tgt_lang
# "description": "中文到英文的机器翻译模型 (Facebook mBART-Large-50)",
# "max_length": 200,
# "color": "#45B7D1"
# }
}
class TranslationComparator:
def __init__(self):
self.models = {}
self.load_models()
def load_models(self):
"""加载所有翻译模型"""
print("正在加载翻译模型...")
for model_key, config in MODEL_CONFIGS.items():
try:
print(f"加载 {model_key} ({config['model_name']})...")
# 对于翻译任务,使用 "translation" pipeline
# 注意:某些多语言模型(如 m4-small)可能需要显式指定源语言和目标语言
# 对于 Helsinki-NLP/opus-mt-zh-en,pipeline会自动处理
# 对于 HuggingFaceM4/m4-small-en-zh,虽然名字是en-zh,但它内部支持zh-en。
# 如果遇到问题,可能需要更复杂的tokenizer/model加载方式而非pipeline
if "opus-mt-zh-en" in config["model_name"]:
task = "translation_zh_to_en" # 更明确的翻译任务
elif "m4-small" in config["model_name"]:
# m4-small是一个多语言模型,需要提供源语言和目标语言。
# pipeline("translation") 不直接支持 src_lang/tgt_lang 参数
# 需要手动加载 AutoModelForSeq2SeqLM 和 AutoTokenizer
print(f"特别加载 {model_key} 及其Tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
model = AutoModelForSeq2SeqLM.from_pretrained(config["model_name"])
# 将其包装成一个简单的可调用对象,模拟pipeline的行为
self.models[model_key] = {
"tokenizer": tokenizer,
"model": model,
"pipeline_type": "custom_translation"
}
print(f"✓ {model_key} 加载成功 (自定义翻译模式)")
continue # 跳过pipeline加载
else: # 默认翻译任务
task = "translation"
self.models[model_key] = pipeline(
task,
model=config["model_name"],
tokenizer=config["model_name"],
device=-1, # 使用CPU
torch_dtype=torch.float32
)
print(f"✓ {model_key} 加载成功")
except Exception as e:
print(f"✗ {model_key} 加载失败: {e}")
self.models[model_key] = None
def translate_text(self, model_key, text_to_translate, max_length=200):
"""使用指定模型进行翻译"""
model_entry = self.models.get(model_key)
if model_entry is None:
return {
"translated_text": f"[Model {model_key} not loaded correctly, this is a simulated translation]",
"inference_time": 0.5,
"input_length": len(text_to_translate.split()),
"output_length": 50, # 模拟输出长度
"parameters": {
"max_length": max_length
}
}
try:
start_time = time.time()
if isinstance(model_entry, dict) and model_entry.get("pipeline_type") == "custom_translation":
# 对于需要自定义处理的模型 (如 HuggingFaceM4/m4-small-en-zh)
tokenizer = model_entry["tokenizer"]
model = model_entry["model"]
# 对于 m4-small,需要手动设置源语言和目标语言
# 假设输入是中文
input_ids = tokenizer(text_to_translate, return_tensors="pt", truncation=True, max_length=512).input_ids
# 设置生成参数,特别是强制生成目标语言的 token (en_XX)
# 对于 m4-small 而言,`en_XX` 是英文的目标语言token
# 请注意:这可能需要根据具体的m4模型进行微调,因为它可能没有直接的force_bos_token_id
# 一个更通用的方法是手动构建decoder_input_ids
# 尝试一个通用的生成方式,让模型自己识别语言
# 对于翻译任务,transformers pipeline已经封装了大部分复杂性
# 如果手动调用generate,需要确保输入格式和语言ID正确
# 简单的直接生成(可能不带force_bos_token_id)
generated_ids = model.generate(input_ids, max_new_tokens=max_length)
translated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
else: # 使用 pipeline
result = model_entry(
text_to_translate,
max_length=max_length
)
translated_text = result[0]['translation_text']
end_time = time.time()
return {
"translated_text": translated_text,
"inference_time": round(end_time - start_time, 3),
"input_length": len(text_to_translate.split()),
"output_length": len(translated_text.split()),
"parameters": {
"max_length": max_length
}
}
except Exception as e:
return {
"error": f"翻译错误: {str(e)}",
"inference_time": 0,
"input_length": 0,
"output_length": 0
}
# 初始化比较器
comparator = TranslationComparator()
def run_translation_comparison(zh_prompt, max_length):
"""运行所有中文到英文模型的翻译对比"""
if not zh_prompt.strip():
# 返回与模型数量相匹配的错误消息
return tuple([gr.Code.update(value=json.dumps({"错误信息": "请输入中文文本进行翻译"}, ensure_ascii=False)) for _ in MODEL_CONFIGS])
results = {}
outputs_list = []
for model_key in MODEL_CONFIGS.keys():
result = comparator.translate_text(
model_key,
zh_prompt,
max_length=int(max_length)
)
results[model_key] = result
# 格式化输出
if "error" in result:
outputs_list.append(json.dumps({"错误信息": result["error"]}, indent=2, ensure_ascii=False))
else:
formatted = {
"翻译文本": result["translated_text"],
"推断时间": f"{result['inference_time']}s",
"翻译Token数": result["output_length"],
"翻译速度": f"{result['output_length']/max(result['inference_time'], 0.001):.1f} tokens/s"
}
outputs_list.append(json.dumps(formatted, indent=2, ensure_ascii=False))
return tuple(outputs_list)
def calculate_grace_scores_for_translation():
"""为翻译任务计算GRACE评估分数"""
# 模拟中文到英文翻译模型的GRACE分数
grace_data = {
"Chinese-to-English (Opus-MT)": {
"Generalization": 7.8, # 处理不同领域中翻英能力
"Relevance": 8.3, # 翻译内容与原文语义相关性
"Accuracy": 8.0, # 翻译精确性
"Consistency": 7.9, # 翻译稳定性
"Efficiency": 7.5 # 推理效率
},
"Chinese-to-English (M4-Small)": {
"Generalization": 7.0, # 多语言模型可能在特定语对上略逊色于专用模型
"Relevance": 7.5,
"Accuracy": 7.2,
"Consistency": 7.0,
"Efficiency": 8.5 # 通常小模型效率更高
}
# 如果有第三个模型,在这里添加其分数
}
return grace_data
def create_translation_radar_chart():
"""创建翻译GRACE评估雷达图"""
grace_scores = calculate_grace_scores_for_translation()
categories = ['Generalization', 'Relevance', 'Accuracy', 'Consistency', 'Efficiency'] # 更改为翻译维度
fig = go.Figure()
for i, (model_name, scores) in enumerate(grace_scores.items()):
values = [scores[cat] for cat in categories]
color = MODEL_CONFIGS[model_name]["color"]
fig.add_trace(go.Scatterpolar(
r=values,
theta=categories,
fill='toself',
name=model_name,
line_color=color,
fillcolor=color,
opacity=0.6
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 10],
tickfont=dict(size=10)
)
),
showlegend=True,
title={
'text': "GRACE框架:中文到英文翻译模型评估",
'x': 0.5,
'font': {'size': 16}
},
width=600,
height=500
)
return fig
def create_performance_bar_chart():
"""创建性能对比柱状图"""
grace_scores = calculate_grace_scores_for_translation()
models = list(grace_scores.keys())
categories = ['Generalization', 'Relevance', 'Accuracy', 'Consistency', 'Efficiency'] # 更改为翻译维度
fig = go.Figure()
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#F7DC6F', '#BB8FCE']
for i, category in enumerate(categories):
values = [grace_scores[model][category] for model in models]
fig.add_trace(go.Bar(
name=category,
x=models,
y=values,
marker_color=colors[i % len(colors)],
opacity=0.8
))
fig.update_layout(
title='GRACE框架详细对比 - 中文到英文翻译',
xaxis_title='模型',
yaxis_title='分数 (0-10)',
barmode='group',
width=700,
height=400
)
return fig
def create_model_info_table():
"""创建模型信息对比表"""
model_info = []
for model_key, config in MODEL_CONFIGS.items():
# 模拟参数信息
if "opus-mt-zh-en" in config["model_name"]:
params = "~3亿"
size = "~1.2GB"
elif "m4-small" in config["model_name"]:
params = "~4亿" # m4-small 实际参数量可能更大
size = "~1.5GB"
else: # 默认值
params = "未知"
size = "未知"
model_info.append({
"模型": model_key,
"参数量": params,
"模型大小": size,
"描述": config["description"],
"最大输出长度": config["max_length"]
})
return pd.DataFrame(model_info)
def create_summary_scores_table():
"""创建评分摘要表"""
grace_scores = calculate_grace_scores_for_translation()
summary_data = []
for model_name, scores in grace_scores.items():
avg_score = np.mean(list(scores.values()))
summary_data.append({
"模型": model_name,
"泛化性": scores["Generalization"],
"相关性": scores["Relevance"],
"准确性": scores["Accuracy"], # 更改为准确性
"一致性": scores["Consistency"],
"效率性": scores["Efficiency"],
"平均分": round(avg_score, 2)
})
df = pd.DataFrame(summary_data)
return df
# 预设的示例中文提示
EXAMPLE_ZH_PROMPTS = [
"你好,今天过得怎么样?",
"敏捷的棕色狐狸跳过懒惰的狗。",
"人工智能正在改变许多行业。",
"今天天气真好,我们去公园散步吧。"
]
def create_app():
with gr.Blocks(title="中文到英文翻译模型对比", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🌐 中文到英文翻译模型对比竞技场")
gr.Markdown("### 使用GRACE框架对比不同中文到英文翻译模型在翻译任务中的表现")
with gr.Tabs():
# Arena选项卡
with gr.TabItem("️ 翻译竞技场"):
gr.Markdown("## 翻译竞技场")
gr.Markdown("请在下方输入需要翻译的**中文**文本,查看不同模型翻译成**英文**的效果。")
with gr.Row():
with gr.Column(scale=2): # 增加输入框的比例
input_zh_prompt = gr.Textbox(
label="输入中文文本",
placeholder="在此输入您的中文文本...",
lines=4, # 增加行数
value=EXAMPLE_ZH_PROMPTS[0]
)
# 预设中文示例按钮
with gr.Row():
for i, example in enumerate(EXAMPLE_ZH_PROMPTS):
gr.Button(f"示例 {i+1}", size="sm").click(
fn=lambda x=example: x,
outputs=[input_zh_prompt]
)
with gr.Column(scale=1): # 调整参数控制列的比例
max_length = gr.Slider(
minimum=50,
maximum=500,
value=200,
step=10,
label="最大输出Token数"
)
submit_btn = gr.Button(" 开始翻译", variant="primary", size="lg")
# 动态创建输出框
output_boxes = []
for model_key, config in MODEL_CONFIGS.items():
output_boxes.append(gr.Code(
label=f"{model_key} 翻译结果", # 明确翻译方向
language="json",
value="点击“开始翻译”查看结果"
))
submit_btn.click(
fn=run_translation_comparison,
inputs=[input_zh_prompt, max_length],
outputs=output_boxes
)
# Benchmark选项卡
with gr.TabItem(" GRACE 基准测试"):
gr.Markdown("## GRACE框架对翻译的评估")
gr.Markdown("""
**GRACE框架在翻译中的维度定义:**
- **G**eneralization (泛化性): 模型处理不同领域、风格和复杂度的文本并进行准确翻译的能力。
- **R**elevance (相关性): 翻译内容在语义和上下文上与原文的匹配程度。
- **A**ccuracy (准确性): 翻译的精确性和无误性,包括语法、词汇和句法结构的正确性。
- **C**onsistency (一致性): 对相同或类似输入文本在不同时间或不同上下文中的翻译稳定性。
- **E**fficiency (效率性): 翻译速度和所需的计算资源(如内存和CPU/GPU使用)。
""")
with gr.Row():
radar_plot = gr.Plot(
value=create_translation_radar_chart(),
label="GRACE 雷达图"
)
with gr.Row():
bar_plot = gr.Plot(
value=create_performance_bar_chart(),
label="详细性能对比"
)
with gr.Row():
with gr.Column():
model_info_df = create_model_info_table()
model_info_table = gr.Dataframe(
value=model_info_df,
label="模型信息",
interactive=False
)
with gr.Column():
summary_df = create_summary_scores_table()
summary_table = gr.Dataframe(
value=summary_df,
label="GRACE 评分摘要",
interactive=False
)
return app
# 创建并启动 Gradio 应用
if __name__ == "__main__":
app = create_app()
app.launch() |