Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,95 +1,84 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
-
from transformers import pipeline,
|
| 4 |
import pandas as pd
|
| 5 |
import plotly.express as px
|
| 6 |
import os # 用于检查文件是否存在
|
| 7 |
|
| 8 |
# --- 1. 模型加载 ---
|
| 9 |
-
#
|
| 10 |
-
#
|
| 11 |
-
#
|
| 12 |
-
#
|
| 13 |
-
|
| 14 |
-
# --- 模型 1: DistilGPT2 (小型通用文本生成模型) ---
|
| 15 |
-
# 负责同学: [牛正武]
|
| 16 |
try:
|
| 17 |
-
model1_name = "
|
| 18 |
-
# device=0 表示使用第一个GPU,如果没有GPU则使用-1表示CPU
|
| 19 |
generator1 = pipeline("text-generation", model=model1_name, device=0 if torch.cuda.is_available() else -1)
|
| 20 |
-
print(f"✅ 模型 1 ({model1_name}) 加载成功!")
|
| 21 |
except Exception as e:
|
| 22 |
-
print(f"❌ 模型 1 ({model1_name}) 加载失败: {e}")
|
| 23 |
-
generator1 = None
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
#
|
| 27 |
try:
|
| 28 |
-
model2_name = "
|
| 29 |
-
|
| 30 |
-
print(f"✅ 模型 2 ({model2_name}) 加载成功!")
|
| 31 |
except Exception as e:
|
| 32 |
-
print(f"❌ 模型 2 ({model2_name}) 加载失败: {e}")
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
# --- [可选] 模型 3: 你可以根据需要添加第三个模型 ---
|
| 36 |
-
# 例如:一个翻译模型,或者一个专门的对话模型
|
| 37 |
-
# model3_name = "Helsinki-NLP/opus-mt-en-zh" # 这是一个英译中翻译模型
|
| 38 |
-
# try:
|
| 39 |
-
# translator = pipeline("translation_en_to_zh", model=model3_name, device=0 if torch.cuda.is_available() else -1)
|
| 40 |
-
# print(f"✅ 模型 3 ({model3_name}) 加载成功!")
|
| 41 |
-
# except Exception as e:
|
| 42 |
-
# print(f"❌ 模型 3 ({model3_name}) 加载失败: {e}")
|
| 43 |
-
# translator = None
|
| 44 |
|
| 45 |
|
| 46 |
# --- 2. 推理函数 ---
|
| 47 |
-
# 这个函数接
|
| 48 |
-
def
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
# output3 = "模型 3 未加载或生成失败。" # 如果有第三个模型
|
| 52 |
|
|
|
|
| 53 |
if generator1:
|
| 54 |
try:
|
| 55 |
-
#
|
| 56 |
-
|
| 57 |
-
|
|
|
|
| 58 |
# 清理:移除输入部分,只保留生成内容
|
| 59 |
-
if
|
| 60 |
-
|
| 61 |
except Exception as e:
|
| 62 |
-
|
| 63 |
|
| 64 |
-
|
|
|
|
| 65 |
try:
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
if output2.startswith(prompt):
|
| 69 |
-
output2 = output2[len(prompt):].strip()
|
| 70 |
except Exception as e:
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
# if translator:
|
| 75 |
-
# try:
|
| 76 |
-
# trans_result = translator(prompt)
|
| 77 |
-
# output3 = trans_result[0]['translation_text']
|
| 78 |
-
# except Exception as e:
|
| 79 |
-
# output3 = f"模型 3 (翻译模型) 生成错误: {e}"
|
| 80 |
|
| 81 |
-
return
|
| 82 |
|
| 83 |
|
| 84 |
# --- 3. GRACE 评估数据(示例数据,请根据你们的实际评估结果修改) ---
|
| 85 |
-
#
|
| 86 |
-
# 评分范围通常是 1-5 分,分数越高代表表现越好。
|
| 87 |
grace_data = {
|
| 88 |
"维度": ["Generalization (泛化性)", "Relevance (相关性)", "Artistry (创新表现力)", "Efficiency (效率性)"],
|
| 89 |
-
#
|
| 90 |
-
"
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
}
|
| 94 |
grace_df = pd.DataFrame(grace_data)
|
| 95 |
|
|
@@ -116,7 +105,8 @@ def create_benchmark_tab():
|
|
| 116 |
|
| 117 |
return gr.Column(
|
| 118 |
gr.Markdown("## 📊 模型性能对比 (GRACE 评估)"),
|
| 119 |
-
gr.Markdown("本页展示了我们选用的模型在 GRACE 框架下的评估结果。数据为 1-5 分,分数越高代表表现越好。"
|
|
|
|
| 120 |
gr.Plot(fig, label="GRACE 评估雷达图"),
|
| 121 |
gr.Markdown("### GRACE 评估数据"),
|
| 122 |
gr.DataFrame(grace_df, label="详细评估数据")
|
|
@@ -126,28 +116,30 @@ def create_benchmark_tab():
|
|
| 126 |
def create_arena_tab():
|
| 127 |
with gr.Blocks() as arena_block:
|
| 128 |
gr.Markdown("## ⚔️ Arena: 模型实时对比")
|
| 129 |
-
gr.Markdown("在这里,您可以输入一段文本
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
with gr.Row():
|
| 132 |
-
#
|
| 133 |
-
|
| 134 |
-
# 增加生成长度控制
|
| 135 |
-
gen_length_slider = gr.Slider(minimum=20, maximum=300, value=100, step=10, label="生成文本最大长度")
|
| 136 |
generate_btn = gr.Button("🚀 生成并对比")
|
| 137 |
|
| 138 |
with gr.Row():
|
| 139 |
-
# 模型 1
|
| 140 |
-
|
| 141 |
-
# 模型 2 输出
|
| 142 |
-
|
| 143 |
-
# # 如果有第三个模型
|
| 144 |
-
# output_model3 = gr.Textbox(label="模型 3 (翻译模型) 输出:", interactive=False, lines=10)
|
| 145 |
|
| 146 |
# 绑定按钮点击事件到推理函数
|
| 147 |
generate_btn.click(
|
| 148 |
-
fn=
|
| 149 |
-
inputs=[
|
| 150 |
-
outputs=[
|
| 151 |
)
|
| 152 |
return arena_block
|
| 153 |
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification # 导入AutoTokenizer, AutoModelForSequenceClassification用于问答模型
|
| 4 |
import pandas as pd
|
| 5 |
import plotly.express as px
|
| 6 |
import os # 用于检查文件是否存在
|
| 7 |
|
| 8 |
# --- 1. 模型加载 ---
|
| 9 |
+
# 负责同学: [填写负责这个模型的同学姓名]
|
| 10 |
+
# 注意:QuantFactory/Apollo2-7B-GGUF 模型通常不直接兼容 pipeline("text-generation", ...)
|
| 11 |
+
# 除非有额外的llama.cpp或特定的transformers加载配置。
|
| 12 |
+
# 为了演示和确保运行流畅,这里使用 gpt2-large 作为替代。
|
|
|
|
|
|
|
|
|
|
| 13 |
try:
|
| 14 |
+
model1_name = "gpt2-large" # 替代 QuantFactory/Apollo2-7B-GGUF 以确保兼容性
|
|
|
|
| 15 |
generator1 = pipeline("text-generation", model=model1_name, device=0 if torch.cuda.is_available() else -1)
|
| 16 |
+
print(f"✅ 模型 1 (文本生成: {model1_name}) 加载成功!")
|
| 17 |
except Exception as e:
|
| 18 |
+
print(f"❌ 模型 1 (文本生成: {model1_name}) 加载失败: {e}")
|
| 19 |
+
generator1 = None
|
| 20 |
|
| 21 |
+
# 负责同学: [填写负责这个模型的同学姓名]
|
| 22 |
+
# deepset/roberta-base-squad2 是一个问答模型,需要 context
|
| 23 |
try:
|
| 24 |
+
model2_name = "deepset/roberta-base-squad2"
|
| 25 |
+
qa_model = pipeline("question-answering", model=model2_name, device=0 if torch.cuda.is_available() else -1)
|
| 26 |
+
print(f"✅ 模型 2 (问答: {model2_name}) 加载成功!")
|
| 27 |
except Exception as e:
|
| 28 |
+
print(f"❌ 模型 2 (问答: {model2_name}) 加载失败: {e}")
|
| 29 |
+
qa_model = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
# --- 2. 推理函数 ---
|
| 33 |
+
# 这个函数现在接受一个问题/提示词和一个上下文
|
| 34 |
+
def get_model_outputs(question_or_prompt, context, max_length=100):
|
| 35 |
+
output_text_gen = "文本生成模型未加载或生成失败。"
|
| 36 |
+
output_qa = "问答模型未加载或生成失败。"
|
|
|
|
| 37 |
|
| 38 |
+
# 模型 1: 文本生成
|
| 39 |
if generator1:
|
| 40 |
try:
|
| 41 |
+
# 文本生成模型将问题和上下文作为其prompt的一部分
|
| 42 |
+
full_prompt_for_gen = f"{question_or_prompt}\nContext: {context}" if context else question_or_prompt
|
| 43 |
+
gen_result = generator1(full_prompt_for_gen, max_new_tokens=max_length, num_return_sequences=1, truncation=True)
|
| 44 |
+
output_text_gen = gen_result[0]['generated_text']
|
| 45 |
# 清理:移除输入部分,只保留生成内容
|
| 46 |
+
if output_text_gen.startswith(full_prompt_for_gen):
|
| 47 |
+
output_text_gen = output_text_gen[len(full_prompt_for_gen):].strip()
|
| 48 |
except Exception as e:
|
| 49 |
+
output_text_gen = f"文本生成模型 ({model1_name}) 错误: {e}"
|
| 50 |
|
| 51 |
+
# 模型 2: 问答
|
| 52 |
+
if qa_model and context: # 问答模型必须有上下文
|
| 53 |
try:
|
| 54 |
+
qa_result = qa_model(question=question_or_prompt, context=context)
|
| 55 |
+
output_qa = qa_result['answer']
|
|
|
|
|
|
|
| 56 |
except Exception as e:
|
| 57 |
+
output_qa = f"问答模型 ({model2_name}) 错误: {e}"
|
| 58 |
+
elif qa_model and not context:
|
| 59 |
+
output_qa = "问答模型需要提供上下文才能回答问题。"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
+
return output_text_gen, output_qa
|
| 62 |
|
| 63 |
|
| 64 |
# --- 3. GRACE 评估数据(示例数据,请根据你们的实际评估结果修改) ---
|
| 65 |
+
# 请根据 gpt2-large 和 deepset/roberta-base-squad2 的实际表现进行评分
|
|
|
|
| 66 |
grace_data = {
|
| 67 |
"维度": ["Generalization (泛化性)", "Relevance (相关性)", "Artistry (创新表现力)", "Efficiency (效率性)"],
|
| 68 |
+
# 模型 1: gpt2-large (通用文本生成模型)
|
| 69 |
+
"GPT2-Large": [
|
| 70 |
+
4.0, # 泛化性: 能处理多种文本生成任务
|
| 71 |
+
3.5, # 相关性: 对于特定事实性问题可能不如问答模型精确
|
| 72 |
+
4.2, # 创新表现力: 生成文本流畅,有一定创造性
|
| 73 |
+
3.8 # 效率性: 相对 GPT2 较大,但比 Llama-2-7b 小
|
| 74 |
+
],
|
| 75 |
+
# 模型 2: deepset/roberta-base-squad2 (问答模型)
|
| 76 |
+
"RoBERTa-SQuAD2": [
|
| 77 |
+
3.0, # 泛化性: 专门用于问答,不能生成开放式文本
|
| 78 |
+
4.8, # 相关性: 从给定上下文中抽取答案,相关性极高
|
| 79 |
+
2.0, # 创新表现力: 抽取式问答,无创新表现
|
| 80 |
+
4.5 # 效率性: 推理速度快,效率高
|
| 81 |
+
]
|
| 82 |
}
|
| 83 |
grace_df = pd.DataFrame(grace_data)
|
| 84 |
|
|
|
|
| 105 |
|
| 106 |
return gr.Column(
|
| 107 |
gr.Markdown("## 📊 模型性能对比 (GRACE 评估)"),
|
| 108 |
+
gr.Markdown("本页展示了我们选用的模型在 GRACE 框架下的评估结果。数据为 1-5 分,分数越高代表表现越好。\n"
|
| 109 |
+
"**注意**: GPT2-Large 主要用于文本生成,RoBERTa-SQuAD2 主要用于问答,它们的评估维度侧重有所不同。"),
|
| 110 |
gr.Plot(fig, label="GRACE 评估雷达图"),
|
| 111 |
gr.Markdown("### GRACE 评估数据"),
|
| 112 |
gr.DataFrame(grace_df, label="详细评估数据")
|
|
|
|
| 116 |
def create_arena_tab():
|
| 117 |
with gr.Blocks() as arena_block:
|
| 118 |
gr.Markdown("## ⚔️ Arena: 模型实时对比")
|
| 119 |
+
gr.Markdown("在这里,您可以输入一个问题或提示词,并提供一段上下文。文本生成模型将根据问题和上下文生成文本,问答模型将从上下文中抽取答案。")
|
| 120 |
+
|
| 121 |
+
with gr.Row():
|
| 122 |
+
# 统一输入框 1: 问题/提示词
|
| 123 |
+
question_input = gr.Textbox(label="问题/提示词:", placeholder="请输入您的问题或想让模型生成的提示词...", lines=3)
|
| 124 |
+
# 统一输入框 2: 上下文 (主要用于问答模型)
|
| 125 |
+
context_input = gr.Textbox(label="上下文 (Context):", placeholder="请输入问答模型需要从中抽取答案的上下文...", lines=5)
|
| 126 |
|
| 127 |
with gr.Row():
|
| 128 |
+
# 增加生成长度控制(主要针对文本生成模型)
|
| 129 |
+
gen_length_slider = gr.Slider(minimum=20, maximum=300, value=100, step=10, label="文本生成最大长度")
|
|
|
|
|
|
|
| 130 |
generate_btn = gr.Button("🚀 生成并对比")
|
| 131 |
|
| 132 |
with gr.Row():
|
| 133 |
+
# 模型 1 输出 (文本生成)
|
| 134 |
+
output_text_gen = gr.Textbox(label=f"模型 1 (文本生成: {model1_name}) 输出:", interactive=False, lines=10)
|
| 135 |
+
# 模型 2 输出 (问答)
|
| 136 |
+
output_qa = gr.Textbox(label=f"模型 2 (问答: {model2_name}) 输出:", interactive=False, lines=10)
|
|
|
|
|
|
|
| 137 |
|
| 138 |
# 绑定按钮点击事件到推理函数
|
| 139 |
generate_btn.click(
|
| 140 |
+
fn=get_model_outputs,
|
| 141 |
+
inputs=[question_input, context_input, gen_length_slider],
|
| 142 |
+
outputs=[output_text_gen, output_qa]
|
| 143 |
)
|
| 144 |
return arena_block
|
| 145 |
|