hellokawei commited on
Commit
0938f57
·
verified ·
1 Parent(s): 451b0b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -41
app.py CHANGED
@@ -18,15 +18,15 @@ MODEL_CONFIGS = {
18
  "max_length": 200, # 翻译输出的最大长度
19
  "color": "#FF6B6B"
20
  },
21
- "Chinese-to-English (mBART-Large-50)": { # 替换为mBART模型
22
- "model_name": "facebook/mbart-large-50-many-to-one-mmt",
23
- "description": "中文到英文的机器翻译模型 (Facebook mBART-Large-50)",
24
  "max_length": 200, # 翻译输出的最大长度
25
  "color": "#4ECDC4"
26
  }
27
  # 如果需要第三个模型,可以取消注释下面这个,或替换成您想要的
28
  # "Chinese-to-English (Another Model)": {
29
- # "model_name": "facebook/mbart-large-50-one-to-many-mmt", # 另一个多语言模型,需要指定 src_lang/tgt_lang
30
  # "description": "中文到英文的机器翻译模型 (Facebook mBART-Large-50)",
31
  # "max_length": 200,
32
  # "color": "#45B7D1"
@@ -45,26 +45,15 @@ class TranslationComparator:
45
  try:
46
  print(f"加载 {model_key} ({config['model_name']})...")
47
 
48
- # 对于翻译任务,使用 "translation" pipeline
49
- # 注意:mBART模型需要指定 source_lang 和 target_lang
50
- if "mbart-large-50" in config["model_name"]:
51
- self.models[model_key] = pipeline(
52
- "translation",
53
- model=config["model_name"],
54
- tokenizer=config["model_name"],
55
- src_lang="zh_CN", # 源语言为中文
56
- tgt_lang="en_US", # 目标语言为英文
57
- device=-1, # 使用CPU,避免GPU内存不足问题
58
- torch_dtype=torch.float32 # 保持一致,或根据模型精度调整
59
- )
60
- else: # 对于Helsinki-NLP/opus-mt-zh-en等
61
- self.models[model_key] = pipeline(
62
- "translation", # 也可以用 "translation_zh_to_en" 如果 pipeline 支持
63
- model=config["model_name"],
64
- tokenizer=config["model_name"],
65
- device=-1, # 使用CPU
66
- torch_dtype=torch.float32
67
- )
68
  print(f"✓ {model_key} 加载成功")
69
  except Exception as e:
70
  print(f"✗ {model_key} 加载失败: {e}")
@@ -87,12 +76,19 @@ class TranslationComparator:
87
  try:
88
  start_time = time.time()
89
 
90
- # 翻译文本
91
- # pipeline("translation") 的返回格式是 [{"translation_text": "..."}]
92
- result = model_entry( # 直接使用 model_entry,因为现在都是pipeline对象
93
- text_to_translate,
94
- max_length=max_length
95
- )
 
 
 
 
 
 
 
96
 
97
  end_time = time.time()
98
 
@@ -163,12 +159,12 @@ def calculate_grace_scores_for_translation():
163
  "Consistency": 7.9, # 翻译稳定性
164
  "Efficiency": 7.5 # 推理效率
165
  },
166
- "Chinese-to-English (mBART-Large-50)": { # **这里已修改!**
167
- "Generalization": 8.5, # 更大型多语言模型,泛化性通常更强
168
- "Relevance": 8.8,
169
- "Accuracy": 8.6,
170
- "Consistency": 8.5,
171
- "Efficiency": 6.0 # 模型较大,效率可能略低
172
  }
173
  }
174
  return grace_data
@@ -183,8 +179,7 @@ def create_translation_radar_chart():
183
 
184
  for i, (model_name, scores) in enumerate(grace_scores.items()):
185
  values = [scores[cat] for cat in categories]
186
- # **这里使用 MODEL_CONFIGS[model_name]["color"] 依赖于 MODEL_CONFIGS 和 grace_scores 的键名一致**
187
- # 这是导致之前 KeyError 的地方,现在应该已修复,因为 calculate_grace_scores_for_translation 的键名已更新
188
  color = MODEL_CONFIGS[model_name]["color"]
189
 
