Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
import pandas as pd
|
| 3 |
import plotly.graph_objects as go
|
| 4 |
import plotly.express as px
|
|
@@ -8,7 +9,6 @@ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
|
|
| 8 |
import torch
|
| 9 |
import json
|
| 10 |
import re
|
| 11 |
-
from googletrans import Translator # 导入翻译库
|
| 12 |
|
| 13 |
# 选择三个小型文本生成模型
|
| 14 |
MODEL_CONFIGS = {
|
|
@@ -37,7 +37,6 @@ class TextGenerationComparator:
|
|
| 37 |
self.models = {}
|
| 38 |
self.tokenizers = {}
|
| 39 |
self.load_models()
|
| 40 |
-
self.translator = Translator() # 初始化翻译器
|
| 41 |
|
| 42 |
def load_models(self):
|
| 43 |
"""加载所有文本生成模型"""
|
|
@@ -59,26 +58,11 @@ class TextGenerationComparator:
|
|
| 59 |
# 创建一个mock模型用于演示
|
| 60 |
self.models[model_key] = None
|
| 61 |
|
| 62 |
-
def translate_to_chinese(self, text):
|
| 63 |
-
"""将文本翻译成中文"""
|
| 64 |
-
try:
|
| 65 |
-
# 尝试翻译,如果原文是中文则不翻译
|
| 66 |
-
if re.search(r'[\u4e00-\u9fff]', text): # 检查是否包含中文字符
|
| 67 |
-
return text
|
| 68 |
-
translation = self.translator.translate(text, dest='zh-cn')
|
| 69 |
-
return translation.text
|
| 70 |
-
except Exception as e:
|
| 71 |
-
print(f"翻译失败: {e}")
|
| 72 |
-
return text # 翻译失败时返回原文
|
| 73 |
-
|
| 74 |
def generate_text(self, model_key, prompt, max_length=50, temperature=0.7, top_p=0.9):
|
| 75 |
"""使用指定模型生成文本"""
|
| 76 |
if self.models[model_key] is None:
|
| 77 |
-
# 模拟输出也翻译成中文
|
| 78 |
-
mock_output_en = f"[Model {model_key} not loaded correctly, this is a simulated output] {prompt} and this is a sample continuation of the text..."
|
| 79 |
-
mock_output_zh = self.translate_to_chinese(mock_output_en)
|
| 80 |
return {
|
| 81 |
-
"generated_text":
|
| 82 |
"inference_time": 0.5,
|
| 83 |
"input_length": len(prompt.split()),
|
| 84 |
"output_length": max_length,
|
|
@@ -107,19 +91,16 @@ class TextGenerationComparator:
|
|
| 107 |
end_time = time.time()
|
| 108 |
|
| 109 |
# 提取生成的文本(去除原始prompt)
|
| 110 |
-
|
| 111 |
-
if
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
# 将生成的英文文本翻译成中文
|
| 115 |
-
generated_text_zh = self.translate_to_chinese(generated_text_en)
|
| 116 |
|
| 117 |
return {
|
| 118 |
-
"generated_text":
|
| 119 |
-
"full_text":
|
| 120 |
"inference_time": round(end_time - start_time, 3),
|
| 121 |
"input_length": len(prompt.split()),
|
| 122 |
-
"output_length": len(
|
| 123 |
"parameters": {
|
| 124 |
"temperature": temperature,
|
| 125 |
"top_p": top_p,
|
|
@@ -160,11 +141,11 @@ def run_text_generation_comparison(prompt, max_length, temperature, top_p):
|
|
| 160 |
if "error" in result:
|
| 161 |
return json.dumps(result, indent=2, ensure_ascii=False)
|
| 162 |
|
| 163 |
-
#
|
| 164 |
formatted = {
|
| 165 |
"生成文本": result["generated_text"],
|
| 166 |
"推断时间": f"{result['inference_time']}s",
|
| 167 |
-
"生成Token数": result["output_length"],
|
| 168 |
"生成速度": f"{result['output_length']/max(result['inference_time'], 0.001):.1f} tokens/s"
|
| 169 |
}
|
| 170 |
return json.dumps(formatted, indent=2, ensure_ascii=False)
|
|
@@ -340,7 +321,7 @@ def create_app():
|
|
| 340 |
# Arena选项卡
|
| 341 |
with gr.TabItem("🏟️ 生成竞技场"):
|
| 342 |
gr.Markdown("## 文本生成竞技场")
|
| 343 |
-
gr.Markdown("输入一个**英文**提示文本,查看不同GPT-2
|
| 344 |
|
| 345 |
with gr.Row():
|
| 346 |
with gr.Column(scale=3):
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
|
| 3 |
import pandas as pd
|
| 4 |
import plotly.graph_objects as go
|
| 5 |
import plotly.express as px
|
|
|
|
| 9 |
import torch
|
| 10 |
import json
|
| 11 |
import re
|
|
|
|
| 12 |
|
| 13 |
# 选择三个小型文本生成模型
|
| 14 |
MODEL_CONFIGS = {
|
|
|
|
| 37 |
self.models = {}
|
| 38 |
self.tokenizers = {}
|
| 39 |
self.load_models()
|
|
|
|
| 40 |
|
| 41 |
def load_models(self):
|
| 42 |
"""加载所有文本生成模型"""
|
|
|
|
| 58 |
# 创建一个mock模型用于演示
|
| 59 |
self.models[model_key] = None
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
def generate_text(self, model_key, prompt, max_length=50, temperature=0.7, top_p=0.9):
|
| 62 |
"""使用指定模型生成文本"""
|
| 63 |
if self.models[model_key] is None:
|
|
|
|
|
|
|
|
|
|
| 64 |
return {
|
| 65 |
+
"generated_text": f"[Model {model_key} not loaded correctly, this is a simulated output] {prompt} and this is a sample continuation of the text...",
|
| 66 |
"inference_time": 0.5,
|
| 67 |
"input_length": len(prompt.split()),
|
| 68 |
"output_length": max_length,
|
|
|
|
| 91 |
end_time = time.time()
|
| 92 |
|
| 93 |
# 提取生成的文本(去除原始prompt)
|
| 94 |
+
generated_text = result[0]['generated_text']
|
| 95 |
+
if generated_text.startswith(prompt):
|
| 96 |
+
generated_text = generated_text[len(prompt):].strip()
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
return {
|
| 99 |
+
"generated_text": generated_text,
|
| 100 |
+
"full_text": result[0]['generated_text'],
|
| 101 |
"inference_time": round(end_time - start_time, 3),
|
| 102 |
"input_length": len(prompt.split()),
|
| 103 |
+
"output_length": len(generated_text.split()),
|
| 104 |
"parameters": {
|
| 105 |
"temperature": temperature,
|
| 106 |
"top_p": top_p,
|
|
|
|
| 141 |
if "error" in result:
|
| 142 |
return json.dumps(result, indent=2, ensure_ascii=False)
|
| 143 |
|
| 144 |
+
# 这里的键名保留中文,值会是英文
|
| 145 |
formatted = {
|
| 146 |
"生成文本": result["generated_text"],
|
| 147 |
"推断时间": f"{result['inference_time']}s",
|
| 148 |
+
"生成Token数": result["output_length"],
|
| 149 |
"生成速度": f"{result['output_length']/max(result['inference_time'], 0.001):.1f} tokens/s"
|
| 150 |
}
|
| 151 |
return json.dumps(formatted, indent=2, ensure_ascii=False)
|
|
|
|
| 321 |
# Arena选项卡
|
| 322 |
with gr.TabItem("🏟️ 生成竞技场"):
|
| 323 |
gr.Markdown("## 文本生成竞技场")
|
| 324 |
+
gr.Markdown("输入一个**英文**提示文本,查看不同GPT-2模型如何续写,续写结果将保持英文显示。")
|
| 325 |
|
| 326 |
with gr.Row():
|
| 327 |
with gr.Column(scale=3):
|