Update app.py
Browse files
app.py
CHANGED
|
@@ -9,21 +9,25 @@ import io
|
|
| 9 |
SERVER_URL = "http://103.235.229.133:7000/predict"
|
| 10 |
# -------------------------------------------------------
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
DEFAULT_DATAFRAME_VALUE = [
|
| 15 |
["France", "capital city of", "Paris", "she traveled to France for the first time"],
|
| 16 |
-
["Germany", "capital city of", "Berlin", "
|
|
|
|
| 17 |
]
|
| 18 |
|
| 19 |
-
def call_api(
|
| 20 |
"""
|
| 21 |
这个函数会被Gradio调用,它负责收集UI上的输入,
|
| 22 |
将其转换成API需要的格式,然后发送请求。
|
|
|
|
| 23 |
"""
|
| 24 |
-
# 1. 数据校验
|
| 25 |
-
if not model_name.strip():
|
| 26 |
-
raise gr.Error("模型名称不能为空!")
|
| 27 |
if data_df is None or data_df.empty:
|
| 28 |
raise gr.Error("输入数据不能为空!请在表格中至少添加一行数据。")
|
| 29 |
|
|
@@ -35,9 +39,9 @@ def call_api(model_name, only_final_result, data_df):
|
|
| 35 |
data_df_with_id.to_csv(output, index=False)
|
| 36 |
csv_data = output.getvalue()
|
| 37 |
|
| 38 |
-
# 3. 构建发送到API的JSON数据体
|
| 39 |
payload = {
|
| 40 |
-
"model_name":
|
| 41 |
"data_csv": csv_data,
|
| 42 |
"only_final_result": only_final_result
|
| 43 |
}
|
|
@@ -77,58 +81,56 @@ with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280p
|
|
| 77 |
"""
|
| 78 |
)
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
label="模型名称 (model_name)",
|
| 87 |
-
value=DEFAULT_MODEL_NAME,
|
| 88 |
-
info="请输入部署在服务器上的模型路径。"
|
| 89 |
-
)
|
| 90 |
-
only_final_result_input = gr.Checkbox(
|
| 91 |
-
label="仅返回最终结果 (only_final_result)",
|
| 92 |
-
value=False,
|
| 93 |
-
info="勾选此项,API将只返回最终聚合结果,而不是所有层的详细信息。"
|
| 94 |
-
)
|
| 95 |
-
|
| 96 |
-
gr.Markdown("### 2. 输入数据")
|
| 97 |
-
|
| 98 |
-
# -------------------- 主要修改部分在这里 --------------------
|
| 99 |
-
# 将原来的 info 参数内容,移到一个独立的 Markdown 组件中
|
| 100 |
-
gr.Markdown("ℹ️ **操作提示**: 点击表格右下方的 **✏️ Edit** 按钮,然后点击 **+ Add row** 来添加新行,或点击垃圾桶图标删除行。")
|
| 101 |
|
| 102 |
-
|
| 103 |
-
label="
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
row_count=(1, "dynamic"),
|
| 107 |
-
col_count=(4, "fixed"),
|
| 108 |
-
interactive=True
|
| 109 |
-
# 已移除 'info' 参数
|
| 110 |
)
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
| 127 |
|
| 128 |
-
# 定义按钮点击事件
|
| 129 |
submit_btn.click(
|
| 130 |
fn=call_api,
|
| 131 |
-
inputs=[
|
| 132 |
outputs=[output_json, output_dataframe]
|
| 133 |
)
|
| 134 |
|
|
|
|
| 9 |
SERVER_URL = "http://103.235.229.133:7000/predict"
|
| 10 |
# -------------------------------------------------------
|
| 11 |
|
| 12 |
+
# ----------------- 固定的模型名称 -----------------
|
| 13 |
+
# 根据要求,模型名称是固定的,不再作为UI输入
|
| 14 |
+
FIXED_MODEL_NAME = "/data/tsq/MODELS/Qwen2.5-1.5B-Instruct"
|
| 15 |
+
# -------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
# 默认的输入数据,方便用户测试
|
| 18 |
DEFAULT_DATAFRAME_VALUE = [
|
| 19 |
["France", "capital city of", "Paris", "she traveled to France for the first time"],
|
| 20 |
+
["Germany", "capital city of", "Berlin", "Germany is a country in Central Europe"],
|
| 21 |
+
["Japan", "capital city of", "Tokyo", "Tokyo is the largest city in Japan"]
|
| 22 |
]
|
| 23 |
|
| 24 |
+
def call_api(only_final_result, data_df):
|
| 25 |
"""
|
| 26 |
这个函数会被Gradio调用,它负责收集UI上的输入,
|
| 27 |
将其转换成API需要的格式,然后发送请求。
|
| 28 |
+
(已修改:不再接收 model_name 参数)
|
| 29 |
"""
|
| 30 |
+
# 1. 数据校验 (不再需要校验模型名称)
|
|
|
|
|
|
|
| 31 |
if data_df is None or data_df.empty:
|
| 32 |
raise gr.Error("输入数据不能为空!请在表格中至少添加一行数据。")
|
| 33 |
|
|
|
|
| 39 |
data_df_with_id.to_csv(output, index=False)
|
| 40 |
csv_data = output.getvalue()
|
| 41 |
|
| 42 |
+
# 3. 构建发送到API的JSON数据体 (使用固定的模型名称)
|
| 43 |
payload = {
|
| 44 |
+
"model_name": FIXED_MODEL_NAME,
|
| 45 |
"data_csv": csv_data,
|
| 46 |
"only_final_result": only_final_result
|
| 47 |
}
|
|
|
|
| 81 |
"""
|
| 82 |
)
|
| 83 |
|
| 84 |
+
# ---------- 新布局:输入部分 ----------
|
| 85 |
+
with gr.Column():
|
| 86 |
+
gr.Markdown("### 1. 配置参数")
|
| 87 |
+
with gr.Group():
|
| 88 |
+
# 使用 Markdown 或不可编辑的 Textbox 展示固定的模型名称
|
| 89 |
+
gr.Markdown(f"**当前模型 (固定):** `{FIXED_MODEL_NAME}`")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
+
only_final_result_input = gr.Checkbox(
|
| 92 |
+
label="仅返回最终结果 (only_final_result)",
|
| 93 |
+
value=False,
|
| 94 |
+
info="勾选此项,API将只返回最终聚合结果,而不是所有层的详细信息。"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
)
|
| 96 |
+
|
| 97 |
+
gr.Markdown("### 2. 输入数据")
|
| 98 |
+
|
| 99 |
+
gr.Markdown("ℹ️ **操作提示**: 点击表格右下方的 **✏️ Edit** 按钮,然后点击 **+ Add row** 来添加新行,或点击垃圾桶图标删除行。")
|
| 100 |
+
|
| 101 |
+
# -------------------- 主要修改部分在这里 --------------------
|
| 102 |
+
# 增加了 height 参数来让表格更高,并且放在了更宽的布局中
|
| 103 |
+
data_input = gr.DataFrame(
|
| 104 |
+
label="数据表格",
|
| 105 |
+
headers=["subject", "relation", "object", "prompt_source"],
|
| 106 |
+
value=DEFAULT_DATAFRAME_VALUE,
|
| 107 |
+
row_count=(3, "dynamic"), # 默认显示3行
|
| 108 |
+
col_count=(4, "fixed"),
|
| 109 |
+
interactive=True,
|
| 110 |
+
height=300 # 明确设置组件高度,单位为像素
|
| 111 |
+
)
|
| 112 |
+
# -------------------- 修改结束 --------------------
|
| 113 |
+
|
| 114 |
+
submit_btn = gr.Button("提交推理", variant="primary")
|
| 115 |
|
| 116 |
+
# ---------- 新布局:输出部分 ----------
|
| 117 |
+
with gr.Column():
|
| 118 |
+
gr.Markdown("### 3. 查看推理结果")
|
| 119 |
+
with gr.Tabs():
|
| 120 |
+
with gr.TabItem("结果表格"):
|
| 121 |
+
output_dataframe = gr.DataFrame(
|
| 122 |
+
label="推理结果详情",
|
| 123 |
+
interactive=False,
|
| 124 |
+
# 增加了列数以适应可能的更多输出
|
| 125 |
+
headers=["id", "subject", "relation", "object", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
|
| 126 |
+
)
|
| 127 |
+
with gr.TabItem("原始JSON响应"):
|
| 128 |
+
output_json = gr.JSON(label="API 返回的完整JSON")
|
| 129 |
|
| 130 |
+
# 定义按钮点击事件 (已更新 inputs 列表)
|
| 131 |
submit_btn.click(
|
| 132 |
fn=call_api,
|
| 133 |
+
inputs=[only_final_result_input, data_input], # 移除了 model_name_input
|
| 134 |
outputs=[output_json, output_dataframe]
|
| 135 |
)
|
| 136 |
|