txh17 commited on
Commit
1fc9c49
·
verified ·
1 Parent(s): 26364d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -225
app.py CHANGED
@@ -1,230 +1,5 @@
1
- import gradio as gr
2
-
3
- import pandas as pd
4
- import plotly.graph_objects as go
5
- import plotly.express as px
6
- import time
7
- import numpy as np
8
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
9
- import torch
10
- import json
11
- import re
12
-
13
- # 选择三个小型文本生成模型
14
- MODEL_CONFIGS = {
15
- "GPT2-Small": {
16
- "model_name": "gpt2",
17
- "description": "OpenAI的GPT-2小型模型(1.24亿参数)",
18
- "max_length": 100,
19
- "color": "#FF6B6B"
20
- },
21
- "DistilGPT2": {
22
- "model_name": "distilgpt2",
23
- "description": "GPT-2的蒸馏版本(8200万参数)",
24
- "max_length": 100,
25
- "color": "#4ECDC4"
26
- },
27
- "GPT2-Medium": {
28
- "model_name": "gpt2-medium",
29
- "description": "GPT-2中型模型(3.55亿参数)",
30
- "max_length": 100,
31
- "color": "#45B7D1"
32
- }
33
- }
34
-
35
- class TextGenerationComparator:
36
- def __init__(self):
37
- self.models = {}
38
- self.tokenizers = {}
39
- self.load_models()
40
-
41
- def load_models(self):
42
- """加载所有文本生成模型"""
43
- print("正在加载模型...")
44
- for model_key, config in MODEL_CONFIGS.items():
45
- try:
46
- print(f"加载 {model_key}...")
47
- # 使用pipeline方式加载,更简单且内存友好
48
- self.models[model_key] = pipeline(
49
- "text-generation",
50
- model=config["model_name"],
51
- tokenizer=config["model_name"],
52
- device=-1, # 使用CPU,避免GPU内存问题
53
- torch_dtype=torch.float32
54
- )
55
- print(f"✓ {model_key} 加载成功")
56
- except Exception as e:
57
- print(f"✗ {model_key} 加载失败: {e}")
58
- # 创建一个mock模型用于演示
59
- self.models[model_key] = None
60
-
61
- def generate_text(self, model_key, prompt, max_length=50, temperature=0.7, top_p=0.9):
62
- """使用指定模型生成文本"""
63
- if self.models[model_key] is None:
64
- return {
65
- "generated_text": f"[Model {model_key} not loaded correctly, this is a simulated output] {prompt} and this is a sample continuation of the text...",
66
- "inference_time": 0.5,
67
- "input_length": len(prompt.split()),
68
- "output_length": max_length,
69
- "parameters": {
70
- "temperature": temperature,
71
- "top_p": top_p,
72
- "max_length": max_length
73
- }
74
- }
75
-
76
- try:
77
- start_time = time.time()
78
-
79
- # 生成文本
80
- result = self.models[model_key](
81
- prompt,
82
- max_length=len(prompt.split()) + max_length,
83
- temperature=temperature,
84
- top_p=top_p,
85
- do_sample=True,
86
- pad_token_id=50256, # GPT-2的pad token
87
- num_return_sequences=1,
88
- truncation=True
89
- )
90
-
91
- end_time = time.time()
92
-
93
- # 提取生成的文本(去除原始prompt)
94
- generated_text = result[0]['generated_text']
95
- if generated_text.startswith(prompt):
96
- generated_text = generated_text[len(prompt):].strip()
97
-
98
- return {
99
- "generated_text": generated_text,
100
- "full_text": result[0]['generated_text'],
101
- "inference_time": round(end_time - start_time, 3),
102
- "input_length": len(prompt.split()),
103
- "output_length": len(generated_text.split()),
104
- "parameters": {
105
- "temperature": temperature,
106
- "top_p": top_p,
107
- "max_length": max_length
108
- }
109
- }
110
-
111
- except Exception as e:
112
- return {
113
- "error": f"生成错误: {str(e)}",
114
- "inference_time": 0,
115
- "input_length": 0,
116
- "output_length": 0
117
- }
118
-
119
- # 初始化比较器
120
- comparator = TextGenerationComparator()
121
-
122
- def run_text_generation_comparison(prompt, max_length, temperature, top_p):
123
- """运行所有模型的文本生成对比"""
124
- if not prompt.strip():
125
- return "Please enter a prompt.", "Please enter a prompt.", "Please enter a prompt." # 提示文本为英文
126
-
127
- results = {}
128
-
129
- for model_key in MODEL_CONFIGS.keys():
130
- result = comparator.generate_text(
131
- model_key,
132
- prompt,
133
- max_length=int(max_length),
134
- temperature=temperature,
135
- top_p=top_p
136
- )
137
- results[model_key] = result
138
-
139
- # 格式化输出
140
- def format_result(result):
141
- if "error" in result:
142
- return json.dumps(result, indent=2, ensure_ascii=False)
143
-
144
- # 这里的键名保留中文,值会是英文
145
- formatted = {
146
- "生成文本": result["generated_text"],
147
- "推断时间": f"{result['inference_time']}s",
148
- "生成Token数": result["output_length"],
149
- "生成速度": f"{result['output_length']/max(result['inference_time'], 0.001):.1f} tokens/s"
150
- }
151
- return json.dumps(formatted, indent=2, ensure_ascii=False)
152
-
153
- gpt2_result = format_result(results.get("GPT2-Small", {}))
154
- distilgpt2_result = format_result(results.get("DistilGPT2", {}))
155
- gpt2_medium_result = format_result(results.get("GPT2-Medium", {}))
156
-
157
- return gpt2_result, distilgpt2_result, gpt2_medium_result
158
-
159
- def calculate_grace_scores_for_generation():
160
- """为文本生成任务计算GRACE评估分数"""
161
- # 基于文本生成任务特点的GRACE评分
162
- grace_data = {
163
- "GPT2-Small": {
164
- "Generalization": 7.5, # 中等泛化能力,适用多种文本类型
165
- "Relevance": 8.2, # 与输入提示相关性较好
166
- "Artistry": 7.8, # 创造性和表达力中等
167
- "Consistency": 8.0, # 输出一致性良好
168
- "Efficiency": 9.2 # 小模型,效率很高
169
- },
170
- "DistilGPT2": {
171
- "Generalization": 7.2, # 蒸馏模型,泛化能力略低
172
- "Relevance": 7.9, # 相关性稍低于原模型
173
- "Artistry": 7.5, # 创造性受蒸馏影响
174
- "Consistency": 7.8, # 一致性略有损失
175
- "Efficiency": 9.8 # 最小模型,效率最高
176
- },
177
- "GPT2-Medium": {
178
- "Generalization": 8.8, # 更大模型,更好的泛化
179
- "Relevance": 9.1, # 更好的上下文理解
180
- "Artistry": 8.9, # 更强的创造性表达
181
- "Consistency": 8.7, # 更一致的输出质量
182
- "Efficiency": 6.5 # 较大模型,效率较低
183
- }
184
- }
185
- return grace_data
186
-
187
- def create_generation_radar_chart():
188
- """创建文本生成GRACE评估雷达图"""
189
- grace_scores = calculate_grace_scores_for_generation()
190
- # 类别名称翻译,但在图表中为了保持GRACE框架的名称一致性,这里保留英文,但在标题和描述中会使用中文
191
- categories = ['Generalization', 'Relevance', 'Artistry', 'Consistency', 'Efficiency']
192
-
193
- fig = go.Figure()
194
 
195
- for i, (model_name, scores) in enumerate(grace_scores.items()):
196
- values = [scores[cat] for cat in categories]
197
- color = MODEL_CONFIGS[model_name]["color"]
198
-
199
- fig.add_trace(go.Scatterpolar(
200
- r=values,
201
- theta=categories,
202
- fill='toself',
203
- name=model_name,
204
- line_color=color,
205
- fillcolor=color,
206
- opacity=0.6
207
- ))
208
 
209
- fig.update_layout(
210
- polar=dict(
211
- radialaxis=dict(
212
- visible=True,
213
- range=[0, 10],
214
- tickfont=dict(size=10)
215
- )
216
- ),
217
- showlegend=True,
218
- title={
219
- 'text': "GRACE框架:文本生成模型评估",
220
- 'x': 0.5,
221
- 'font': {'size': 16}
222
- },
223
- width=600,
224
- height=500
225
- )
226
-
227
- return fig
228
 
229
  def create_performance_bar_chart():
230
  """创建性能对比柱状图"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  def create_performance_bar_chart():
5
  """创建性能对比柱状图"""