Update app.py
Browse files
app.py
CHANGED
|
@@ -25,13 +25,10 @@ def call_api(only_final_result, data_df):
|
|
| 25 |
"""
|
| 26 |
这个函数会被Gradio调用,它负责收集UI上的输入,
|
| 27 |
将其转换成API需要的格式,然后发送请求。
|
| 28 |
-
(已修改:不再接收 model_name 参数)
|
| 29 |
"""
|
| 30 |
-
# 1. 数据校验 (不再需要校验模型名称)
|
| 31 |
if data_df is None or data_df.empty:
|
| 32 |
raise gr.Error("输入数据不能为空!请在表格中至少添加一行数据。")
|
| 33 |
|
| 34 |
-
# 2. 将输入的DataFrame转换成API需要的CSV字符串
|
| 35 |
data_df_with_id = data_df.copy()
|
| 36 |
data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))])
|
| 37 |
|
|
@@ -39,7 +36,6 @@ def call_api(only_final_result, data_df):
|
|
| 39 |
data_df_with_id.to_csv(output, index=False)
|
| 40 |
csv_data = output.getvalue()
|
| 41 |
|
| 42 |
-
# 3. 构建发送到API的JSON数据体 (使用固定的模型名称)
|
| 43 |
payload = {
|
| 44 |
"model_name": FIXED_MODEL_NAME,
|
| 45 |
"data_csv": csv_data,
|
|
@@ -47,13 +43,10 @@ def call_api(only_final_result, data_df):
|
|
| 47 |
}
|
| 48 |
|
| 49 |
try:
|
| 50 |
-
# 4. 发送POST请求
|
| 51 |
response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True)
|
| 52 |
response.raise_for_status()
|
| 53 |
-
|
| 54 |
result_json = response.json()
|
| 55 |
|
| 56 |
-
# 5. 处理和返回结果
|
| 57 |
dataframe_result_data = result_json.get("result", [])
|
| 58 |
|
| 59 |
if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0:
|
|
@@ -81,13 +74,10 @@ with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280p
|
|
| 81 |
"""
|
| 82 |
)
|
| 83 |
|
| 84 |
-
# ---------- 新布局:输入部分 ----------
|
| 85 |
with gr.Column():
|
| 86 |
gr.Markdown("### 1. 配置参数")
|
| 87 |
with gr.Group():
|
| 88 |
-
# 使用 Markdown 或不可编辑的 Textbox 展示固定的模型名称
|
| 89 |
gr.Markdown(f"**当前模型 (固定):** `{FIXED_MODEL_NAME}`")
|
| 90 |
-
|
| 91 |
only_final_result_input = gr.Checkbox(
|
| 92 |
label="仅返回最终结果 (only_final_result)",
|
| 93 |
value=False,
|
|
@@ -95,25 +85,26 @@ with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280p
|
|
| 95 |
)
|
| 96 |
|
| 97 |
gr.Markdown("### 2. 输入数据")
|
| 98 |
-
|
| 99 |
gr.Markdown("ℹ️ **操作提示**: 点击表格右下方的 **✏️ Edit** 按钮,然后点击 **+ Add row** 来添加新行,或点击垃圾桶图标删除行。")
|
| 100 |
|
| 101 |
# -------------------- 主要修改部分在这里 --------------------
|
| 102 |
-
#
|
|
|
|
| 103 |
data_input = gr.DataFrame(
|
| 104 |
label="数据表格",
|
| 105 |
headers=["subject", "relation", "object", "prompt_source"],
|
| 106 |
value=DEFAULT_DATAFRAME_VALUE,
|
| 107 |
-
|
|
|
|
|
|
|
| 108 |
col_count=(4, "fixed"),
|
| 109 |
-
interactive=True
|
| 110 |
-
|
| 111 |
)
|
| 112 |
# -------------------- 修改结束 --------------------
|
| 113 |
|
| 114 |
submit_btn = gr.Button("提交推理", variant="primary")
|
| 115 |
|
| 116 |
-
# ---------- 新布局:输出部分 ----------
|
| 117 |
with gr.Column():
|
| 118 |
gr.Markdown("### 3. 查看推理结果")
|
| 119 |
with gr.Tabs():
|
|
@@ -121,16 +112,14 @@ with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280p
|
|
| 121 |
output_dataframe = gr.DataFrame(
|
| 122 |
label="推理结果详情",
|
| 123 |
interactive=False,
|
| 124 |
-
# 增加了列数以适应可能的更多输出
|
| 125 |
headers=["id", "subject", "relation", "object", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
|
| 126 |
)
|
| 127 |
with gr.TabItem("原始JSON响应"):
|
| 128 |
output_json = gr.JSON(label="API 返回的完整JSON")
|
| 129 |
|
| 130 |
-
# 定义按钮点击事件 (已更新 inputs 列表)
|
| 131 |
submit_btn.click(
|
| 132 |
fn=call_api,
|
| 133 |
-
inputs=[only_final_result_input, data_input],
|
| 134 |
outputs=[output_json, output_dataframe]
|
| 135 |
)
|
| 136 |
|
|
|
|
| 25 |
"""
|
| 26 |
这个函数会被Gradio调用,它负责收集UI上的输入,
|
| 27 |
将其转换成API需要的格式,然后发送请求。
|
|
|
|
| 28 |
"""
|
|
|
|
| 29 |
if data_df is None or data_df.empty:
|
| 30 |
raise gr.Error("输入数据不能为空!请在表格中至少添加一行数据。")
|
| 31 |
|
|
|
|
| 32 |
data_df_with_id = data_df.copy()
|
| 33 |
data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))])
|
| 34 |
|
|
|
|
| 36 |
data_df_with_id.to_csv(output, index=False)
|
| 37 |
csv_data = output.getvalue()
|
| 38 |
|
|
|
|
| 39 |
payload = {
|
| 40 |
"model_name": FIXED_MODEL_NAME,
|
| 41 |
"data_csv": csv_data,
|
|
|
|
| 43 |
}
|
| 44 |
|
| 45 |
try:
|
|
|
|
| 46 |
response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True)
|
| 47 |
response.raise_for_status()
|
|
|
|
| 48 |
result_json = response.json()
|
| 49 |
|
|
|
|
| 50 |
dataframe_result_data = result_json.get("result", [])
|
| 51 |
|
| 52 |
if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0:
|
|
|
|
| 74 |
"""
|
| 75 |
)
|
| 76 |
|
|
|
|
| 77 |
with gr.Column():
|
| 78 |
gr.Markdown("### 1. 配置参数")
|
| 79 |
with gr.Group():
|
|
|
|
| 80 |
gr.Markdown(f"**当前模型 (固定):** `{FIXED_MODEL_NAME}`")
|
|
|
|
| 81 |
only_final_result_input = gr.Checkbox(
|
| 82 |
label="仅返回最终结果 (only_final_result)",
|
| 83 |
value=False,
|
|
|
|
| 85 |
)
|
| 86 |
|
| 87 |
gr.Markdown("### 2. 输入数据")
|
|
|
|
| 88 |
gr.Markdown("ℹ️ **操作提示**: 点击表格右下方的 **✏️ Edit** 按钮,然后点击 **+ Add row** 来添加新行,或点击垃圾桶图标删除行。")
|
| 89 |
|
| 90 |
# -------------------- 主要修改部分在这里 --------------------
|
| 91 |
+
# 移除了 'height' 参数,因为它在旧版Gradio中不受支持。
|
| 92 |
+
# 我们通过增加 row_count 的第一个参数(初始行数)来为表格提供更多初始垂直空间。
|
| 93 |
data_input = gr.DataFrame(
|
| 94 |
label="数据表格",
|
| 95 |
headers=["subject", "relation", "object", "prompt_source"],
|
| 96 |
value=DEFAULT_DATAFRAME_VALUE,
|
| 97 |
+
# 将初始行数从3增加到5,使表格在加载时就更高。
|
| 98 |
+
# (要显示的行数, "dynamic"表示用户可以增删行)
|
| 99 |
+
row_count=(5, "dynamic"),
|
| 100 |
col_count=(4, "fixed"),
|
| 101 |
+
interactive=True
|
| 102 |
+
# 已移除不受支持的 'height' 参数
|
| 103 |
)
|
| 104 |
# -------------------- 修改结束 --------------------
|
| 105 |
|
| 106 |
submit_btn = gr.Button("提交推理", variant="primary")
|
| 107 |
|
|
|
|
| 108 |
with gr.Column():
|
| 109 |
gr.Markdown("### 3. 查看推理结果")
|
| 110 |
with gr.Tabs():
|
|
|
|
| 112 |
output_dataframe = gr.DataFrame(
|
| 113 |
label="推理结果详情",
|
| 114 |
interactive=False,
|
|
|
|
| 115 |
headers=["id", "subject", "relation", "object", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
|
| 116 |
)
|
| 117 |
with gr.TabItem("原始JSON响应"):
|
| 118 |
output_json = gr.JSON(label="API 返回的完整JSON")
|
| 119 |
|
|
|
|
| 120 |
submit_btn.click(
|
| 121 |
fn=call_api,
|
| 122 |
+
inputs=[only_final_result_input, data_input],
|
| 123 |
outputs=[output_json, output_dataframe]
|
| 124 |
)
|
| 125 |
|