hellokawei commited on
Commit
e7a2656
·
verified ·
1 Parent(s): a70ed43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -43
app.py CHANGED
@@ -4,7 +4,8 @@ import plotly.graph_objects as go
4
  import plotly.express as px
5
  import time
6
  import numpy as np
7
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
 
8
  import torch
9
  import json
10
  import re
@@ -23,6 +24,13 @@ MODEL_CONFIGS = {
23
  "max_length": 200, # 翻译输出的最大长度
24
  "color": "#4ECDC4"
25
  }
 
 
 
 
 
 
 
26
  }
27
 
28
  class TranslationComparator:
@@ -46,8 +54,8 @@ class TranslationComparator:
46
  tokenizer=config["model_name"],
47
  src_lang="zh_CN", # 源语言为中文
48
  tgt_lang="en_US", # 目标语言为英文
49
- device=-1, # 使用CPU
50
- torch_dtype=torch.float32
51
  )
52
  else: # 对于Helsinki-NLP/opus-mt-zh-en等
53
  self.models[model_key] = pipeline(
@@ -78,38 +86,18 @@ class TranslationComparator:
78
 
79
  try:
80
  start_time = time.time()
81
-
82
- if isinstance(model_entry, dict) and model_entry.get("pipeline_type") == "custom_translation":
83
- # 对于需要自定义处理的模型 ( HuggingFaceM4/m4-small-en-zh)
84
- tokenizer = model_entry["tokenizer"]
85
- model = model_entry["model"]
86
-
87
- # 对于 m4-small,需要手动设置源语言和目标语言
88
- # 假设输入是中文
89
- input_ids = tokenizer(text_to_translate, return_tensors="pt", truncation=True, max_length=512).input_ids
90
-
91
- # 设置生成参数,特别是强制生成目标语言的 token (en_XX)
92
- # 对于 m4-small 而言,`en_XX` 是英文的目标语言token
93
- # 请注意:这可能需要根据具体的m4模型进行微调,因为它可能没有直接的force_bos_token_id
94
- # 一个更通用的方法是手动构建decoder_input_ids
95
-
96
- # 尝试一个通用的生成方式,让模型自己识别语言
97
- # 对于翻译任务,transformers pipeline已经封装了大部分复杂性
98
- # 如果手动调用generate,需要确保输入格式和语言ID正确
99
-
100
- # 简单的直接生成(可能不带force_bos_token_id)
101
- generated_ids = model.generate(input_ids, max_new_tokens=max_length)
102
- translated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
103
-
104
- else: # 使用 pipeline
105
- result = model_entry(
106
- text_to_translate,
107
- max_length=max_length
108
- )
109
- translated_text = result[0]['translation_text']
110
 
111
  end_time = time.time()
112
 
 
 
113
  return {
114
  "translated_text": translated_text,
115
  "inference_time": round(end_time - start_time, 3),
@@ -175,14 +163,13 @@ def calculate_grace_scores_for_translation():
175
  "Consistency": 7.9, # 翻译稳定性
176
  "Efficiency": 7.5 # 推理效率
177
  },
178
- "Chinese-to-English (M4-Small)": {
179
- "Generalization": 7.0, # 多语言模型可能在特定语对上略逊色于专用模型
180
- "Relevance": 7.5,
181
- "Accuracy": 7.2,
182
- "Consistency": 7.0,
183
- "Efficiency": 8.5 # 通常小模型效率更高
184
  }
185
- # 如果有第三个模型,在这里添加其分数
186
  }
187
  return grace_data
188
 
@@ -196,7 +183,9 @@ def create_translation_radar_chart():
196
 
197
  for i, (model_name, scores) in enumerate(grace_scores.items()):
198
  values = [scores[cat] for cat in categories]
199
- color = MODEL_CONFIGS[model_name]["color"]
 
 
200
 
