Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
|
| 3 |
import pandas as pd
|
| 4 |
import plotly.graph_objects as go
|
| 5 |
import plotly.express as px
|
|
@@ -9,6 +8,7 @@ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
|
|
| 9 |
import torch
|
| 10 |
import json
|
| 11 |
import re
|
|
|
|
| 12 |
|
| 13 |
# 选择三个小型文本生成模型
|
| 14 |
MODEL_CONFIGS = {
|
|
@@ -37,6 +37,7 @@ class TextGenerationComparator:
|
|
| 37 |
self.models = {}
|
| 38 |
self.tokenizers = {}
|
| 39 |
self.load_models()
|
|
|
|
| 40 |
|
| 41 |
def load_models(self):
|
| 42 |
"""加载所有文本生成模型"""
|
|
@@ -58,11 +59,26 @@ class TextGenerationComparator:
|
|
| 58 |
# 创建一个mock模型用于演示
|
| 59 |
self.models[model_key] = None
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
def generate_text(self, model_key, prompt, max_length=50, temperature=0.7, top_p=0.9):
|
| 62 |
"""使用指定模型生成文本"""
|
| 63 |
if self.models[model_key] is None:
|
|
|
|
|
|
|
|
|
|
| 64 |
return {
|
| 65 |
-
"generated_text":
|
| 66 |
"inference_time": 0.5,
|
| 67 |
"input_length": len(prompt.split()),
|
| 68 |
"output_length": max_length,
|
|
@@ -91,16 +107,19 @@ class TextGenerationComparator:
|
|
| 91 |
end_time = time.time()
|
| 92 |
|
| 93 |
# 提取生成的文本(去除原始prompt)
|
| 94 |
-
|
| 95 |
-
if
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
return {
|
| 99 |
-
"generated_text":
|
| 100 |
-
"full_text": result[0]['generated_text'],
|
| 101 |
"inference_time": round(end_time - start_time, 3),
|
| 102 |
"input_length": len(prompt.split()),
|
| 103 |
-
"output_length": len(
|
| 104 |
"parameters": {
|
| 105 |
"temperature": temperature,
|
| 106 |
"top_p": top_p,
|
|
@@ -122,7 +141,7 @@ comparator = TextGenerationComparator()
|
|
| 122 |
def run_text_generation_comparison(prompt, max_length, temperature, top_p):
|
| 123 |
"""运行所有模型的文本生成对比"""
|
| 124 |
if not prompt.strip():
|
| 125 |
-
return "
|
| 126 |
|
| 127 |
results = {}
|
| 128 |
|
|
@@ -141,10 +160,11 @@ def run_text_generation_comparison(prompt, max_length, temperature, top_p):
|
|
| 141 |
if "error" in result:
|
| 142 |
return json.dumps(result, indent=2, ensure_ascii=False)
|
| 143 |
|
|
|
|
| 144 |
formatted = {
|
| 145 |
"生成文本": result["generated_text"],
|
| 146 |
"推断时间": f"{result['inference_time']}s",
|
| 147 |
-
"生成Token数": result["output_length"],
|
| 148 |
"生成速度": f"{result['output_length']/max(result['inference_time'], 0.001):.1f} tokens/s"
|
| 149 |
}
|
| 150 |
return json.dumps(formatted, indent=2, ensure_ascii=False)
|
|
@@ -302,13 +322,13 @@ def create_summary_scores_table():
|
|
| 302 |
df = pd.DataFrame(summary_data)
|
| 303 |
return df
|
| 304 |
|
| 305 |
-
#
|
| 306 |
EXAMPLE_PROMPTS = [
|
| 307 |
-
"
|
| 308 |
-
"
|
| 309 |
-
"
|
| 310 |
-
"
|
| 311 |
-
"
|
| 312 |
]
|
| 313 |
|
| 314 |
def create_app():
|
|
@@ -320,22 +340,22 @@ def create_app():
|
|
| 320 |
# Arena选项卡
|
| 321 |
with gr.TabItem("🏟️ 生成竞技场"):
|
| 322 |
gr.Markdown("## 文本生成竞技场")
|
| 323 |
-
gr.Markdown("
|
| 324 |
|
| 325 |
with gr.Row():
|
| 326 |
with gr.Column(scale=3):
|
| 327 |
input_prompt = gr.Textbox(
|
| 328 |
-
label="
|
| 329 |
-
placeholder="
|
| 330 |
lines=3,
|
| 331 |
-
value="
|
| 332 |
)
|
| 333 |
|
| 334 |
# 预设示例按钮
|
| 335 |
with gr.Row():
|
| 336 |
example_buttons = []
|
| 337 |
for i, example in enumerate(EXAMPLE_PROMPTS[:3]):
|
| 338 |
-
btn = gr.Button(f"
|
| 339 |
example_buttons.append(btn)
|
| 340 |
|
| 341 |
with gr.Column(scale=1):
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
import pandas as pd
|
| 3 |
import plotly.graph_objects as go
|
| 4 |
import plotly.express as px
|
|
|
|
| 8 |
import torch
|
| 9 |
import json
|
| 10 |
import re
|
| 11 |
+
from googletrans import Translator # 导入翻译库
|
| 12 |
|
| 13 |
# 选择三个小型文本生成模型
|
| 14 |
MODEL_CONFIGS = {
|
|
|
|
| 37 |
self.models = {}
|
| 38 |
self.tokenizers = {}
|
| 39 |
self.load_models()
|
| 40 |
+
self.translator = Translator() # 初始化翻译器
|
| 41 |
|
| 42 |
def load_models(self):
|
| 43 |
"""加载所有文本生成模型"""
|
|
|
|
| 59 |
# 创建一个mock模型用于演示
|
| 60 |
self.models[model_key] = None
|
| 61 |
|
| 62 |
+
def translate_to_chinese(self, text):
|
| 63 |
+
"""将文本翻译成中文"""
|
| 64 |
+
try:
|
| 65 |
+
# 尝试翻译,如果原文是中文则不翻译
|
| 66 |
+
if re.search(r'[\u4e00-\u9fff]', text): # 检查是否包含中文字符
|
| 67 |
+
return text
|
| 68 |
+
translation = self.translator.translate(text, dest='zh-cn')
|
| 69 |
+
return translation.text
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"翻译失败: {e}")
|
| 72 |
+
return text # 翻译失败时返回原文
|
| 73 |
+
|
| 74 |
def generate_text(self, model_key, prompt, max_length=50, temperature=0.7, top_p=0.9):
|
| 75 |
"""使用指定模型生成文本"""
|
| 76 |
if self.models[model_key] is None:
|
| 77 |
+
# 模拟输出也翻译成中文
|
| 78 |
+
mock_output_en = f"[Model {model_key} not loaded correctly, this is a simulated output] {prompt} and this is a sample continuation of the text..."
|
| 79 |
+
mock_output_zh = self.translate_to_chinese(mock_output_en)
|
| 80 |
return {
|
| 81 |
+
"generated_text": mock_output_zh,
|
| 82 |
"inference_time": 0.5,
|
| 83 |
"input_length": len(prompt.split()),
|
| 84 |
"output_length": max_length,
|
|
|
|
| 107 |
end_time = time.time()
|
| 108 |
|
| 109 |
# 提取生成的文本(去除原始prompt)
|
| 110 |
+
generated_text_en = result[0]['generated_text']
|
| 111 |
+
if generated_text_en.startswith(prompt):
|
| 112 |
+
generated_text_en = generated_text_en[len(prompt):].strip()
|
| 113 |
+
|
| 114 |
+
# 将生成的英文文本翻译成中文
|
| 115 |
+
generated_text_zh = self.translate_to_chinese(generated_text_en)
|
| 116 |
|
| 117 |
return {
|
| 118 |
+
"generated_text": generated_text_zh, # 返回中文文本
|
| 119 |
+
"full_text": self.translate_to_chinese(result[0]['generated_text']), # 完整文本也翻译
|
| 120 |
"inference_time": round(end_time - start_time, 3),
|
| 121 |
"input_length": len(prompt.split()),
|
| 122 |
+
"output_length": len(generated_text_zh.split()), # 基于中文文本长度
|
| 123 |
"parameters": {
|
| 124 |
"temperature": temperature,
|
| 125 |
"top_p": top_p,
|
|
|
|
| 141 |
def run_text_generation_comparison(prompt, max_length, temperature, top_p):
|
| 142 |
"""运行所有模型的文本生成对比"""
|
| 143 |
if not prompt.strip():
|
| 144 |
+
return "Please enter a prompt.", "Please enter a prompt.", "Please enter a prompt." # 提示文本为英文
|
| 145 |
|
| 146 |
results = {}
|
| 147 |
|
|
|
|
| 160 |
if "error" in result:
|
| 161 |
return json.dumps(result, indent=2, ensure_ascii=False)
|
| 162 |
|
| 163 |
+
# 这里的键名保留中文,值会是中文
|
| 164 |
formatted = {
|
| 165 |
"生成文本": result["generated_text"],
|
| 166 |
"推断时间": f"{result['inference_time']}s",
|
| 167 |
+
"生成Token数": result["output_length"], # 这里的token数可能因为翻译导致不准确,但保留原逻辑
|
| 168 |
"生成速度": f"{result['output_length']/max(result['inference_time'], 0.001):.1f} tokens/s"
|
| 169 |
}
|
| 170 |
return json.dumps(formatted, indent=2, ensure_ascii=False)
|
|
|
|
| 322 |
df = pd.DataFrame(summary_data)
|
| 323 |
return df
|
| 324 |
|
| 325 |
+
# 预设的示例提示(英文)
|
| 326 |
EXAMPLE_PROMPTS = [
|
| 327 |
+
"Once upon a time in a magical forest,",
|
| 328 |
+
"The future of artificial intelligence is",
|
| 329 |
+
"In the year 2050, people will",
|
| 330 |
+
"The most important lesson I learned was",
|
| 331 |
+
"Technology has changed our lives by"
|
| 332 |
]
|
| 333 |
|
| 334 |
def create_app():
|
|
|
|
| 340 |
# Arena选项卡
|
| 341 |
with gr.TabItem("🏟️ 生成竞技场"):
|
| 342 |
gr.Markdown("## 文本生成竞技场")
|
| 343 |
+
gr.Markdown("输入一个**英文**提示文本,查看不同GPT-2模型如何续写,续写结果将翻译成中文显示。")
|
| 344 |
|
| 345 |
with gr.Row():
|
| 346 |
with gr.Column(scale=3):
|
| 347 |
input_prompt = gr.Textbox(
|
| 348 |
+
label="Input Prompt (英文提示文本)", # 标签为英文
|
| 349 |
+
placeholder="Please enter your English text prompt here...", # 占位符为英文
|
| 350 |
lines=3,
|
| 351 |
+
value="Once upon a time in a digital world," # 初始值为英文
|
| 352 |
)
|
| 353 |
|
| 354 |
# 预设示例按钮
|
| 355 |
with gr.Row():
|
| 356 |
example_buttons = []
|
| 357 |
for i, example in enumerate(EXAMPLE_PROMPTS[:3]):
|
| 358 |
+
btn = gr.Button(f"Example {i+1}", size="sm")
|
| 359 |
example_buttons.append(btn)
|
| 360 |
|
| 361 |
with gr.Column(scale=1):
|