lllyx commited on
Commit
083841f
·
verified ·
1 Parent(s): 2fc7baa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import requests
4
+ import os
5
+
6
+ # 从Secrets获取API地址。URL现在指向新的端点 /predict_df
7
+ API_URL = os.getenv("API_URL", "http://YOUR_SERVER_IP:8000/predict_df")
8
+
9
+ # 固定的model_name,也可以做成下拉框让用户选择
10
+ MODEL_NAME = "/data/tsq/MODELS/Qwen2.5-1.5B-Instruct"
11
+
12
+ def call_ml_server(csv_text_input, final_result_only_checkbox):
13
+ """
14
+ 调用后端API的函数
15
+ """
16
+ if "YOUR_SERVER_IP" in API_URL:
17
+ raise gr.Error("错误:API服务器地址未配置!请在Space的Settings -> Secrets中设置API_URL。")
18
+
19
+ # 1. 检查输入是否为空
20
+ if not csv_text_input.strip():
21
+ return None # 如果输入为空,返回None,Gradio会清空输出
22
+
23
+ # 2. 构建发送到API的JSON payload
24
+ payload = {
25
+ "csv_data": csv_text_input,
26
+ "model_name": MODEL_NAME,
27
+ "only_final_result": final_result_only_checkbox
28
+ }
29
+
30
+ try:
31
+ # 3. 发送POST请求
32
+ response = requests.post(API_URL, json=payload, timeout=60) # 延长超时
33
+ response.raise_for_status()
34
+
35
+ # 4. 直接返回JSON结果,Gradio的Dataframe组件会自动处理
36
+ return response.json()
37
+
38
+ except requests.exceptions.HTTPError as e:
39
+ # 如果API返回错误(如500),尝试解析错误详情
40
+ error_detail = e.response.json().get('detail', e.response.text)
41
+ raise gr.Error(f"API返回错误: {error_detail}")
42
+ except requests.exceptions.RequestException as e:
43
+ raise gr.Error(f"请求API时出错: {e}")
44
+ except Exception as e:
45
+ raise gr.Error(f"发生未知错误: {e}")
46
+
47
+ # 示例数据
48
+ example_data = "id,subject,relation,object,prompt_source\n001,France,capital city of,Paris,she traveled to France for the first time\n002,Japan,capital city of,Tokyo,he is planning a trip to Japan"
49
+
50
+ # 创建Gradio界面
51
+ with gr.Blocks() as iface:
52
+ gr.Markdown("## 云服务器DataFrame推理前端")
53
+ gr.Markdown("在下方文本框中输入CSV格式的数据,然后点击“提交”。数据将被发送到私有云服务器进行处理。")
54
+
55
+ with gr.Row():
56
+ with gr.Column(scale=2):
57
+ csv_input = gr.Textbox(
58
+ lines=10,
59
+ label="输入CSV数据",
60
+ placeholder="请在这里粘贴你的CSV数据...",
61
+ value=example_data # 设置默认值
62
+ )
63
+ final_result_checkbox = gr.Checkbox(label="Only Final Result", value=False)
64
+ submit_btn = gr.Button("提交", variant="primary")
65
+
66
+ with gr.Column(scale=3):
67
+ output_df = gr.Dataframe(label="服务器返回结果", wrap=True)
68
+
69
+ gr.Examples(
70
+ examples=[
71
+ [example_data, False],
72
+ ["id,subject,relation,object,prompt_source\n003,Germany,capital city of,Berlin,Germany is a country in Europe", True]
73
+ ],
74
+ inputs=[csv_input, final_result_checkbox]
75
+ )
76
+
77
+ submit_btn.click(
78
+ fn=call_ml_server,
79
+ inputs=[csv_input, final_result_checkbox],
80
+ outputs=output_df
81
+ )
82
+
83
+ # 启动界面
84
+ iface.launch()