lllyx commited on
Commit
041cfb4
·
verified ·
1 Parent(s): 86e2749

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -21
app.py CHANGED
@@ -23,19 +23,14 @@ def call_api(model_name, only_final_result, data_df):
23
  """
24
  # 1. 数据校验
25
  if not model_name.strip():
26
- # 使用Gradio的错误提示功能
27
  raise gr.Error("模型名称不能为空!")
28
  if data_df is None or data_df.empty:
29
  raise gr.Error("输入数据不能为空!请在表格中至少添加一行数据。")
30
 
31
  # 2. 将输入的DataFrame转换成API需要的CSV字符串
32
- # API需要 'id' 列,我们在这里自动生成
33
  data_df_with_id = data_df.copy()
34
- # 为id列生成唯一的、格式化的ID,例如 000, 001, 002...
35
  data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))])
36
 
37
- # 将DataFrame转换为CSV字符串
38
- # 使用io.StringIO来在内存中处理字符串,避免写入文件
39
  output = io.StringIO()
40
  data_df_with_id.to_csv(output, index=False)
41
  csv_data = output.getvalue()
@@ -49,34 +44,27 @@ def call_api(model_name, only_final_result, data_df):
49
 
50
  try:
51
  # 4. 发送POST请求
52
- # 添加流式传输和更长的超时时间以应对长时间运行的任务
53
  response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True)
54
- response.raise_for_status() # 如果状态码不是2xx,则抛出HTTPError
55
 
56
- # 解析返回的JSON数据
57
  result_json = response.json()
58
 
59
  # 5. 处理和返回结果
60
- # 从结果中提取'result'键,它应该是一个字典列表
61
  dataframe_result_data = result_json.get("result", [])
62
 
63
- # 将字典列表转换为Pandas DataFrame,以便Gradio的DataFrame组件更好地显示
64
  if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0:
65
  output_df = pd.DataFrame(dataframe_result_data)
66
  else:
67
- output_df = pd.DataFrame() # 如果没有结果,返回一个空的DataFrame
68
 
69
  return result_json, output_df
70
 
71
  except requests.exceptions.HTTPError as e:
72
- # 更详细地捕获HTTP错误
73
  error_details = f"API返回错误 (状态码: {e.response.status_code}):\n{e.response.text}"
74
  raise gr.Error(f"请求失败: {error_details}")
75
  except requests.exceptions.RequestException as e:
76
- # 捕获网络连接等问题
77
  raise gr.Error(f"连接服务器失败: {e}")
78
  except Exception as e:
79
- # 捕获其他未知错误
80
  raise gr.Error(f"发生未知错误: {e}")
81
 
82
 
@@ -106,16 +94,21 @@ with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280p
106
  )
107
 
108
  gr.Markdown("### 2. 输入数据")
109
- # 使用可编辑的DataFrame作为输入,用户可以动态增删行
 
 
 
 
110
  data_input = gr.DataFrame(
111
  label="数据表格",
112
  headers=["subject", "relation", "object", "prompt_source"],
113
  value=DEFAULT_DATAFRAME_VALUE,
114
- row_count=(1, "dynamic"), # 至少1行,动态增加
115
  col_count=(4, "fixed"),
116
- interactive=True,
117
- info="点击表格下方的 'Edit' (铅笔图标),然后点击 '+' 来添加新行,或点击垃圾桶图标删除行。"
118
  )
 
119
 
120
  submit_btn = gr.Button("提交推理", variant="primary", scale=1)
121
 
@@ -127,11 +120,9 @@ with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280p
127
  output_dataframe = gr.DataFrame(
128
  label="推理结果详情",
129
  interactive=False,
130
- # 根据您的示例输出定义列,Gradio会自动匹配
131
  headers=["subject", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
132
  )
133
  with gr.TabItem("原始JSON响应"):
134
- # 使用gr.JSON组件可以更好地格式化显示JSON
135
  output_json = gr.JSON(label="API 返回的完整JSON")
136
 
137
  # 定义按钮点击事件
@@ -143,5 +134,4 @@ with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280p
143
 
144
  # 启动Gradio应用
145
  if __name__ == "__main__":
146
- # 使用share=True可以生成一个公开链接,方便他人访问
147
  demo.launch()
 
23
  """
24
  # 1. 数据校验
25
  if not model_name.strip():
 
26
  raise gr.Error("模型名称不能为空!")
27
  if data_df is None or data_df.empty:
28
  raise gr.Error("输入数据不能为空!请在表格中至少添加一行数据。")
29
 
30
  # 2. 将输入的DataFrame转换成API需要的CSV字符串
 
31
  data_df_with_id = data_df.copy()
 
32
  data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))])
33
 
 
 
34
  output = io.StringIO()
35
  data_df_with_id.to_csv(output, index=False)
36
  csv_data = output.getvalue()
 
44
 
45
  try:
46
  # 4. 发送POST请求
 
47
  response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True)
48
+ response.raise_for_status()
49
 
 
50
  result_json = response.json()
51
 
52
  # 5. 处理和返回结果
 
53
  dataframe_result_data = result_json.get("result", [])
54
 
 
55
  if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0:
56
  output_df = pd.DataFrame(dataframe_result_data)
57
  else:
58
+ output_df = pd.DataFrame()
59
 
60
  return result_json, output_df
61
 
62
  except requests.exceptions.HTTPError as e:
 
63
  error_details = f"API返回错误 (状态码: {e.response.status_code}):\n{e.response.text}"
64
  raise gr.Error(f"请求失败: {error_details}")
65
  except requests.exceptions.RequestException as e:
 
66
  raise gr.Error(f"连接服务器失败: {e}")
67
  except Exception as e:
 
68
  raise gr.Error(f"发生未知错误: {e}")
69
 
70
 
 
94
  )
95
 
96
  gr.Markdown("### 2. 输入数据")
97
+
98
+ # -------------------- 主要修改部分在这里 --------------------
99
+ # 将原来的 info 参数内容,移到一个独立的 Markdown 组件中
100
+ gr.Markdown("ℹ️ **操作提示**: 点击表格右下方的 **✏️ Edit** 按钮,然后点击 **+ Add row** 来添加新行,或点击垃圾桶图标删除行。")
101
+
102
  data_input = gr.DataFrame(
103
  label="数据表格",
104
  headers=["subject", "relation", "object", "prompt_source"],
105
  value=DEFAULT_DATAFRAME_VALUE,
106
+ row_count=(1, "dynamic"),
107
  col_count=(4, "fixed"),
108
+ interactive=True
109
+ # 已移除 'info' 参数
110
  )
111
+ # -------------------- 修改结束 --------------------
112
 
113
  submit_btn = gr.Button("提交推理", variant="primary", scale=1)
114
 
 
120
  output_dataframe = gr.DataFrame(
121
  label="推理结果详情",
122
  interactive=False,
 
123
  headers=["subject", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
124
  )
125
  with gr.TabItem("原始JSON响应"):
 
126
  output_json = gr.JSON(label="API 返回的完整JSON")
127
 
128
  # 定义按钮点击事件
 
134
 
135
  # 启动Gradio应用
136
  if __name__ == "__main__":
 
137
  demo.launch()