lllyx commited on
Commit
86e2749
·
verified ·
1 Parent(s): 5d33dd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -46
app.py CHANGED
@@ -1,82 +1,147 @@
1
  import gradio as gr
2
  import requests
3
  import json
 
 
4
 
5
  # ----------------- 配置你的服务器API地址 -----------------
6
  # 将IP地址和端口替换为你自己的
7
  SERVER_URL = "http://103.235.229.133:7000/predict"
8
  # -------------------------------------------------------
9
 
10
- # 默认的输入数据,方便用户测试
11
- DEFAULT_CSV_DATA = "id,subject,relation,object,prompt_source\n001,France,capital city of,Paris,she traveled to France for the first time\n002,Germany,capital city of,Berlin,the capital of Germany is Berlin"
 
 
 
 
12
 
13
- def call_api(csv_data):
14
  """
15
- 这个函数会被Gradio调用,它负责向后端API发送请求。
 
16
  """
17
- if not csv_data.strip():
18
- return "错误:输入数据不能为空!", None
 
 
 
 
19
 
20
- # 构建发送到API的JSON数据体
 
 
 
 
 
 
 
 
 
 
 
 
21
  payload = {
22
- "model_name": "/data/tsq/MODELS/Qwen2.5-1.5B-Instruct",
23
  "data_csv": csv_data,
24
- "only_final_result": False # 你可以根据需要修改这个参数
25
  }
26
 
27
  try:
28
- # 发送POST请求
29
- response = requests.post(SERVER_URL, json=payload, timeout=60) # 设置60秒超时
 
 
30
 
31
- # 检查响应状态码
32
- if response.status_code == 200:
33
- # 解析返回的JSON数据
34
- result_json = response.json()
35
- # 为了美观地显示,我们将JSON格式化
36
- pretty_json = json.dumps(result_json, indent=2, ensure_ascii=False)
37
-
38
- # 假设结果在 'result' 键下,并且是一个列表
39
- # 我们可以直接将其传递给Gradio的DataFrame组件
40
- dataframe_result = result_json.get("result", [])
41
-
42
- return pretty_json, dataframe_result
43
  else:
44
- # 如果服务器返回错误
45
- error_message = f"API请求失败,状态码: {response.status_code}\n错误详情: {response.text}"
46
- return error_message, None
47
 
 
 
 
 
48
  except requests.exceptions.RequestException as e:
49
- # 如果网络连接等问题出错
50
- error_message = f"连接服务器失败: {e}"
51
- return error_message, None
 
 
 
52
 
53
- # 使用Gradio Blocks创建更灵活的界面
54
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
55
- gr.Markdown("# 🚀 Ml_patch 函数推理接口")
56
- gr.Markdown("在下方文本框中输入CSV格式的数据,然后点击“提交”按钮,调用云服务器进行推理。")
 
 
 
 
57
 
58
  with gr.Row():
59
- with gr.Column(scale=1):
60
- input_csv = gr.Textbox(
61
- lines=10,
62
- label="输入数据 (CSV格式)",
63
- info="第一行必须是表头,例如: id,subject,relation,object,prompt_source",
64
- value=DEFAULT_CSV_DATA
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  )
66
- submit_btn = gr.Button("提交推理", variant="primary")
 
67
 
68
- with gr.Column(scale=2):
69
- gr.Markdown("### 原始JSON输出")
70
- output_json = gr.Textbox(label="API 返回的完整JSON结果", interactive=False, lines=10)
71
- gr.Markdown("### 结果表格展示")
72
- output_dataframe = gr.DataFrame(label="推理结果", interactive=False)
 
 
 
 
 
 
 
 
 
73
 
 
74
  submit_btn.click(
75
  fn=call_api,
76
- inputs=input_csv,
77
  outputs=[output_json, output_dataframe]
78
  )
79
 
80
  # 启动Gradio应用
81
  if __name__ == "__main__":
 
82
  demo.launch()
 
1
  import gradio as gr
2
  import requests
3
  import json
4
+ import pandas as pd
5
+ import io
6
 
7
  # ----------------- 配置你的服务器API地址 -----------------
8
  # 将IP地址和端口替换为你自己的
9
  SERVER_URL = "http://103.235.229.133:7000/predict"
10
  # -------------------------------------------------------
11
 
12
+ # 默认的模型名称和输入数据,方便用户测试
13
+ DEFAULT_MODEL_NAME = "/data/tsq/MODELS/Qwen2.5-1.5B-Instruct"
14
+ DEFAULT_DATAFRAME_VALUE = [
15
+ ["France", "capital city of", "Paris", "she traveled to France for the first time"],
16
+ ["Germany", "capital city of", "Berlin", "the capital of Germany is Berlin"]
17
+ ]
18
 
19
+ def call_api(model_name, only_final_result, data_df):
20
  """
21
+ 这个函数会被Gradio调用,它负责收集UI上的输入,
22
+ 将其转换成API需要的格式,然后发送请求。
23
  """
24
+ # 1. 数据校验
25
+ if not model_name.strip():
26
+ # 使用Gradio的错误提示功能
27
+ raise gr.Error("模型名称不能为空!")
28
+ if data_df is None or data_df.empty:
29
+ raise gr.Error("输入数据不能为空!请在表格中至少添加一行数据。")
30
 
