lllyx commited on
Commit
10705b6
·
verified ·
1 Parent(s): 460cc42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -64
app.py CHANGED
@@ -1,84 +1,82 @@
1
- # app.py
2
  import gradio as gr
3
  import requests
4
- import os
5
 
6
- # 从Secrets获取API地址。URL现在指向新的端点 /predict_df
7
- API_URL = os.getenv("API_URL", "http://10.1.1.25:5000/predict")
 
 
8
 
9
- # 固定的model_name,也可以做成下拉框让用户选择
10
- MODEL_NAME = "/data/tsq/MODELS/Qwen2.5-1.5B-Instruct"
11
 
12
- def call_ml_server(csv_text_input, final_result_only_checkbox):
13
  """
14
- 调用后端API的函数
15
  """
16
- if "YOUR_SERVER_IP" in API_URL:
17
- raise gr.Error("错误:API服务器地址未配置!请在Space的Settings -> Secrets中设置API_URL。")
18
-
19
- # 1. 检查输入是否为空
20
- if not csv_text_input.strip():
21
- return None # 如果输入为空,返回None,Gradio会清空输出
22
-
23
- # 2. 构建发送到API的JSON payload
24
  payload = {
25
- "csv_data": csv_text_input,
26
- "model_name": MODEL_NAME,
27
- "only_final_result": final_result_only_checkbox
28
  }
29
-
30
  try:
31
- # 3. 发送POST请求
32
- response = requests.post(API_URL, json=payload, timeout=60) # 延长超时
33
- response.raise_for_status()
34
-
35
- # 4. 直接返回JSON结果,Gradio的Dataframe组件会自动处理
36
- return response.json()
37
-
38
- except requests.exceptions.HTTPError as e:
39
- # 如果API返回错误(如500),尝试解析错误详情
40
- error_detail = e.response.json().get('detail', e.response.text)
41
- raise gr.Error(f"API返回错误: {error_detail}")
42
- except requests.exceptions.RequestException as e:
43
- raise gr.Error(f"请求API时出错: {e}")
44
- except Exception as e:
45
- raise gr.Error(f"发生未知错误: {e}")
46
 
47
- # 示例数据
48
- example_data = "id,subject,relation,object,prompt_source\n001,France,capital city of,Paris,she traveled to France for the first time\n002,Japan,capital city of,Tokyo,he is planning a trip to Japan"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # 创建Gradio界面
51
- with gr.Blocks() as iface:
52
- gr.Markdown("## 云服务器DataFrame推理前端")
53
- gr.Markdown("在下方文本框中输入CSV格式的数据,然后点击“提交”。数据将被发送到私有云服务器进行处理。")
 
 
 
 
 
54
 
55
  with gr.Row():
56
- with gr.Column(scale=2):
57
- csv_input = gr.Textbox(
58
  lines=10,
59
- label="输入CSV数据",
60
- placeholder="请在这里粘贴你的CSV数据...",
61
- value=example_data # 设置默认值
62
  )
63
- final_result_checkbox = gr.Checkbox(label="Only Final Result", value=False)
64
- submit_btn = gr.Button("提交", variant="primary")
 
 
 
 
 
65
 
66
- with gr.Column(scale=3):
67
- output_df = gr.Dataframe(label="服务器返回结果", wrap=True)
68
-
69
- gr.Examples(
70
- examples=[
71
- [example_data, False],
72
- ["id,subject,relation,object,prompt_source\n003,Germany,capital city of,Berlin,Germany is a country in Europe", True]
73
- ],
74
- inputs=[csv_input, final_result_checkbox]
75
- )
76
-
77
  submit_btn.click(
78
- fn=call_ml_server,
79
- inputs=[csv_input, final_result_checkbox],
80
- outputs=output_df
81
  )
82
 
83
- # 启动界面
84
- iface.launch()
 
 
 
1
  import gradio as gr
2
  import requests
3
+ import json
4
 
5
+ # ----------------- 配置你的服务器API地址 -----------------
6
+ # 将IP地址和端口替换为你自己的
7
+ SERVER_URL = "http://103.235.229.133:8000/predict"
8
+ # -------------------------------------------------------
9
 
10
+ # 默认的输入数据,方便用户测试
11
+ DEFAULT_CSV_DATA = "id,subject,relation,object,prompt_source\n001,France,capital city of,Paris,she traveled to France for the first time\n002,Germany,capital city of,Berlin,the capital of Germany is Berlin"
12
 
13
+ def call_api(csv_data):
14
  """
15
+ 这个函数会被Gradio调用,它负责向后端API发送请求。
16
  """
17
+ if not csv_data.strip():
18
+ return "错误:输入数据不能为空!", None
19
+
20
+ # 构建发送到API的JSON数据体
 
 
 
 
21
  payload = {
22
+ "model_name": "/data/tsq/MODELS/Qwen2.5-1.5B-Instruct",
23
+ "data_csv": csv_data,
24
+ "only_final_result": False # 你可以根据需要修改这个参数
25
  }
26
+
27
  try:
28
+ # 发送POST请求
29
+ response = requests.post(SERVER_URL, json=payload, timeout=60) # 设置60秒超时
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ # 检查响应状态码
32
+ if response.status_code == 200:
33
+ # 解析返回的JSON数据
34
+ result_json = response.json()
35
+ # 为了美观地显示,我们将JSON格式化
36
+ pretty_json = json.dumps(result_json, indent=2, ensure_ascii=False)
37
+
38
+ # 假设结果在 'result' 键下,并且是一个列表
39
+ # 我们可以直接将其传递给Gradio的DataFrame组件
40
+ dataframe_result = result_json.get("result", [])
41
+
42
+ return pretty_json, dataframe_result
43
+ else:
44
+ # 如果服务器返回错误
45
+ error_message = f"API请求失败,状态码: {response.status_code}\n错误详情: {response.text}"
46
+ return error_message, None
47
 
48
+ except requests.exceptions.RequestException as e:
49
+ # 如果网络连接等问题出错
50
+ error_message = f"连接服务器失败: {e}"
51
+ return error_message, None
52
+
53
+ # 使用Gradio Blocks创建更灵活的界面
54
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
55
+ gr.Markdown("# 🚀 Ml_patch 函数推理接口")
56
+ gr.Markdown("在下方文本框中输入CSV格式的数据,然后点击“提交”按钮,调用云服务器进行推理。")
57
 
58
  with gr.Row():
59
+ with gr.Column(scale=1):
60
+ input_csv = gr.Textbox(
61
  lines=10,
62
+ label="输入数据 (CSV格式)",
63
+ info="第一行必须是表头,例如: id,subject,relation,object,prompt_source",
64
+ value=DEFAULT_CSV_DATA
65
  )
66
+ submit_btn = gr.Button("提交推理", variant="primary")
67
+
68
+ with gr.Column(scale=2):
69
+ gr.Markdown("### 原始JSON输出")
70
+ output_json = gr.Textbox(label="API 返回的完整JSON结果", interactive=False, lines=10)
71
+ gr.Markdown("### 结果表格展示")
72
+ output_dataframe = gr.DataFrame(label="推理结果", interactive=False)
73
 
 
 
 
 
 
 
 
 
 
 
 
74
  submit_btn.click(
75
+ fn=call_api,
76
+ inputs=input_csv,
77
+ outputs=[output_json, output_dataframe]
78
  )
79
 
80
+ # 启动Gradio应用
81
+ if __name__ == "__main__":
82
+ demo.launch()