190
  fig.add_trace(go.Scatterpolar(
@@ -253,9 +248,9 @@ def create_model_info_table():
253
  if "opus-mt-zh-en" in config["model_name"]:
254
  params = "~3亿"
255
  size = "~1.2GB"
256
- elif "mbart-large-50" in config["model_name"]: # 修改为mBART的参数
257
- params = "~6.1亿" # mBART-Large-50 的实际参数量
258
- size = "~2.4GB" # mBART-Large-50 的实际模型大小
259
  else: # 默认值
260
  params = "未知"
261
  size = "未知"
 
18
  "max_length": 200, # 翻译输出的最大长度
19
  "color": "#FF6B6B"
20
  },
21
+ "Chinese-to-English (T5-Small)": { # **更改为 T5-Small 模型**
22
+ "model_name": "google-t5/t5-small",
23
+ "description": "中文到英文的机器翻译模型 (Google T5-Small)",
24
  "max_length": 200, # 翻译输出的最大长度
25
  "color": "#4ECDC4"
26
  }
27
  # 如果需要第三个模型,可以取消注释下面这个,或替换成您想要的
28
  # "Chinese-to-English (Another Model)": {
29
+ # "model_name": "facebook/mbart-large-50-one-to-many-mmt",
30
  # "description": "中文到英文的机器翻译模型 (Facebook mBART-Large-50)",
31
  # "max_length": 200,
32
  # "color": "#45B7D1"
 
45
  try:
46
  print(f"加载 {model_key} ({config['model_name']})...")
47
 
48
+ # T5模型通常用于多任务,这里我们明确指定它用于翻译
49
+ # pipeline("translation") 会尝试自动处理,但T5需要特定输入格式
50
+ self.models[model_key] = pipeline(
51
+ "translation", # T5可以用'translation' task
52
+ model=config["model_name"],
53
+ tokenizer=config["model_name"],
54
+ device=-1, # 使用CPU,避免GPU内存不足问题
55
+ torch_dtype=torch.float32 # 保持一致,或根据模型精度调整
56
+ )
 
 
 
 
 
 
 
 
 
 
 
57
  print(f"✓ {model_key} 加载成功")
58
  except Exception as e:
59
  print(f"✗ {model_key} 加载失败: {e}")
 
76
  try:
77
  start_time = time.time()
78
 
79
+ # **针对 T5 模型添加输入格式化**
80
+ if "t5-small" in model_key.lower(): # 检查是否是T5-Small模型
81
+ # T5的翻译任务通常需要这样的前缀
82
+ formatted_text = f"translate Chinese to English: {text_to_translate}"
83
+ result = model_entry(
84
+ formatted_text,
85
+ max_length=max_length
86
+ )
87
+ else: # 对于Helsinki-NLP/opus-mt-zh-en等其他模型
88
+ result = model_entry( # 直接使用 model_entry,因为现在都是pipeline对象
89
+ text_to_translate,
90
+ max_length=max_length
91
+ )
92
 
93
  end_time = time.time()
94
 
 
159
  "Consistency": 7.9, # 翻译稳定性
160
  "Efficiency": 7.5 # 推理效率
161
  },
162
+ "Chinese-to-English (T5-Small)": { # **T5-Small 的模拟 GRACE 分数**
163
+ "Generalization": 6.8, # 比T5-Base略低,泛化性可能稍弱
164
+ "Relevance": 7.0,
165
+ "Accuracy": 6.5,
166
+ "Consistency": 6.8,
167
+ "Efficiency": 9.0 # 模型更小,效率更高
168
  }
169
  }
170
  return grace_data
 
179
 
180
  for i, (model_name, scores) in enumerate(grace_scores.items()):
181
  values = [scores[cat] for cat in categories]
182
+ # 这里使用 MODEL_CONFIGS[model_name]["color"] 依赖于 MODEL_CONFIGS 和 grace_scores 的键名一致
 
183
  color = MODEL_CONFIGS[model_name]["color"]
184
 
185
  fig.add_trace(go.Scatterpolar(
 
248
  if "opus-mt-zh-en" in config["model_name"]:
249
  params = "~3亿"
250
  size = "~1.2GB"
251
+ elif "t5-small" in config["model_name"]: # **更新 T5-Small 的参数**
252
+ params = "~6千万" # T5-Small 实际参数量约 60 million
253
+ size = "~240MB" # T5-Small 实际模型大小约 240MB
254
  else: # 默认值
255
  params = "未知"
256
  size = "未知"