Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,19 +14,19 @@ import re
|
|
| 14 |
MODEL_CONFIGS = {
|
| 15 |
"GPT2-Small": {
|
| 16 |
"model_name": "gpt2",
|
| 17 |
-
"description": "OpenAI
|
| 18 |
"max_length": 100,
|
| 19 |
"color": "#FF6B6B"
|
| 20 |
},
|
| 21 |
"DistilGPT2": {
|
| 22 |
-
"model_name": "distilgpt2",
|
| 23 |
-
"description": "
|
| 24 |
"max_length": 100,
|
| 25 |
"color": "#4ECDC4"
|
| 26 |
},
|
| 27 |
"GPT2-Medium": {
|
| 28 |
"model_name": "gpt2-medium",
|
| 29 |
-
"description": "GPT-2
|
| 30 |
"max_length": 100,
|
| 31 |
"color": "#45B7D1"
|
| 32 |
}
|
|
@@ -37,7 +37,7 @@ class TextGenerationComparator:
|
|
| 37 |
self.models = {}
|
| 38 |
self.tokenizers = {}
|
| 39 |
self.load_models()
|
| 40 |
-
|
| 41 |
def load_models(self):
|
| 42 |
"""加载所有文本生成模型"""
|
| 43 |
print("正在加载模型...")
|
|
@@ -57,12 +57,12 @@ class TextGenerationComparator:
|
|
| 57 |
print(f"✗ {model_key} 加载失败: {e}")
|
| 58 |
# 创建一个mock模型用于演示
|
| 59 |
self.models[model_key] = None
|
| 60 |
-
|
| 61 |
def generate_text(self, model_key, prompt, max_length=50, temperature=0.7, top_p=0.9):
|
| 62 |
"""使用指定模型生成文本"""
|
| 63 |
if self.models[model_key] is None:
|
| 64 |
return {
|
| 65 |
-
"generated_text": f"[模型 {model_key} 未正确加载,这是模拟输出] {prompt}
|
| 66 |
"inference_time": 0.5,
|
| 67 |
"input_length": len(prompt.split()),
|
| 68 |
"output_length": max_length,
|
|
@@ -72,10 +72,10 @@ class TextGenerationComparator:
|
|
| 72 |
"max_length": max_length
|
| 73 |
}
|
| 74 |
}
|
| 75 |
-
|
| 76 |
try:
|
| 77 |
start_time = time.time()
|
| 78 |
-
|
| 79 |
# 生成文本
|
| 80 |
result = self.models[model_key](
|
| 81 |
prompt,
|
|
@@ -87,14 +87,14 @@ class TextGenerationComparator:
|
|
| 87 |
num_return_sequences=1,
|
| 88 |
truncation=True
|
| 89 |
)
|
| 90 |
-
|
| 91 |
end_time = time.time()
|
| 92 |
-
|
| 93 |
# 提取生成的文本(去除原始prompt)
|
| 94 |
generated_text = result[0]['generated_text']
|
| 95 |
if generated_text.startswith(prompt):
|
| 96 |
generated_text = generated_text[len(prompt):].strip()
|
| 97 |
-
|
| 98 |
return {
|
| 99 |
"generated_text": generated_text,
|
| 100 |
"full_text": result[0]['generated_text'],
|
|
@@ -107,7 +107,7 @@ class TextGenerationComparator:
|
|
| 107 |
"max_length": max_length
|
| 108 |
}
|
| 109 |
}
|
| 110 |
-
|
| 111 |
except Exception as e:
|
| 112 |
return {
|
| 113 |
"error": f"生成错误: {str(e)}",
|
|
@@ -123,36 +123,36 @@ def run_text_generation_comparison(prompt, max_length, temperature, top_p):
|
|
| 123 |
"""运行所有模型的文本生成对比"""
|
| 124 |
if not prompt.strip():
|
| 125 |
return "请输入提示文本", "请输入提示文本", "请输入提示文本"
|
| 126 |
-
|
| 127 |
results = {}
|
| 128 |
-
|
| 129 |
for model_key in MODEL_CONFIGS.keys():
|
| 130 |
result = comparator.generate_text(
|
| 131 |
-
model_key,
|
| 132 |
-
prompt,
|
| 133 |
max_length=int(max_length),
|
| 134 |
temperature=temperature,
|
| 135 |
top_p=top_p
|
| 136 |
)
|
| 137 |
results[model_key] = result
|
| 138 |
-
|
| 139 |
# 格式化输出
|
| 140 |
def format_result(result):
|
| 141 |
if "error" in result:
|
| 142 |
return json.dumps(result, indent=2, ensure_ascii=False)
|
| 143 |
-
|
| 144 |
formatted = {
|
| 145 |
-
"
|
| 146 |
-
"
|
| 147 |
-
"
|
| 148 |
-
"
|
| 149 |
}
|
| 150 |
return json.dumps(formatted, indent=2, ensure_ascii=False)
|
| 151 |
-
|
| 152 |
gpt2_result = format_result(results.get("GPT2-Small", {}))
|
| 153 |
distilgpt2_result = format_result(results.get("DistilGPT2", {}))
|
| 154 |
gpt2_medium_result = format_result(results.get("GPT2-Medium", {}))
|
| 155 |
-
|
| 156 |
return gpt2_result, distilgpt2_result, gpt2_medium_result
|
| 157 |
|
| 158 |
def calculate_grace_scores_for_generation():
|
|
@@ -186,14 +186,15 @@ def calculate_grace_scores_for_generation():
|
|
| 186 |
def create_generation_radar_chart():
|
| 187 |
"""创建文本生成GRACE评估���达图"""
|
| 188 |
grace_scores = calculate_grace_scores_for_generation()
|
|
|
|
| 189 |
categories = ['Generalization', 'Relevance', 'Artistry', 'Consistency', 'Efficiency']
|
| 190 |
-
|
| 191 |
fig = go.Figure()
|
| 192 |
-
|
| 193 |
for i, (model_name, scores) in enumerate(grace_scores.items()):
|
| 194 |
values = [scores[cat] for cat in categories]
|
| 195 |
color = MODEL_CONFIGS[model_name]["color"]
|
| 196 |
-
|
| 197 |
fig.add_trace(go.Scatterpolar(
|
| 198 |
r=values,
|
| 199 |
theta=categories,
|
|
@@ -203,7 +204,7 @@ def create_generation_radar_chart():
|
|
| 203 |
fillcolor=color,
|
| 204 |
opacity=0.6
|
| 205 |
))
|
| 206 |
-
|
| 207 |
fig.update_layout(
|
| 208 |
polar=dict(
|
| 209 |
radialaxis=dict(
|
|
@@ -214,40 +215,31 @@ def create_generation_radar_chart():
|
|
| 214 |
),
|
| 215 |
showlegend=True,
|
| 216 |
title={
|
| 217 |
-
'text': "GRACE
|
| 218 |
'x': 0.5,
|
| 219 |
'font': {'size': 16}
|
| 220 |
},
|
| 221 |
width=600,
|
| 222 |
height=500
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
)
|
| 234 |
-
|
| 235 |
return fig
|
| 236 |
|
| 237 |
def create_performance_bar_chart():
|
| 238 |
"""创建性能对比柱状图"""
|
| 239 |
grace_scores = calculate_grace_scores_for_generation()
|
| 240 |
-
|
| 241 |
models = list(grace_scores.keys())
|
|
|
|
| 242 |
categories = ['Generalization', 'Relevance', 'Artistry', 'Consistency', 'Efficiency']
|
| 243 |
-
|
| 244 |
fig = go.Figure()
|
| 245 |
-
|
| 246 |
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#F7DC6F', '#BB8FCE']
|
| 247 |
-
|
| 248 |
for i, category in enumerate(categories):
|
| 249 |
values = [grace_scores[model][category] for model in models]
|
| 250 |
-
|
| 251 |
fig.add_trace(go.Bar(
|
| 252 |
name=category,
|
| 253 |
x=models,
|
|
@@ -255,16 +247,16 @@ def create_performance_bar_chart():
|
|
| 255 |
marker_color=colors[i % len(colors)],
|
| 256 |
opacity=0.8
|
| 257 |
))
|
| 258 |
-
|
| 259 |
fig.update_layout(
|
| 260 |
-
title='GRACE
|
| 261 |
-
xaxis_title='
|
| 262 |
-
yaxis_title='
|
| 263 |
barmode='group',
|
| 264 |
width=700,
|
| 265 |
height=400
|
| 266 |
)
|
| 267 |
-
|
| 268 |
return fig
|
| 269 |
|
| 270 |
def create_model_info_table():
|
|
@@ -273,107 +265,96 @@ def create_model_info_table():
|
|
| 273 |
for model_key, config in MODEL_CONFIGS.items():
|
| 274 |
# 模拟参数信息
|
| 275 |
if "small" in model_key.lower() or model_key == "GPT2-Small":
|
| 276 |
-
params = "
|
| 277 |
size = "~500MB"
|
| 278 |
elif "distil" in model_key.lower():
|
| 279 |
-
params = "
|
| 280 |
size = "~350MB"
|
| 281 |
else:
|
| 282 |
-
params = "
|
| 283 |
size = "~1.4GB"
|
| 284 |
-
|
| 285 |
model_info.append({
|
| 286 |
-
"
|
| 287 |
-
"
|
| 288 |
-
"
|
| 289 |
-
"
|
| 290 |
-
"
|
| 291 |
})
|
| 292 |
-
|
| 293 |
return pd.DataFrame(model_info)
|
| 294 |
|
| 295 |
def create_summary_scores_table():
|
| 296 |
"""创建评分摘要表"""
|
| 297 |
grace_scores = calculate_grace_scores_for_generation()
|
| 298 |
-
|
| 299 |
summary_data = []
|
| 300 |
for model_name, scores in grace_scores.items():
|
| 301 |
avg_score = np.mean(list(scores.values()))
|
| 302 |
summary_data.append({
|
| 303 |
-
"
|
| 304 |
-
"
|
| 305 |
-
"
|
| 306 |
-
"
|
| 307 |
-
"
|
| 308 |
-
"
|
| 309 |
-
"
|
| 310 |
})
|
| 311 |
-
|
| 312 |
df = pd.DataFrame(summary_data)
|
| 313 |
return df
|
| 314 |
|
| 315 |
# 预设的示例提示
|
| 316 |
EXAMPLE_PROMPTS = [
|
| 317 |
-
"
|
| 318 |
-
"
|
| 319 |
-
"
|
| 320 |
-
"
|
| 321 |
-
"
|
| 322 |
]
|
| 323 |
|
| 324 |
def create_app():
|
| 325 |
-
with gr.Blocks(title="
|
| 326 |
-
gr.Markdown("# 📝
|
| 327 |
-
gr.Markdown("###
|
| 328 |
-
|
| 329 |
with gr.Tabs():
|
| 330 |
# Arena选项卡
|
| 331 |
-
with gr.TabItem("🏟️
|
| 332 |
-
gr.Markdown("##
|
| 333 |
-
gr.Markdown("
|
| 334 |
-
|
| 335 |
with gr.Row():
|
| 336 |
with gr.Column(scale=3):
|
| 337 |
input_prompt = gr.Textbox(
|
| 338 |
-
label="
|
| 339 |
-
placeholder="
|
| 340 |
lines=3,
|
| 341 |
-
value="
|
| 342 |
)
|
| 343 |
-
|
| 344 |
-
# 预设示例按钮
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
|
|
|
|
| 354 |
with gr.Row():
|
| 355 |
example_buttons = []
|
| 356 |
for i, example in enumerate(EXAMPLE_PROMPTS[:3]):
|
| 357 |
-
btn = gr.Button(f"
|
| 358 |
example_buttons.append(btn)
|
| 359 |
-
|
| 360 |
with gr.Column(scale=1):
|
| 361 |
max_length = gr.Slider(
|
| 362 |
minimum=10,
|
| 363 |
maximum=200,
|
| 364 |
value=50,
|
| 365 |
step=10,
|
| 366 |
-
label="
|
| 367 |
)
|
| 368 |
-
|
| 369 |
temperature = gr.Slider(
|
| 370 |
minimum=0.1,
|
| 371 |
maximum=2.0,
|
| 372 |
value=0.7,
|
| 373 |
step=0.1,
|
| 374 |
-
label="Temperature"
|
| 375 |
)
|
| 376 |
-
|
| 377 |
top_p = gr.Slider(
|
| 378 |
minimum=0.1,
|
| 379 |
maximum=1.0,
|
|
@@ -381,83 +362,82 @@ def create_app():
|
|
| 381 |
step=0.05,
|
| 382 |
label="Top-p"
|
| 383 |
)
|
| 384 |
-
|
| 385 |
-
submit_btn = gr.Button("🚀
|
| 386 |
-
|
| 387 |
# 设置示例按钮点击事件
|
| 388 |
for i, btn in enumerate(example_buttons):
|
| 389 |
btn.click(
|
| 390 |
fn=lambda x=EXAMPLE_PROMPTS[i]: x,
|
| 391 |
outputs=[input_prompt]
|
| 392 |
)
|
| 393 |
-
|
| 394 |
with gr.Row():
|
| 395 |
with gr.Column():
|
| 396 |
gpt2_output = gr.Code(
|
| 397 |
-
label="GPT2-Small (
|
| 398 |
language="json",
|
| 399 |
-
value="
|
| 400 |
)
|
| 401 |
-
|
| 402 |
with gr.Column():
|
| 403 |
distilgpt2_output = gr.Code(
|
| 404 |
-
label="DistilGPT2 (
|
| 405 |
language="json",
|
| 406 |
-
value="
|
| 407 |
)
|
| 408 |
-
|
| 409 |
with gr.Column():
|
| 410 |
gpt2_medium_output = gr.Code(
|
| 411 |
-
label="GPT2-Medium (
|
| 412 |
-
language="json",
|
| 413 |
-
value="
|
| 414 |
)
|
| 415 |
-
|
| 416 |
submit_btn.click(
|
| 417 |
fn=run_text_generation_comparison,
|
| 418 |
inputs=[input_prompt, max_length, temperature, top_p],
|
| 419 |
outputs=[gpt2_output, distilgpt2_output, gpt2_medium_output]
|
| 420 |
)
|
| 421 |
-
|
| 422 |
# Benchmark选项卡
|
| 423 |
-
with gr.TabItem("📊 GRACE
|
| 424 |
-
gr.Markdown("## GRACE
|
| 425 |
gr.Markdown("""
|
| 426 |
-
**GRACE
|
| 427 |
-
- **G**eneralization:
|
| 428 |
-
- **R**elevance:
|
| 429 |
-
- **A**rtistry:
|
| 430 |
-
- **C**onsistency:
|
| 431 |
-
- **E**fficiency:
|
| 432 |
""")
|
| 433 |
-
|
| 434 |
with gr.Row():
|
| 435 |
radar_plot = gr.Plot(
|
| 436 |
value=create_generation_radar_chart(),
|
| 437 |
-
label="GRACE
|
| 438 |
)
|
| 439 |
-
|
| 440 |
with gr.Row():
|
| 441 |
bar_plot = gr.Plot(
|
| 442 |
value=create_performance_bar_chart(),
|
| 443 |
-
label="
|
| 444 |
-
|
| 445 |
)
|
| 446 |
-
|
| 447 |
with gr.Row():
|
| 448 |
with gr.Column():
|
| 449 |
model_info_df = create_model_info_table()
|
| 450 |
model_info_table = gr.Dataframe(
|
| 451 |
value=model_info_df,
|
| 452 |
-
label="
|
| 453 |
interactive=False
|
| 454 |
)
|
| 455 |
-
|
| 456 |
with gr.Column():
|
| 457 |
summary_df = create_summary_scores_table()
|
| 458 |
summary_table = gr.Dataframe(
|
| 459 |
value=summary_df,
|
| 460 |
-
label="GRACE
|
| 461 |
interactive=False
|
| 462 |
)
|
| 463 |
|
|
@@ -654,7 +634,7 @@ graph TD
|
|
| 654 |
|
| 655 |
## 5. 合作与反思
|
| 656 |
|
| 657 |
-
###
|
| 658 |
- **负责内容**:
|
| 659 |
- 模型集成和pipeline构建
|
| 660 |
- Arena界面开发和交互逻辑
|
|
@@ -672,7 +652,7 @@ graph TD
|
|
| 672 |
- 生成质量的客观评估方法设计
|
| 673 |
- CPU推理性能优化
|
| 674 |
|
| 675 |
-
###
|
| 676 |
- **负责内容**:
|
| 677 |
- GRACE评估框架的文本生成适配
|
| 678 |
- 数据可视化和图表制作
|
|
|
|
| 14 |
MODEL_CONFIGS = {
|
| 15 |
"GPT2-Small": {
|
| 16 |
"model_name": "gpt2",
|
| 17 |
+
"description": "OpenAI的GPT-2小型模型(1.24亿参数)",
|
| 18 |
"max_length": 100,
|
| 19 |
"color": "#FF6B6B"
|
| 20 |
},
|
| 21 |
"DistilGPT2": {
|
| 22 |
+
"model_name": "distilgpt2",
|
| 23 |
+
"description": "GPT-2的蒸馏版本(8200万参数)",
|
| 24 |
"max_length": 100,
|
| 25 |
"color": "#4ECDC4"
|
| 26 |
},
|
| 27 |
"GPT2-Medium": {
|
| 28 |
"model_name": "gpt2-medium",
|
| 29 |
+
"description": "GPT-2中型模型(3.55亿参数)",
|
| 30 |
"max_length": 100,
|
| 31 |
"color": "#45B7D1"
|
| 32 |
}
|
|
|
|
| 37 |
self.models = {}
|
| 38 |
self.tokenizers = {}
|
| 39 |
self.load_models()
|
| 40 |
+
|
| 41 |
def load_models(self):
|
| 42 |
"""加载所有文本生成模型"""
|
| 43 |
print("正在加载模型...")
|
|
|
|
| 57 |
print(f"✗ {model_key} 加载失败: {e}")
|
| 58 |
# 创建一个mock模型用于演示
|
| 59 |
self.models[model_key] = None
|
| 60 |
+
|
| 61 |
def generate_text(self, model_key, prompt, max_length=50, temperature=0.7, top_p=0.9):
|
| 62 |
"""使用指定模型生成文本"""
|
| 63 |
if self.models[model_key] is None:
|
| 64 |
return {
|
| 65 |
+
"generated_text": f"[模型 {model_key} 未正确加载,这是模拟输出] {prompt} 这是一个模拟的文本续写...",
|
| 66 |
"inference_time": 0.5,
|
| 67 |
"input_length": len(prompt.split()),
|
| 68 |
"output_length": max_length,
|
|
|
|
| 72 |
"max_length": max_length
|
| 73 |
}
|
| 74 |
}
|
| 75 |
+
|
| 76 |
try:
|
| 77 |
start_time = time.time()
|
| 78 |
+
|
| 79 |
# 生成文本
|
| 80 |
result = self.models[model_key](
|
| 81 |
prompt,
|
|
|
|
| 87 |
num_return_sequences=1,
|
| 88 |
truncation=True
|
| 89 |
)
|
| 90 |
+
|
| 91 |
end_time = time.time()
|
| 92 |
+
|
| 93 |
# 提取生成的文本(去除原始prompt)
|
| 94 |
generated_text = result[0]['generated_text']
|
| 95 |
if generated_text.startswith(prompt):
|
| 96 |
generated_text = generated_text[len(prompt):].strip()
|
| 97 |
+
|
| 98 |
return {
|
| 99 |
"generated_text": generated_text,
|
| 100 |
"full_text": result[0]['generated_text'],
|
|
|
|
| 107 |
"max_length": max_length
|
| 108 |
}
|
| 109 |
}
|
| 110 |
+
|
| 111 |
except Exception as e:
|
| 112 |
return {
|
| 113 |
"error": f"生成错误: {str(e)}",
|
|
|
|
| 123 |
"""运行所有模型的文本生成对比"""
|
| 124 |
if not prompt.strip():
|
| 125 |
return "请输入提示文本", "请输入提示文本", "请输入提示文本"
|
| 126 |
+
|
| 127 |
results = {}
|
| 128 |
+
|
| 129 |
for model_key in MODEL_CONFIGS.keys():
|
| 130 |
result = comparator.generate_text(
|
| 131 |
+
model_key,
|
| 132 |
+
prompt,
|
| 133 |
max_length=int(max_length),
|
| 134 |
temperature=temperature,
|
| 135 |
top_p=top_p
|
| 136 |
)
|
| 137 |
results[model_key] = result
|
| 138 |
+
|
| 139 |
# 格式化输出
|
| 140 |
def format_result(result):
|
| 141 |
if "error" in result:
|
| 142 |
return json.dumps(result, indent=2, ensure_ascii=False)
|
| 143 |
+
|
| 144 |
formatted = {
|
| 145 |
+
"生成文本": result["generated_text"],
|
| 146 |
+
"推断时间": f"{result['inference_time']}s",
|
| 147 |
+
"生成Token数": result["output_length"],
|
| 148 |
+
"生成速度": f"{result['output_length']/max(result['inference_time'], 0.001):.1f} tokens/s"
|
| 149 |
}
|
| 150 |
return json.dumps(formatted, indent=2, ensure_ascii=False)
|
| 151 |
+
|
| 152 |
gpt2_result = format_result(results.get("GPT2-Small", {}))
|
| 153 |
distilgpt2_result = format_result(results.get("DistilGPT2", {}))
|
| 154 |
gpt2_medium_result = format_result(results.get("GPT2-Medium", {}))
|
| 155 |
+
|
| 156 |
return gpt2_result, distilgpt2_result, gpt2_medium_result
|
| 157 |
|
| 158 |
def calculate_grace_scores_for_generation():
|
|
|
|
| 186 |
def create_generation_radar_chart():
|
| 187 |
"""创建文本生成GRACE评估���达图"""
|
| 188 |
grace_scores = calculate_grace_scores_for_generation()
|
| 189 |
+
# 类别名称翻译,但在图表中为了保持GRACE框架的名称一致性,这里保留英文,但在标题和描述中会使用中文
|
| 190 |
categories = ['Generalization', 'Relevance', 'Artistry', 'Consistency', 'Efficiency']
|
| 191 |
+
|
| 192 |
fig = go.Figure()
|
| 193 |
+
|
| 194 |
for i, (model_name, scores) in enumerate(grace_scores.items()):
|
| 195 |
values = [scores[cat] for cat in categories]
|
| 196 |
color = MODEL_CONFIGS[model_name]["color"]
|
| 197 |
+
|
| 198 |
fig.add_trace(go.Scatterpolar(
|
| 199 |
r=values,
|
| 200 |
theta=categories,
|
|
|
|
| 204 |
fillcolor=color,
|
| 205 |
opacity=0.6
|
| 206 |
))
|
| 207 |
+
|
| 208 |
fig.update_layout(
|
| 209 |
polar=dict(
|
| 210 |
radialaxis=dict(
|
|
|
|
| 215 |
),
|
| 216 |
showlegend=True,
|
| 217 |
title={
|
| 218 |
+
'text': "GRACE框架:文本生成模型评估",
|
| 219 |
'x': 0.5,
|
| 220 |
'font': {'size': 16}
|
| 221 |
},
|
| 222 |
width=600,
|
| 223 |
height=500
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
)
|
| 225 |
+
|
| 226 |
return fig
|
| 227 |
|
| 228 |
def create_performance_bar_chart():
|
| 229 |
"""创建性能对比柱状图"""
|
| 230 |
grace_scores = calculate_grace_scores_for_generation()
|
| 231 |
+
|
| 232 |
models = list(grace_scores.keys())
|
| 233 |
+
# 类别名称翻译
|
| 234 |
categories = ['Generalization', 'Relevance', 'Artistry', 'Consistency', 'Efficiency']
|
| 235 |
+
|
| 236 |
fig = go.Figure()
|
| 237 |
+
|
| 238 |
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#F7DC6F', '#BB8FCE']
|
| 239 |
+
|
| 240 |
for i, category in enumerate(categories):
|
| 241 |
values = [grace_scores[model][category] for model in models]
|
| 242 |
+
|
| 243 |
fig.add_trace(go.Bar(
|
| 244 |
name=category,
|
| 245 |
x=models,
|
|
|
|
| 247 |
marker_color=colors[i % len(colors)],
|
| 248 |
opacity=0.8
|
| 249 |
))
|
| 250 |
+
|
| 251 |
fig.update_layout(
|
| 252 |
+
title='GRACE框架详细对比 - 文本生成',
|
| 253 |
+
xaxis_title='模型',
|
| 254 |
+
yaxis_title='分数 (0-10)',
|
| 255 |
barmode='group',
|
| 256 |
width=700,
|
| 257 |
height=400
|
| 258 |
)
|
| 259 |
+
|
| 260 |
return fig
|
| 261 |
|
| 262 |
def create_model_info_table():
|
|
|
|
| 265 |
for model_key, config in MODEL_CONFIGS.items():
|
| 266 |
# 模拟参数信息
|
| 267 |
if "small" in model_key.lower() or model_key == "GPT2-Small":
|
| 268 |
+
params = "1.24亿"
|
| 269 |
size = "~500MB"
|
| 270 |
elif "distil" in model_key.lower():
|
| 271 |
+
params = "8200万"
|
| 272 |
size = "~350MB"
|
| 273 |
else:
|
| 274 |
+
params = "3.55亿"
|
| 275 |
size = "~1.4GB"
|
| 276 |
+
|
| 277 |
model_info.append({
|
| 278 |
+
"模型": model_key,
|
| 279 |
+
"参数量": params,
|
| 280 |
+
"模型大小": size,
|
| 281 |
+
"描述": config["description"],
|
| 282 |
+
"最大长度": config["max_length"]
|
| 283 |
})
|
|
|
|
| 284 |
return pd.DataFrame(model_info)
|
| 285 |
|
| 286 |
def create_summary_scores_table():
|
| 287 |
"""创建评分摘要表"""
|
| 288 |
grace_scores = calculate_grace_scores_for_generation()
|
| 289 |
+
|
| 290 |
summary_data = []
|
| 291 |
for model_name, scores in grace_scores.items():
|
| 292 |
avg_score = np.mean(list(scores.values()))
|
| 293 |
summary_data.append({
|
| 294 |
+
"模型": model_name,
|
| 295 |
+
"泛化性": scores["Generalization"],
|
| 296 |
+
"相关性": scores["Relevance"],
|
| 297 |
+
"艺术性": scores["Artistry"],
|
| 298 |
+
"一致性": scores["Consistency"],
|
| 299 |
+
"效率性": scores["Efficiency"],
|
| 300 |
+
"平均分": round(avg_score, 2)
|
| 301 |
})
|
|
|
|
| 302 |
df = pd.DataFrame(summary_data)
|
| 303 |
return df
|
| 304 |
|
| 305 |
# 预设的示例提示
|
| 306 |
EXAMPLE_PROMPTS = [
|
| 307 |
+
"很久很久以前,在一个魔法森林里,",
|
| 308 |
+
"人工智能的未来是",
|
| 309 |
+
"在2050年,人们将会",
|
| 310 |
+
"我学到的最重要的一课是",
|
| 311 |
+
"科技改变了我们的生活,因为"
|
| 312 |
]
|
| 313 |
|
| 314 |
def create_app():
|
| 315 |
+
with gr.Blocks(title="文本生成模型对比", theme=gr.themes.Soft()) as app:
|
| 316 |
+
gr.Markdown("# 📝 文本生成模型对比竞技场")
|
| 317 |
+
gr.Markdown("### 使用GRACE框架对比不同GPT-2模型在文本生成任务中的表现")
|
| 318 |
+
|
| 319 |
with gr.Tabs():
|
| 320 |
# Arena选项卡
|
| 321 |
+
with gr.TabItem("🏟️ 生成竞技场"):
|
| 322 |
+
gr.Markdown("## 文本生成竞技场")
|
| 323 |
+
gr.Markdown("输入一个提示文本,查看不同GPT-2模型如何续写。")
|
| 324 |
+
|
| 325 |
with gr.Row():
|
| 326 |
with gr.Column(scale=3):
|
| 327 |
input_prompt = gr.Textbox(
|
| 328 |
+
label="输入提示文本",
|
| 329 |
+
placeholder="请在这里输入您的文本提示...",
|
| 330 |
lines=3,
|
| 331 |
+
value="很久很久以前,在一个数字世界里,"
|
| 332 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
+
# 预设示例按钮
|
| 335 |
with gr.Row():
|
| 336 |
example_buttons = []
|
| 337 |
for i, example in enumerate(EXAMPLE_PROMPTS[:3]):
|
| 338 |
+
btn = gr.Button(f"示例 {i+1}", size="sm")
|
| 339 |
example_buttons.append(btn)
|
| 340 |
+
|
| 341 |
with gr.Column(scale=1):
|
| 342 |
max_length = gr.Slider(
|
| 343 |
minimum=10,
|
| 344 |
maximum=200,
|
| 345 |
value=50,
|
| 346 |
step=10,
|
| 347 |
+
label="最大新Token数"
|
| 348 |
)
|
| 349 |
+
|
| 350 |
temperature = gr.Slider(
|
| 351 |
minimum=0.1,
|
| 352 |
maximum=2.0,
|
| 353 |
value=0.7,
|
| 354 |
step=0.1,
|
| 355 |
+
label="温度 (Temperature)"
|
| 356 |
)
|
| 357 |
+
|
| 358 |
top_p = gr.Slider(
|
| 359 |
minimum=0.1,
|
| 360 |
maximum=1.0,
|
|
|
|
| 362 |
step=0.05,
|
| 363 |
label="Top-p"
|
| 364 |
)
|
| 365 |
+
|
| 366 |
+
submit_btn = gr.Button("🚀 生成文本", variant="primary", size="lg")
|
| 367 |
+
|
| 368 |
# 设置示例按钮点击事件
|
| 369 |
for i, btn in enumerate(example_buttons):
|
| 370 |
btn.click(
|
| 371 |
fn=lambda x=EXAMPLE_PROMPTS[i]: x,
|
| 372 |
outputs=[input_prompt]
|
| 373 |
)
|
| 374 |
+
|
| 375 |
with gr.Row():
|
| 376 |
with gr.Column():
|
| 377 |
gpt2_output = gr.Code(
|
| 378 |
+
label="GPT2-Small (1.24亿参数)",
|
| 379 |
language="json",
|
| 380 |
+
value="点击“生成文本”查看结果"
|
| 381 |
)
|
| 382 |
+
|
| 383 |
with gr.Column():
|
| 384 |
distilgpt2_output = gr.Code(
|
| 385 |
+
label="DistilGPT2 (8200万参数)",
|
| 386 |
language="json",
|
| 387 |
+
value="点击“生成文本”查看结果"
|
| 388 |
)
|
| 389 |
+
|
| 390 |
with gr.Column():
|
| 391 |
gpt2_medium_output = gr.Code(
|
| 392 |
+
label="GPT2-Medium (3.55亿参数)",
|
| 393 |
+
language="json",
|
| 394 |
+
value="点击“生成文本”查看结���"
|
| 395 |
)
|
| 396 |
+
|
| 397 |
submit_btn.click(
|
| 398 |
fn=run_text_generation_comparison,
|
| 399 |
inputs=[input_prompt, max_length, temperature, top_p],
|
| 400 |
outputs=[gpt2_output, distilgpt2_output, gpt2_medium_output]
|
| 401 |
)
|
| 402 |
+
|
| 403 |
# Benchmark选项卡
|
| 404 |
+
with gr.TabItem("📊 GRACE 基准测试"):
|
| 405 |
+
gr.Markdown("## GRACE框架对文本生成的评估")
|
| 406 |
gr.Markdown("""
|
| 407 |
+
**GRACE框架在文本生成中的维度定义:**
|
| 408 |
+
- **G**eneralization (泛化性): 处理多样化提示和主题的能力
|
| 409 |
+
- **R**elevance (相关性): 输出与输入提示的逻辑连贯性
|
| 410 |
+
- **A**rtistry (艺术性): 创造性、连贯性和语言质量
|
| 411 |
+
- **C**onsistency (一致性): 多次生成时的可靠性和稳定性
|
| 412 |
+
- **E**fficiency (效率性): 生成速度和计算资源需求
|
| 413 |
""")
|
| 414 |
+
|
| 415 |
with gr.Row():
|
| 416 |
radar_plot = gr.Plot(
|
| 417 |
value=create_generation_radar_chart(),
|
| 418 |
+
label="GRACE 雷达图"
|
| 419 |
)
|
| 420 |
+
|
| 421 |
with gr.Row():
|
| 422 |
bar_plot = gr.Plot(
|
| 423 |
value=create_performance_bar_chart(),
|
| 424 |
+
label="详细性能对比"
|
|
|
|
| 425 |
)
|
| 426 |
+
|
| 427 |
with gr.Row():
|
| 428 |
with gr.Column():
|
| 429 |
model_info_df = create_model_info_table()
|
| 430 |
model_info_table = gr.Dataframe(
|
| 431 |
value=model_info_df,
|
| 432 |
+
label="模型信息",
|
| 433 |
interactive=False
|
| 434 |
)
|
| 435 |
+
|
| 436 |
with gr.Column():
|
| 437 |
summary_df = create_summary_scores_table()
|
| 438 |
summary_table = gr.Dataframe(
|
| 439 |
value=summary_df,
|
| 440 |
+
label="GRACE 评分摘要",
|
| 441 |
interactive=False
|
| 442 |
)
|
| 443 |
|
|
|
|
| 634 |
|
| 635 |
## 5. 合作与反思
|
| 636 |
|
| 637 |
+
### 成员一:谭秀辉
|
| 638 |
- **负责内容**:
|
| 639 |
- 模型集成和pipeline构建
|
| 640 |
- Arena界面开发和交互逻辑
|
|
|
|
| 652 |
- 生成质量的客观评估方法设计
|
| 653 |
- CPU推理性能优化
|
| 654 |
|
| 655 |
+
### 成员二:王旌旗
|
| 656 |
- **负责内容**:
|
| 657 |
- GRACE评估框架的文本生成适配
|
| 658 |
- 数据可视化和图表制作
|