txh17 commited on
Commit
ce873b2
·
verified ·
1 Parent(s): 39e5ba2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -21
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
-
3
  import pandas as pd
4
  import plotly.graph_objects as go
5
  import plotly.express as px
@@ -9,6 +8,7 @@ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
9
  import torch
10
  import json
11
  import re
 
12
 
13
  # 选择三个小型文本生成模型
14
  MODEL_CONFIGS = {
@@ -37,6 +37,7 @@ class TextGenerationComparator:
37
  self.models = {}
38
  self.tokenizers = {}
39
  self.load_models()
 
40
 
41
  def load_models(self):
42
  """加载所有文本生成模型"""
@@ -58,11 +59,26 @@ class TextGenerationComparator:
58
  # 创建一个mock模型用于演示
59
  self.models[model_key] = None
60
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def generate_text(self, model_key, prompt, max_length=50, temperature=0.7, top_p=0.9):
62
  """使用指定模型生成文本"""
63
  if self.models[model_key] is None:
 
 
 
64
  return {
65
- "generated_text": f"[模型 {model_key} 未正确加载,这是模拟输出] {prompt} 这是一个模拟的文本续写...",
66
  "inference_time": 0.5,
67
  "input_length": len(prompt.split()),
68
  "output_length": max_length,
@@ -91,16 +107,19 @@ class TextGenerationComparator:
91
  end_time = time.time()
92
 
93
  # 提取生成的文本(去除原始prompt)
94
- generated_text = result[0]['generated_text']
95
- if generated_text.startswith(prompt):
96
- generated_text = generated_text[len(prompt):].strip()
 
 
 
97
 
98
  return {
99
- "generated_text": generated_text,
100
- "full_text": result[0]['generated_text'],
101
  "inference_time": round(end_time - start_time, 3),
102
  "input_length": len(prompt.split()),
103
- "output_length": len(generated_text.split()),
104
  "parameters": {
105
  "temperature": temperature,
106
  "top_p": top_p,
@@ -122,7 +141,7 @@ comparator = TextGenerationComparator()
122
  def run_text_generation_comparison(prompt, max_length, temperature, top_p):
123
  """运行所有模型的文本生成对比"""
124
  if not prompt.strip():
125
- return "请输入提示文本", "请输入提示文本", "请输入提示文本"
126
 
127
  results = {}
128
 
@@ -141,10 +160,11 @@ def run_text_generation_comparison(prompt, max_length, temperature, top_p):
141
  if "error" in result:
142
  return json.dumps(result, indent=2, ensure_ascii=False)
143
 
 
144
  formatted = {
145
  "生成文本": result["generated_text"],
146
  "推断时间": f"{result['inference_time']}s",
147
- "生成Token数": result["output_length"],
148
  "生成速度": f"{result['output_length']/max(result['inference_time'], 0.001):.1f} tokens/s"
149
  }
150
  return json.dumps(formatted, indent=2, ensure_ascii=False)
@@ -302,13 +322,13 @@ def create_summary_scores_table():
302
  df = pd.DataFrame(summary_data)
303
  return df
304
 
305
- # 预设的示例提示
306
  EXAMPLE_PROMPTS = [
307
- "很久很久以前,在一个魔法森林里,",
308
- "人工智能的未来是",
309
- "2050年,人们将会",
310
- "我学到的最重要的一课是",
311
- "科技改变了我们的生活,因为"
312
  ]
313
 
314
  def create_app():
@@ -320,22 +340,22 @@ def create_app():
320
  # Arena选项卡
321
  with gr.TabItem("🏟️ 生成竞技场"):
322
  gr.Markdown("## 文本生成竞技场")
323
- gr.Markdown("输入一个提示文本,查看不同GPT-2模型如何续写。")
324
 
325
  with gr.Row():
326
  with gr.Column(scale=3):
327
  input_prompt = gr.Textbox(
328
- label="输入提示文本",
329
- placeholder="请在这里输入您的文本提示...",
330
  lines=3,
331
- value="很久很久以前,在一个数字世界里,"
332
  )
333
 
334
  # 预设示例按钮
335
  with gr.Row():
336
  example_buttons = []
337
  for i, example in enumerate(EXAMPLE_PROMPTS[:3]):
338
- btn = gr.Button(f"示例 {i+1}", size="sm")
339
  example_buttons.append(btn)
340
 
341
  with gr.Column(scale=1):
 
1
  import gradio as gr
 
2
  import pandas as pd
3
  import plotly.graph_objects as go
4
  import plotly.express as px
 
8
  import torch
9
  import json
10
  import re
11
+ from googletrans import Translator # 导入翻译库
12
 
13
  # 选择三个小型文本生成模型
14
  MODEL_CONFIGS = {
 
37
  self.models = {}
38
  self.tokenizers = {}
39
  self.load_models()
40
+ self.translator = Translator() # 初始化翻译器
41
 
42
  def load_models(self):
43
  """加载所有文本生成模型"""
 
59
  # 创建一个mock模型用于演示
60
  self.models[model_key] = None
61
 
62
+ def translate_to_chinese(self, text):
63
+ """将文本翻译成中文"""
64
+ try:
65
+ # 尝试翻译,如果原文是中文则不翻译
66
+ if re.search(r'[\u4e00-\u9fff]', text): # 检查是否包含中文字符
67
+ return text
68
+ translation = self.translator.translate(text, dest='zh-cn')
69
+ return translation.text
70
+ except Exception as e:
71
+ print(f"翻译失败: {e}")
72
+ return text # 翻译失败时返回原文
73
+
74
  def generate_text(self, model_key, prompt, max_length=50, temperature=0.7, top_p=0.9):
75
  """使用指定模型生成文本"""
76
  if self.models[model_key] is None:
77
+ # 模拟输出也翻译成中文
78
+ mock_output_en = f"[Model {model_key} not loaded correctly, this is a simulated output] {prompt} and this is a sample continuation of the text..."
79
+ mock_output_zh = self.translate_to_chinese(mock_output_en)
80
  return {
81
+ "generated_text": mock_output_zh,
82
  "inference_time": 0.5,
83
  "input_length": len(prompt.split()),
84
  "output_length": max_length,
 
107
  end_time = time.time()
108
 
109
  # 提取生成的文本(去除原始prompt)
110
+ generated_text_en = result[0]['generated_text']
111
+ if generated_text_en.startswith(prompt):
112
+ generated_text_en = generated_text_en[len(prompt):].strip()
113
+
114
+ # 将生成的英文文本翻译成中文
115
+ generated_text_zh = self.translate_to_chinese(generated_text_en)
116
 
117
  return {
118
+ "generated_text": generated_text_zh, # 返回中文文本
119
+ "full_text": self.translate_to_chinese(result[0]['generated_text']), # 完整文本也翻译
120
  "inference_time": round(end_time - start_time, 3),
121
  "input_length": len(prompt.split()),
122
+ "output_length": len(generated_text_zh.split()), # 基于中文文本长度
123
  "parameters": {
124
  "temperature": temperature,
125
  "top_p": top_p,
 
141
  def run_text_generation_comparison(prompt, max_length, temperature, top_p):
142
  """运行所有模型的文本生成对比"""
143
  if not prompt.strip():
144
+ return "Please enter a prompt.", "Please enter a prompt.", "Please enter a prompt." # 提示文本为英文
145
 
146
  results = {}
147
 
 
160
  if "error" in result:
161
  return json.dumps(result, indent=2, ensure_ascii=False)
162
 
163
+ # 这里的键名保留中文,值会是中文
164
  formatted = {
165
  "生成文本": result["generated_text"],
166
  "推断时间": f"{result['inference_time']}s",
167
+ "生成Token数": result["output_length"], # 这里的token数可能因为翻译导致不准确,但保留原逻辑
168
  "生成速度": f"{result['output_length']/max(result['inference_time'], 0.001):.1f} tokens/s"
169
  }
170
  return json.dumps(formatted, indent=2, ensure_ascii=False)
 
322
  df = pd.DataFrame(summary_data)
323
  return df
324
 
325
+ # 预设的示例提示(英文)
326
  EXAMPLE_PROMPTS = [
327
+ "Once upon a time in a magical forest,",
328
+ "The future of artificial intelligence is",
329
+ "In the year 2050, people will",
330
+ "The most important lesson I learned was",
331
+ "Technology has changed our lives by"
332
  ]
333
 
334
  def create_app():
 
340
  # Arena选项卡
341
  with gr.TabItem("🏟️ 生成竞技场"):
342
  gr.Markdown("## 文本生成竞技场")
343
+ gr.Markdown("输入一个**英文**提示文本,查看不同GPT-2模型如何续写,续写结果将翻译成中文显示。")
344
 
345
  with gr.Row():
346
  with gr.Column(scale=3):
347
  input_prompt = gr.Textbox(
348
+ label="Input Prompt (英文提示文本)", # 标签为英文
349
+ placeholder="Please enter your English text prompt here...", # 占位符为英文
350
  lines=3,
351
+ value="Once upon a time in a digital world," # 初始值为英文
352
  )
353
 
354
  # 预设示例按钮
355
  with gr.Row():
356
  example_buttons = []
357
  for i, example in enumerate(EXAMPLE_PROMPTS[:3]):
358
+ btn = gr.Button(f"Example {i+1}", size="sm")
359
  example_buttons.append(btn)
360
 
361
  with gr.Column(scale=1):