31
+ # 2. 将输入的DataFrame转换成API需要的CSV字符串
32
+ # API需要 'id' 列,我们在这里自动生成
33
+ data_df_with_id = data_df.copy()
34
+ # 为id列生成唯一的、格式化的ID,例如 000, 001, 002...
35
+ data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))])
36
+
37
+ # 将DataFrame转换为CSV字符串
38
+ # 使用io.StringIO来在内存中处理字符串,避免写入文件
39
+ output = io.StringIO()
40
+ data_df_with_id.to_csv(output, index=False)
41
+ csv_data = output.getvalue()
42
+
43
+ # 3. 构建发送到API的JSON数据体
44
  payload = {
45
+ "model_name": model_name,
46
  "data_csv": csv_data,
47
+ "only_final_result": only_final_result
48
  }
49
 
50
  try:
51
+ # 4. 发送POST请求
52
+ # 添加流式传输和更长的超时时间以应对长时间运行的任务
53
+ response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True)
54
+ response.raise_for_status() # 如果状态码不是2xx,则抛出HTTPError
55
 
56
+ # 解析返回的JSON数据
57
+ result_json = response.json()
58
+
59
+ # 5. 处理和返回结果
60
+ # 从结果中提取'result'键,它应该是一个字典列表
61
+ dataframe_result_data = result_json.get("result", [])
62
+
63
+ # 将字典列表转换为Pandas DataFrame,以便Gradio的DataFrame组件更好地显示
64
+ if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0:
65
+ output_df = pd.DataFrame(dataframe_result_data)
 
 
66
  else:
67
+ output_df = pd.DataFrame() # 如果没有结果,返回一个空的DataFrame
68
+
69
+ return result_json, output_df
70
 
71
+ except requests.exceptions.HTTPError as e:
72
+ # 更详细地捕获HTTP错误
73
+ error_details = f"API返回错误 (状态码: {e.response.status_code}):\n{e.response.text}"
74
+ raise gr.Error(f"请求失败: {error_details}")
75
  except requests.exceptions.RequestException as e:
76
+ # 捕获网络连接等问题
77
+ raise gr.Error(f"连接服务器失败: {e}")
78
+ except Exception as e:
79
+ # 捕获其他未知错误
80
+ raise gr.Error(f"发生未知错误: {e}")
81
+
82
 
83
+ # --- 使用Gradio Blocks创建更灵活、更美观的界面 ---
84
+ with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280px !important; margin: auto;}") as demo:
85
+ gr.Markdown(
86
+ """
87
+ # 🚀 Ml_patch 函数推理接口
88
+ 一个为 `Ml_patch` 函数设计的交互式Web界面。请在下方配置模型参数并输入您要测试的数据。
89
+ """
90
+ )
91
 
92
  with gr.Row():
93
+ # 左侧列:输入配置
94
+ with gr.Column(scale=2):
95
+ gr.Markdown("### 1. 配置模型和参数")
96
+ with gr.Group():
97
+ model_name_input = gr.Textbox(
98
+ label="模型名称 (model_name)",
99
+ value=DEFAULT_MODEL_NAME,
100
+ info="请输入部署在服务器上的模型路径。"
101
+ )
102
+ only_final_result_input = gr.Checkbox(
103
+ label="仅返回最终结果 (only_final_result)",
104
+ value=False,
105
+ info="勾选此项,API将只返回最终聚合结果,而不是所有层的详细信息。"
106
+ )
107
+
108
+ gr.Markdown("### 2. 输入数据")
109
+ # 使用可编辑的DataFrame作为输入,用户可以动态增删行
110
+ data_input = gr.DataFrame(
111
+ label="数据表格",
112
+ headers=["subject", "relation", "object", "prompt_source"],
113
+ value=DEFAULT_DATAFRAME_VALUE,
114
+ row_count=(1, "dynamic"), # 至少1行,动态增加
115
+ col_count=(4, "fixed"),
116
+ interactive=True,
117
+ info="点击表格下方的 'Edit' (铅笔图标),然后点击 '+' 来添加新行,或点击垃圾桶图标删除行。"
118
  )
119
+
120
+ submit_btn = gr.Button("提交推理", variant="primary", scale=1)
121
 
122
+ # 右侧列:输出结果
123
+ with gr.Column(scale=3):
124
+ gr.Markdown("### 3. 查看推理结果")
125
+ with gr.Tabs():
126
+ with gr.TabItem("结果表格"):
127
+ output_dataframe = gr.DataFrame(
128
+ label="推理结果详情",
129
+ interactive=False,
130
+ # 根据您的示例输出定义列,Gradio会自动匹配
131
+ headers=["subject", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
132
+ )
133
+ with gr.TabItem("原始JSON响应"):
134
+ # 使用gr.JSON组件可以更好地格式化显示JSON
135
+ output_json = gr.JSON(label="API 返回的完整JSON")
136
 
137
+ # 定义按钮点击事件
138
  submit_btn.click(
139
  fn=call_api,
140
+ inputs=[model_name_input, only_final_result_input, data_input],
141
  outputs=[output_json, output_dataframe]
142
  )
143
 
144
  # 启动Gradio应用
145
  if __name__ == "__main__":
146
+ # 使用share=True可以生成一个公开链接,方便他人访问
147
  demo.launch()