hellokawei commited on
Commit
639702f
·
verified ·
1 Parent(s): cb502ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -30
app.py CHANGED
@@ -8,7 +8,6 @@ import time
8
  import numpy as np
9
 
10
  # 初始化模型
11
- @gr.cache
12
  def load_models():
13
  """加载三个不同的文本生成模型"""
14
  models = {}
@@ -16,33 +15,53 @@ def load_models():
16
  try:
17
  # 模型1: GPT-2 (轻量级)
18
  models['gpt2'] = {
19
- 'pipeline': pipeline("text-generation", model="gpt2", max_length=100),
20
  'name': 'GPT-2',
21
  'description': '经典的自回归语言模型,适合短文本生成'
22
  }
23
 
24
  # 模型2: DistilGPT-2 (更快速)
25
  models['distilgpt2'] = {
26
- 'pipeline': pipeline("text-generation", model="distilgpt2", max_length=100),
27
  'name': 'DistilGPT-2',
28
  'description': '轻量化的GPT-2,速度更快但质量略低'
29
  }
30
 
31
- # 模型3: Microsoft DialoGPT (对话优化)
32
- models['dialogpt'] = {
33
- 'pipeline': pipeline("text-generation", model="microsoft/DialoGPT-medium", max_length=100),
34
- 'name': 'DialoGPT-medium',
35
- 'description': '针对对话场景优化的生成模型'
36
  }
37
 
38
  except Exception as e:
39
  print(f"模型加载错误: {e}")
40
- # 备用方案:使用更简单的模型
41
- models['gpt2'] = {
42
- 'pipeline': pipeline("text-generation", model="gpt2", max_length=50),
43
- 'name': 'GPT-2',
44
- 'description': '经典的自回归语言模型'
45
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  return models
48
 
@@ -63,35 +82,41 @@ GRACE_DATA = {
63
  'Artistry': 6.8,
64
  'Efficiency': 9.2
65
  },
66
- 'DialoGPT-medium': {
67
- 'Generalization': 7.0,
68
- 'Relevance': 8.8,
69
- 'Artistry': 8.0,
70
- 'Efficiency': 7.5
71
  }
72
  }
73
 
74
- def generate_text_with_model(model_key, prompt, max_length=100):
75
  """使用指定模型生成文本"""
76
  try:
77
  start_time = time.time()
78
 
79
- if model_key not in MODELS:
80
- return "模型未找到", 0
81
 
82
  result = MODELS[model_key]['pipeline'](
83
  prompt,
84
- max_length=max_length,
85
  num_return_sequences=1,
86
  temperature=0.7,
87
  do_sample=True,
88
- pad_token_id=50256
 
 
89
  )
90
 
91
  end_time = time.time()
92
  generation_time = end_time - start_time
93
 
94
- generated_text = result[0]['generated_text']
 
 
 
 
95
  return generated_text, generation_time
96
 
97
  except Exception as e:
@@ -161,7 +186,7 @@ def arena_interface(prompt, max_length):
161
  # 格式化输出
162
  output1 = f"**{MODELS['gpt2']['name']}** (生成时间: {times.get('gpt2', 0):.2f}s)\n\n{results.get('gpt2', '生成失败')}"
163
  output2 = f"**{MODELS['distilgpt2']['name']}** (生成时间: {times.get('distilgpt2', 0):.2f}s)\n\n{results.get('distilgpt2', '生成失败')}"
164
- output3 = f"**{MODELS['dialogpt']['name']}** (生成时间: {times.get('dialogpt', 0):.2f}s)\n\n{results.get('dialogpt', '生成失败')}"
165
 
166
  # 生成对比分析
167
  analysis = f"""
@@ -170,13 +195,13 @@ def arena_interface(prompt, max_length):
170
  ### 速度对比
171
  - GPT-2: {times.get('gpt2', 0):.2f}秒
172
  - DistilGPT-2: {times.get('distilgpt2', 0):.2f}秒
173
- - DialoGPT: {times.get('dialogpt', 0):.2f}秒
174
 
175
  ### 质量评估
176
  根据GRACE框架,不同模型在各维度的表现存在差异:
177
- - **效率性**: DistilGPT-2表现最佳
178
- - **相关性**: DialoGPT在对话场景中表现突出
179
  - **泛化性**: GPT-2具有最强的通用性
 
180
  """
181
 
182
  return output1, output2, output3, analysis
@@ -239,7 +264,7 @@ def create_app():
239
  with gr.Row():
240
  model1_output = gr.Markdown(label="GPT-2 输出")
241
  model2_output = gr.Markdown(label="DistilGPT-2 输出")
242
- model3_output = gr.Markdown(label="DialoGPT 输出")
243
 
244
  analysis_output = gr.Markdown(label="对比分析")
245
 
 
8
  import numpy as np
9
 
10
  # 初始化模型
 
11
  def load_models():
12
  """加载三个不同的文本生成模型"""
13
  models = {}
 
15
  try:
16
  # 模型1: GPT-2 (轻量级)
