hellokawei commited on
Commit
a70ed43
·
verified ·
1 Parent(s): 90d1820

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -41
app.py CHANGED
@@ -17,19 +17,12 @@ MODEL_CONFIGS = {
17
  "max_length": 200, # 翻译输出的最大长度
18
  "color": "#FF6B6B"
19
  },
20
- "Chinese-to-English (M4-Small)": {
21
- "model_name": "HuggingFaceM4/m4-small-en-zh", # 这是一个多语言模型,支持zh-en
22
- "description": "中文到英文的机器翻译模型 (HuggingFaceM4 M4-Small)",
23
  "max_length": 200, # 翻译输出的最大长度
24
  "color": "#4ECDC4"
25
  }
26
- # 如果需要第三个模型,可以取消注释下面这个,或替换成您想要的
27
- # "Chinese-to-English (Another Model)": {
28
- # "model_name": "facebook/mbart-large-50-one-to-many-mmt", # 另一个多语言模型,需要指定 src_lang/tgt_lang
29
- # "description": "中文到英文的机器翻译模型 (Facebook mBART-Large-50)",
30
- # "max_length": 200,
31
- # "color": "#45B7D1"
32
- # }
33
  }
34
 
35
  class TranslationComparator:
@@ -43,38 +36,27 @@ class TranslationComparator:
43
  for model_key, config in MODEL_CONFIGS.items():
44
  try:
45
  print(f"加载 {model_key} ({config['model_name']})...")
 
46
  # 对于翻译任务,使用 "translation" pipeline
47
- # 注意:某些多语言模型(如 m4-small)可能需要显式指定源语言和目标语言
48
- # 对于 Helsinki-NLP/opus-mt-zh-en,pipeline会自动处理
49
- # 对于 HuggingFaceM4/m4-small-en-zh,虽然名字是en-zh,但它内部支持zh-en。
50
- # 如果遇到问题,可能需要更复杂的tokenizer/model加载方式而非pipeline
51
- if "opus-mt-zh-en" in config["model_name"]:
52
- task = "translation_zh_to_en" # 更明确的翻译任务
53
- elif "m4-small" in config["model_name"]:
54
- # m4-small是一个多语言模型,需要提供源语言和目标语言。
55
- # pipeline("translation") 不直接支持 src_lang/tgt_lang 参数
56
- # 需要手动加载 AutoModelForSeq2SeqLM 和 AutoTokenizer
57
- print(f"特别加载 {model_key} 及其Tokenizer...")
58
- tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
59
- model = AutoModelForSeq2SeqLM.from_pretrained(config["model_name"])
60
- # 将其包装成一个简单的可调用对象,模拟pipeline的行为
61
- self.models[model_key] = {
62
- "tokenizer": tokenizer,
63
- "model": model,
64
- "pipeline_type": "custom_translation"
65
- }
66
- print(f"✓ {model_key} 加载成功 (自定义翻译模式)")
67
- continue # 跳过pipeline加载
68
- else: # 默认翻译任务
69
- task = "translation"
70
-
71
- self.models[model_key] = pipeline(
72
- task,
73
- model=config["model_name"],
74
- tokenizer=config["model_name"],
75
- device=-1, # 使用CPU
76
- torch_dtype=torch.float32
77
- )
78
  print(f"✓ {model_key} 加载成功")
79
  except Exception as e:
80
  print(f"✗ {model_key} 加载失败: {e}")
 
17
  "max_length": 200, # 翻译输出的最大长度
18
  "color": "#FF6B6B"
19
  },
20
+ "Chinese-to-English (mBART-Large-50)": { # 替换为mBART模型
21
+ "model_name": "facebook/mbart-large-50-many-to-one-mmt",
22
+ "description": "中文到英文的机器翻译模型 (Facebook mBART-Large-50)",
23
  "max_length": 200, # 翻译输出的最大长度
24
  "color": "#4ECDC4"
25
  }
 
 
 
 
 
 
 
26
  }
27
 
28
  class TranslationComparator:
 
36
  for model_key, config in MODEL_CONFIGS.items():
37
  try:
38
  print(f"加载 {model_key} ({config['model_name']})...")
39
+
40
  # 对于翻译任务,使用 "translation" pipeline
41
+ # 注意:mBART模型需要指定 source_lang 和 target_lang
42
+ if "mbart-large-50" in config["model_name"]:
43
+ self.models[model_key] = pipeline(
44
+ "translation",
45
+ model=config["model_name"],
46
+ tokenizer=config["model_name"],
47
+ src_lang="zh_CN", # 源语言为中文
48
+ tgt_lang="en_US", # 目标语言为英文
49
+ device=-1, # 使用CPU
50
+ torch_dtype=torch.float32
51
+ )
52
+ else: # 对于Helsinki-NLP/opus-mt-zh-en等
53
+ self.models[model_key] = pipeline(
54
+ "translation", # 也可以用 "translation_zh_to_en" 如果 pipeline 支持
55
+ model=config["model_name"],
56
+ tokenizer=config["model_name"],
57
+ device=-1, # 使用CPU
58
+ torch_dtype=torch.float32
59
+ )
 
 
 
 
 
 
 
 
 
 
 
 
60
  print(f"✓ {model_key} 加载成功")
61
  except Exception as e:
62
  print(f"✗ {model_key} 加载失败: {e}")