201
  fig.add_trace(go.Scatterpolar(
202
  r=values,
@@ -264,9 +253,9 @@ def create_model_info_table():
264
  if "opus-mt-zh-en" in config["model_name"]:
265
  params = "~3亿"
266
  size = "~1.2GB"
267
- elif "m4-small" in config["model_name"]:
268
- params = "~4亿" # m4-small 实际参数量可能更大
269
- size = "~1.5GB"
270
  else: # 默认值
271
  params = "未知"
272
  size = "未知"
 
4
  import plotly.express as px
5
  import time
6
  import numpy as np
7
+ # 导入 AutoTokenizer AutoModelForSeq2SeqLM
8
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
9
  import torch
10
  import json
11
  import re
 
24
  "max_length": 200, # 翻译输出的最大长度
25
  "color": "#4ECDC4"
26
  }
27
+ # 如果需要第三个模型,可以取消注释下面这个,或替换成您想要的
28
+ # "Chinese-to-English (Another Model)": {
29
+ # "model_name": "facebook/mbart-large-50-one-to-many-mmt", # 另一个多语言模型,需要指定 src_lang/tgt_lang
30
+ # "description": "中文到英文的机器翻译模型 (Facebook mBART-Large-50)",
31
+ # "max_length": 200,
32
+ # "color": "#45B7D1"
33
+ # }
34
  }
35
 
36
  class TranslationComparator:
 
54
  tokenizer=config["model_name"],
55
  src_lang="zh_CN", # 源语言为中文
56
  tgt_lang="en_US", # 目标语言为英文
57
+ device=-1, # 使用CPU,避免GPU内存不足问题
58
+ torch_dtype=torch.float32 # 保持一致,或根据模型精度调整
59
  )
60
  else: # 对于Helsinki-NLP/opus-mt-zh-en等
61
  self.models[model_key] = pipeline(
 
86
 
87
  try:
88
  start_time = time.time()
89
+
90
+ # 翻译文本
91
+ # pipeline("translation") 的返回格式是 [{"translation_text": "..."}]
92
+ result = model_entry( # 直接使用 model_entry,因为现在都是pipeline对象
93
+ text_to_translate,
94
+ max_length=max_length
95
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  end_time = time.time()
98
 
99
+ translated_text = result[0]['translation_text']
100
+
101
  return {
102
  "translated_text": translated_text,
103
  "inference_time": round(end_time - start_time, 3),
 
163
  "Consistency": 7.9, # 翻译稳定性
164
  "Efficiency": 7.5 # 推理效率
165
  },
166
+ "Chinese-to-English (mBART-Large-50)": { # **这里已修改!**
167
+ "Generalization": 8.5, # 更大型多语言模型,泛化性通常更强
168
+ "Relevance": 8.8,
169
+ "Accuracy": 8.6,
170
+ "Consistency": 8.5,
171
+ "Efficiency": 6.0 # 模型较大,效率可能略低
172
  }
 
173
  }
174
  return grace_data
175
 
 
183
 
184
  for i, (model_name, scores) in enumerate(grace_scores.items()):
185
  values = [scores[cat] for cat in categories]
186
+ # **这里使用 MODEL_CONFIGS[model_name]["color"] 依赖于 MODEL_CONFIGS 和 grace_scores 的键名一致**
187
+ # 这是导致之前 KeyError 的地方,现在应该已修复,因为 calculate_grace_scores_for_translation 的键名已更新
188
+ color = MODEL_CONFIGS[model_name]["color"]
189
 
190
  fig.add_trace(go.Scatterpolar(
191
  r=values,
 
253
  if "opus-mt-zh-en" in config["model_name"]:
254
  params = "~3亿"
255
  size = "~1.2GB"
256
+ elif "mbart-large-50" in config["model_name"]: # 修改为mBART的参数
257
+ params = "~6.1亿" # mBART-Large-50 的实际参数量
258
+ size = "~2.4GB" # mBART-Large-50 的实际模型大小
259
  else: # 默认值
260
  params = "未知"
261
  size = "未知"