txh17 commited on
Commit
26364d2
·
verified ·
1 Parent(s): 9579639

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -30
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  import pandas as pd
3
  import plotly.graph_objects as go
4
  import plotly.express as px
@@ -8,7 +9,6 @@ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
8
  import torch
9
  import json
10
  import re
11
- from googletrans import Translator # 导入翻译库
12
 
13
  # 选择三个小型文本生成模型
14
  MODEL_CONFIGS = {
@@ -37,7 +37,6 @@ class TextGenerationComparator:
37
  self.models = {}
38
  self.tokenizers = {}
39
  self.load_models()
40
- self.translator = Translator() # 初始化翻译器
41
 
42
  def load_models(self):
43
  """加载所有文本生成模型"""
@@ -59,26 +58,11 @@ class TextGenerationComparator:
59
  # 创建一个mock模型用于演示
60
  self.models[model_key] = None
61
 
62
- def translate_to_chinese(self, text):
63
- """将文本翻译成中文"""
64
- try:
65
- # 尝试翻译,如果原文是中文则不翻译
66
- if re.search(r'[\u4e00-\u9fff]', text): # 检查是否包含中文字符
67
- return text
68
- translation = self.translator.translate(text, dest='zh-cn')
69
- return translation.text
70
- except Exception as e:
71
- print(f"翻译失败: {e}")
72
- return text # 翻译失败时返回原文
73
-
74
  def generate_text(self, model_key, prompt, max_length=50, temperature=0.7, top_p=0.9):
75
  """使用指定模型生成文本"""
76
  if self.models[model_key] is None:
77
- # 模拟输出也翻译成中文
78
- mock_output_en = f"[Model {model_key} not loaded correctly, this is a simulated output] {prompt} and this is a sample continuation of the text..."
79
- mock_output_zh = self.translate_to_chinese(mock_output_en)
80
  return {
81
- "generated_text": mock_output_zh,
82
  "inference_time": 0.5,
83
  "input_length": len(prompt.split()),
84
  "output_length": max_length,
@@ -107,19 +91,16 @@ class TextGenerationComparator:
107
  end_time = time.time()
108
 
109
  # 提取生成的文本(去除原始prompt)
110
- generated_text_en = result[0]['generated_text']
111
- if generated_text_en.startswith(prompt):
112
- generated_text_en = generated_text_en[len(prompt):].strip()
113
-
114
- # 将生成的英文文本翻译成中文
115
- generated_text_zh = self.translate_to_chinese(generated_text_en)
116
 
117
  return {
118
- "generated_text": generated_text_zh, # 返回中文文本
119
- "full_text": self.translate_to_chinese(result[0]['generated_text']), # 完整文本也翻译
120
  "inference_time": round(end_time - start_time, 3),
121
  "input_length": len(prompt.split()),
122
- "output_length": len(generated_text_zh.split()), # 基于中文文本长度
123
  "parameters": {
124
  "temperature": temperature,
125
  "top_p": top_p,
@@ -160,11 +141,11 @@ def run_text_generation_comparison(prompt, max_length, temperature, top_p):
160
  if "error" in result:
161
  return json.dumps(result, indent=2, ensure_ascii=False)
162
 
163
- # 这里的键名保留中文,值会是中文
164
  formatted = {
165
  "生成文本": result["generated_text"],
166
  "推断时间": f"{result['inference_time']}s",
167
- "生成Token数": result["output_length"], # 这里的token数可能因为翻译导致不准确,但保留原逻辑
168
  "生成速度": f"{result['output_length']/max(result['inference_time'], 0.001):.1f} tokens/s"
169
  }
170
  return json.dumps(formatted, indent=2, ensure_ascii=False)
@@ -340,7 +321,7 @@ def create_app():
340
  # Arena选项卡
341
  with gr.TabItem("🏟️ 生成竞技场"):
342
  gr.Markdown("## 文本生成竞技场")
343
- gr.Markdown("输入一个**英文**提示文本,查看不同GPT-2模型如何续写,续写结果将翻译成中文显示。")
344
 
345
  with gr.Row():
346
  with gr.Column(scale=3):
 
1
  import gradio as gr
2
+
3
  import pandas as pd
4
  import plotly.graph_objects as go
5
  import plotly.express as px
 
9
  import torch
10
  import json
11
  import re
 
12
 
13
  # 选择三个小型文本生成模型
14
  MODEL_CONFIGS = {
 
37
  self.models = {}
38
  self.tokenizers = {}
39
  self.load_models()
 
40
 
41
  def load_models(self):
42
  """加载所有文本生成模型"""
 
58
  # 创建一个mock模型用于演示
59
  self.models[model_key] = None
60
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def generate_text(self, model_key, prompt, max_length=50, temperature=0.7, top_p=0.9):
62
  """使用指定模型生成文本"""
63
  if self.models[model_key] is None:
 
 
 
64
  return {
65
+ "generated_text": f"[Model {model_key} not loaded correctly, this is a simulated output] {prompt} and this is a sample continuation of the text...",
66
  "inference_time": 0.5,
67
  "input_length": len(prompt.split()),
68
  "output_length": max_length,
 
91
  end_time = time.time()
92
 
93
  # 提取生成的文本(去除原始prompt)
94
+ generated_text = result[0]['generated_text']
95
+ if generated_text.startswith(prompt):
96
+ generated_text = generated_text[len(prompt):].strip()
 
 
 
97
 
98
  return {
99
+ "generated_text": generated_text,
100
+ "full_text": result[0]['generated_text'],
101
  "inference_time": round(end_time - start_time, 3),
102
  "input_length": len(prompt.split()),
103
+ "output_length": len(generated_text.split()),
104
  "parameters": {
105
  "temperature": temperature,
106
  "top_p": top_p,
 
141
  if "error" in result:
142
  return json.dumps(result, indent=2, ensure_ascii=False)
143
 
144
+ # 这里的键名保留中文,值会是英文
145
  formatted = {
146
  "生成文本": result["generated_text"],
147
  "推断时间": f"{result['inference_time']}s",
148
+ "生成Token数": result["output_length"],
149
  "生成速度": f"{result['output_length']/max(result['inference_time'], 0.001):.1f} tokens/s"
150
  }
151
  return json.dumps(formatted, indent=2, ensure_ascii=False)
 
321
  # Arena选项卡
322
  with gr.TabItem("🏟️ 生成竞技场"):
323
  gr.Markdown("## 文本生成竞技场")
324
+ gr.Markdown("输入一个**英文**提示文本,查看不同GPT-2模型如何续写,续写结果将保持英文显示。")
325
 
326
  with gr.Row():
327
  with gr.Column(scale=3):