17
  models['gpt2'] = {
18
+ 'pipeline': pipeline("text-generation", model="gpt2", max_new_tokens=50),
19
  'name': 'GPT-2',
20
  'description': '经典的自回归语言模型,适合短文本生成'
21
  }
22
 
23
  # 模型2: DistilGPT-2 (更快速)
24
  models['distilgpt2'] = {
25
+ 'pipeline': pipeline("text-generation", model="distilgpt2", max_new_tokens=50),
26
  'name': 'DistilGPT-2',
27
  'description': '轻量化的GPT-2,速度更快但质量略低'
28
  }
29
 
30
+ # 模型3: OpenELM (苹果开源模型)
31
+ models['openelm'] = {
32
+ 'pipeline': pipeline("text-generation", model="apple/OpenELM-270M", max_new_tokens=50, trust_remote_code=True),
33
+ 'name': 'OpenELM-270M',
34
+ 'description': '苹果开源的轻量级语言模型'
35
  }
36
 
37
  except Exception as e:
38
  print(f"模型加载错误: {e}")
39
+ # 备用方案:只使用最基础的模型
40
+ try:
41
+ models['gpt2'] = {
42
+ 'pipeline': pipeline("text-generation", model="gpt2", max_new_tokens=30),
43
+ 'name': 'GPT-2',
44
+ 'description': '经典的自回归语言模型'
45
+ }
46
+ models['distilgpt2'] = {
47
+ 'pipeline': pipeline("text-generation", model="distilgpt2", max_new_tokens=30),
48
+ 'name': 'DistilGPT-2',
49
+ 'description': '轻量化版本'
50
+ }
51
+ # 第三个模型用简单的替代
52
+ models['openelm'] = {
53
+ 'pipeline': pipeline("text-generation", model="gpt2", max_new_tokens=20),
54
+ 'name': 'GPT-2-Variant',
55
+ 'description': '备用模型配置'
56
+ }
57
+ except Exception as e2:
58
+ print(f"备用模型加载也失败: {e2}")
59
+ # 最终备用:至少确保有一个模型可用
60
+ models['gpt2'] = {
61
+ 'pipeline': None,
62
+ 'name': 'GPT-2',
63
+ 'description': '模型加载失败'
64
+ }
65
 
66
  return models
67
 
 
82
  'Artistry': 6.8,
83
  'Efficiency': 9.2
84
  },
85
+ 'OpenELM-270M': {
86
+ 'Generalization': 6.5,
87
+ 'Relevance': 7.0,
88
+ 'Artistry': 6.5,
89
+ 'Efficiency': 8.8
90
  }
91
  }
92
 
93
+ def generate_text_with_model(model_key, prompt, max_length=50):
94
  """使用指定模型生成文本"""
95
  try:
96
  start_time = time.time()
97
 
98
+ if model_key not in MODELS or MODELS[model_key]['pipeline'] is None:
99
+ return "模型未找到或加载失败", 0
100
 
101
  result = MODELS[model_key]['pipeline'](
102
  prompt,
103
+ max_new_tokens=min(max_length, 50),
104
  num_return_sequences=1,
105
  temperature=0.7,
106
  do_sample=True,
107
+ pad_token_id=50256,
108
+ truncation=True,
109
+ return_full_text=False
110
  )
111
 
112
  end_time = time.time()
113
  generation_time = end_time - start_time
114
 
115
+ if result and len(result) > 0:
116
+ generated_text = prompt + result[0]['generated_text']
117
+ else:
118
+ generated_text = "生成失败"
119
+
120
  return generated_text, generation_time
121
 
122
  except Exception as e:
 
186
  # 格式化输出
187
  output1 = f"**{MODELS['gpt2']['name']}** (生成时间: {times.get('gpt2', 0):.2f}s)\n\n{results.get('gpt2', '生成失败')}"
188
  output2 = f"**{MODELS['distilgpt2']['name']}** (生成时间: {times.get('distilgpt2', 0):.2f}s)\n\n{results.get('distilgpt2', '生成失败')}"
189
+ output3 = f"**{MODELS['openelm']['name']}** (生成时间: {times.get('openelm', 0):.2f}s)\n\n{results.get('openelm', '生成失败')}"
190
 
191
  # 生成对比分析
192
  analysis = f"""
 
195
  ### 速度对比
196
  - GPT-2: {times.get('gpt2', 0):.2f}秒
197
  - DistilGPT-2: {times.get('distilgpt2', 0):.2f}秒
198
+ - OpenELM: {times.get('openelm', 0):.2f}秒
199
 
200
  ### 质量评估
201
  根据GRACE框架,不同模型在各维度的表现存在差异:
202
+ - **效率性**: DistilGPT-2和OpenELM表现优异
 
203
  - **泛化性**: GPT-2具有最强的通用性
204
+ - **相关性**: 各模型在相关性上表现相近
205
  """
206
 
207
  return output1, output2, output3, analysis
 
264
  with gr.Row():
265
  model1_output = gr.Markdown(label="GPT-2 输出")
266
  model2_output = gr.Markdown(label="DistilGPT-2 输出")
267
+ model3_output = gr.Markdown(label="OpenELM 输出")
268
 
269
  analysis_output = gr.Markdown(label="对比分析")
270