Update app.py
Browse files
app.py
CHANGED
|
@@ -23,19 +23,14 @@ def call_api(model_name, only_final_result, data_df):
|
|
| 23 |
"""
|
| 24 |
# 1. 数据校验
|
| 25 |
if not model_name.strip():
|
| 26 |
-
# 使用Gradio的错误提示功能
|
| 27 |
raise gr.Error("模型名称不能为空!")
|
| 28 |
if data_df is None or data_df.empty:
|
| 29 |
raise gr.Error("输入数据不能为空!请在表格中至少添加一行数据。")
|
| 30 |
|
| 31 |
# 2. 将输入的DataFrame转换成API需要的CSV字符串
|
| 32 |
-
# API需要 'id' 列,我们在这里自动生成
|
| 33 |
data_df_with_id = data_df.copy()
|
| 34 |
-
# 为id列生成唯一的、格式化的ID,例如 000, 001, 002...
|
| 35 |
data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))])
|
| 36 |
|
| 37 |
-
# 将DataFrame转换为CSV字符串
|
| 38 |
-
# 使用io.StringIO来在内存中处理字符串,避免写入文件
|
| 39 |
output = io.StringIO()
|
| 40 |
data_df_with_id.to_csv(output, index=False)
|
| 41 |
csv_data = output.getvalue()
|
|
@@ -49,34 +44,27 @@ def call_api(model_name, only_final_result, data_df):
|
|
| 49 |
|
| 50 |
try:
|
| 51 |
# 4. 发送POST请求
|
| 52 |
-
# 添加流式传输和更长的超时时间以应对长时间运行的任务
|
| 53 |
response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True)
|
| 54 |
-
response.raise_for_status()
|
| 55 |
|
| 56 |
-
# 解析返回的JSON数据
|
| 57 |
result_json = response.json()
|
| 58 |
|
| 59 |
# 5. 处理和返回结果
|
| 60 |
-
# 从结果中提取'result'键,它应该是一个字典列表
|
| 61 |
dataframe_result_data = result_json.get("result", [])
|
| 62 |
|
| 63 |
-
# 将字典列表转换为Pandas DataFrame,以便Gradio的DataFrame组件更好地显示
|
| 64 |
if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0:
|
| 65 |
output_df = pd.DataFrame(dataframe_result_data)
|
| 66 |
else:
|
| 67 |
-
output_df = pd.DataFrame()
|
| 68 |
|
| 69 |
return result_json, output_df
|
| 70 |
|
| 71 |
except requests.exceptions.HTTPError as e:
|
| 72 |
-
# 更详细地捕获HTTP错误
|
| 73 |
error_details = f"API返回错误 (状态码: {e.response.status_code}):\n{e.response.text}"
|
| 74 |
raise gr.Error(f"请求失败: {error_details}")
|
| 75 |
except requests.exceptions.RequestException as e:
|
| 76 |
-
# 捕获网络连接等问题
|
| 77 |
raise gr.Error(f"连接服务器失败: {e}")
|
| 78 |
except Exception as e:
|
| 79 |
-
# 捕获其他未知错误
|
| 80 |
raise gr.Error(f"发生未知错误: {e}")
|
| 81 |
|
| 82 |
|
|
@@ -106,16 +94,21 @@ with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280p
|
|
| 106 |
)
|
| 107 |
|
| 108 |
gr.Markdown("### 2. 输入数据")
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
data_input = gr.DataFrame(
|
| 111 |
label="数据表格",
|
| 112 |
headers=["subject", "relation", "object", "prompt_source"],
|
| 113 |
value=DEFAULT_DATAFRAME_VALUE,
|
| 114 |
-
row_count=(1, "dynamic"),
|
| 115 |
col_count=(4, "fixed"),
|
| 116 |
-
interactive=True
|
| 117 |
-
|
| 118 |
)
|
|
|
|
| 119 |
|
| 120 |
submit_btn = gr.Button("提交推理", variant="primary", scale=1)
|
| 121 |
|
|
@@ -127,11 +120,9 @@ with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280p
|
|
| 127 |
output_dataframe = gr.DataFrame(
|
| 128 |
label="推理结果详情",
|
| 129 |
interactive=False,
|
| 130 |
-
# 根据您的示例输出定义列,Gradio会自动匹配
|
| 131 |
headers=["subject", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
|
| 132 |
)
|
| 133 |
with gr.TabItem("原始JSON响应"):
|
| 134 |
-
# 使用gr.JSON组件可以更好地格式化显示JSON
|
| 135 |
output_json = gr.JSON(label="API 返回的完整JSON")
|
| 136 |
|
| 137 |
# 定义按钮点击事件
|
|
@@ -143,5 +134,4 @@ with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280p
|
|
| 143 |
|
| 144 |
# 启动Gradio应用
|
| 145 |
if __name__ == "__main__":
|
| 146 |
-
# 使用share=True可以生成一个公开链接,方便他人访问
|
| 147 |
demo.launch()
|
|
|
|
| 23 |
"""
|
| 24 |
# 1. 数据校验
|
| 25 |
if not model_name.strip():
|
|
|
|
| 26 |
raise gr.Error("模型名称不能为空!")
|
| 27 |
if data_df is None or data_df.empty:
|
| 28 |
raise gr.Error("输入数据不能为空!请在表格中至少添加一行数据。")
|
| 29 |
|
| 30 |
# 2. 将输入的DataFrame转换成API需要的CSV字符串
|
|
|
|
| 31 |
data_df_with_id = data_df.copy()
|
|
|
|
| 32 |
data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))])
|
| 33 |
|
|
|
|
|
|
|
| 34 |
output = io.StringIO()
|
| 35 |
data_df_with_id.to_csv(output, index=False)
|
| 36 |
csv_data = output.getvalue()
|
|
|
|
| 44 |
|
| 45 |
try:
|
| 46 |
# 4. 发送POST请求
|
|
|
|
| 47 |
response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True)
|
| 48 |
+
response.raise_for_status()
|
| 49 |
|
|
|
|
| 50 |
result_json = response.json()
|
| 51 |
|
| 52 |
# 5. 处理和返回结果
|
|
|
|
| 53 |
dataframe_result_data = result_json.get("result", [])
|
| 54 |
|
|
|
|
| 55 |
if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0:
|
| 56 |
output_df = pd.DataFrame(dataframe_result_data)
|
| 57 |
else:
|
| 58 |
+
output_df = pd.DataFrame()
|
| 59 |
|
| 60 |
return result_json, output_df
|
| 61 |
|
| 62 |
except requests.exceptions.HTTPError as e:
|
|
|
|
| 63 |
error_details = f"API返回错误 (状态码: {e.response.status_code}):\n{e.response.text}"
|
| 64 |
raise gr.Error(f"请求失败: {error_details}")
|
| 65 |
except requests.exceptions.RequestException as e:
|
|
|
|
| 66 |
raise gr.Error(f"连接服务器失败: {e}")
|
| 67 |
except Exception as e:
|
|
|
|
| 68 |
raise gr.Error(f"发生未知错误: {e}")
|
| 69 |
|
| 70 |
|
|
|
|
| 94 |
)
|
| 95 |
|
| 96 |
gr.Markdown("### 2. 输入数据")
|
| 97 |
+
|
| 98 |
+
# -------------------- 主要修改部分在这里 --------------------
|
| 99 |
+
# 将原来的 info 参数内容,移到一个独立的 Markdown 组件中
|
| 100 |
+
gr.Markdown("ℹ️ **操作提示**: 点击表格右下方的 **✏️ Edit** 按钮,然后点击 **+ Add row** 来添加新行,或点击垃圾桶图标删除行。")
|
| 101 |
+
|
| 102 |
data_input = gr.DataFrame(
|
| 103 |
label="数据表格",
|
| 104 |
headers=["subject", "relation", "object", "prompt_source"],
|
| 105 |
value=DEFAULT_DATAFRAME_VALUE,
|
| 106 |
+
row_count=(1, "dynamic"),
|
| 107 |
col_count=(4, "fixed"),
|
| 108 |
+
interactive=True
|
| 109 |
+
# 已移除 'info' 参数
|
| 110 |
)
|
| 111 |
+
# -------------------- 修改结束 --------------------
|
| 112 |
|
| 113 |
submit_btn = gr.Button("提交推理", variant="primary", scale=1)
|
| 114 |
|
|
|
|
| 120 |
output_dataframe = gr.DataFrame(
|
| 121 |
label="推理结果详情",
|
| 122 |
interactive=False,
|
|
|
|
| 123 |
headers=["subject", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
|
| 124 |
)
|
| 125 |
with gr.TabItem("原始JSON响应"):
|
|
|
|
| 126 |
output_json = gr.JSON(label="API 返回的完整JSON")
|
| 127 |
|
| 128 |
# 定义按钮点击事件
|
|
|
|
| 134 |
|
| 135 |
# 启动Gradio应用
|
| 136 |
if __name__ == "__main__":
|
|
|
|
| 137 |
demo.launch()
|