lllyx commited on
Commit
9b2d6e3
·
verified ·
1 Parent(s): 041cfb4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -55
app.py CHANGED
@@ -9,21 +9,25 @@ import io
9
  SERVER_URL = "http://103.235.229.133:7000/predict"
10
  # -------------------------------------------------------
11
 
12
- # 默认的模型名称和输入数据,方便用户测试
13
- DEFAULT_MODEL_NAME = "/data/tsq/MODELS/Qwen2.5-1.5B-Instruct"
 
 
 
 
14
  DEFAULT_DATAFRAME_VALUE = [
15
  ["France", "capital city of", "Paris", "she traveled to France for the first time"],
16
- ["Germany", "capital city of", "Berlin", "the capital of Germany is Berlin"]
 
17
  ]
18
 
19
- def call_api(model_name, only_final_result, data_df):
20
  """
21
  这个函数会被Gradio调用,它负责收集UI上的输入,
22
  将其转换成API需要的格式,然后发送请求。
 
23
  """
24
- # 1. 数据校验
25
- if not model_name.strip():
26
- raise gr.Error("模型名称不能为空!")
27
  if data_df is None or data_df.empty:
28
  raise gr.Error("输入数据不能为空!请在表格中至少添加一行数据。")
29
 
@@ -35,9 +39,9 @@ def call_api(model_name, only_final_result, data_df):
35
  data_df_with_id.to_csv(output, index=False)
36
  csv_data = output.getvalue()
37
 
38
- # 3. 构建发送到API的JSON数据体
39
  payload = {
40
- "model_name": model_name,
41
  "data_csv": csv_data,
42
  "only_final_result": only_final_result
43
  }
@@ -77,58 +81,56 @@ with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280p
77
  """
78
  )
79
 
80
- with gr.Row():
81
- # 左侧列:输入配置
82
- with gr.Column(scale=2):
83
- gr.Markdown("### 1. 配置模型和参数")
84
- with gr.Group():
85
- model_name_input = gr.Textbox(
86
- label="模型名称 (model_name)",
87
- value=DEFAULT_MODEL_NAME,
88
- info="请输入部署在服务器上的模型路径。"
89
- )
90
- only_final_result_input = gr.Checkbox(
91
- label="仅返回最终结果 (only_final_result)",
92
- value=False,
93
- info="勾选此项,API将只返回最终聚合结果,而不是所有层的详细信息。"
94
- )
95
-
96
- gr.Markdown("### 2. 输入数据")
97
-
98
- # -------------------- 主要修改部分在这里 --------------------
99
- # 将原来的 info 参数内容,移到一个独立的 Markdown 组件中
100
- gr.Markdown("ℹ️ **操作提示**: 点击表格右下方的 **✏️ Edit** 按钮,然后点击 **+ Add row** 来添加新行,或点击垃圾桶图标删除行。")
101
 
102
- data_input = gr.DataFrame(
103
- label="数据表格",
104
- headers=["subject", "relation", "object", "prompt_source"],
105
- value=DEFAULT_DATAFRAME_VALUE,
106
- row_count=(1, "dynamic"),
107
- col_count=(4, "fixed"),
108
- interactive=True
109
- # 已移除 'info' 参数
110
  )
111
- # -------------------- 修改结束 --------------------
112
-
113
- submit_btn = gr.Button("提交推理", variant="primary", scale=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- # 右侧列:输出结果
116
- with gr.Column(scale=3):
117
- gr.Markdown("### 3. 查看推理结果")
118
- with gr.Tabs():
119
- with gr.TabItem("结果表格"):
120
- output_dataframe = gr.DataFrame(
121
- label="推理结果详情",
122
- interactive=False,
123
- headers=["subject", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
124
- )
125
- with gr.TabItem("原始JSON响应"):
126
- output_json = gr.JSON(label="API 返回的完整JSON")
 
127
 
128
- # 定义按钮点击事件
129
  submit_btn.click(
130
  fn=call_api,
131
- inputs=[model_name_input, only_final_result_input, data_input],
132
  outputs=[output_json, output_dataframe]
133
  )
134
 
 
9
  SERVER_URL = "http://103.235.229.133:7000/predict"
10
  # -------------------------------------------------------
11
 
12
+ # ----------------- 固定的模型名称 -----------------
13
+ # 根据要求,模型名称是固定的,不再作为UI输入
14
+ FIXED_MODEL_NAME = "/data/tsq/MODELS/Qwen2.5-1.5B-Instruct"
15
+ # -------------------------------------------------------
16
+
17
+ # 默认的输入数据,方便用户测试
18
  DEFAULT_DATAFRAME_VALUE = [
19
  ["France", "capital city of", "Paris", "she traveled to France for the first time"],
20
+ ["Germany", "capital city of", "Berlin", "Germany is a country in Central Europe"],
21
+ ["Japan", "capital city of", "Tokyo", "Tokyo is the largest city in Japan"]
22
  ]
23
 
24
+ def call_api(only_final_result, data_df):
25
  """
26
  这个函数会被Gradio调用,它负责收集UI上的输入,
27
  将其转换成API需要的格式,然后发送请求。
28
+ (已修改:不再接收 model_name 参数)
29
  """
30
+ # 1. 数据校验 (不再需要校验模型名称)
 
 
31
  if data_df is None or data_df.empty:
32
  raise gr.Error("输入数据不能为空!请在表格中至少添加一行数据。")
33
 
 
39
  data_df_with_id.to_csv(output, index=False)
40
  csv_data = output.getvalue()
41
 
42
+ # 3. 构建发送到API的JSON数据体 (使用固定的模型名称)
43
  payload = {
44
+ "model_name": FIXED_MODEL_NAME,
45
  "data_csv": csv_data,
46
  "only_final_result": only_final_result
47
  }
 
81
  """
82
  )
83
 
84
+ # ---------- 新布局:输入部分 ----------
85
+ with gr.Column():
86
+ gr.Markdown("### 1. 配置参数")
87
+ with gr.Group():
88
+ # 使用 Markdown 或不可编辑的 Textbox 展示固定的模型名称
89
+ gr.Markdown(f"**当前模型 (固定):** `{FIXED_MODEL_NAME}`")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ only_final_result_input = gr.Checkbox(
92
+ label="仅返回最终结果 (only_final_result)",
93
+ value=False,
94
+ info="勾选此项,API将只返回最终聚合结果,而不是所有层的详细信息。"
 
 
 
 
95
  )
96
+
97
+ gr.Markdown("### 2. 输入数据")
98
+
99
+ gr.Markdown("ℹ️ **操作提示**: 点击表格右下方的 **✏️ Edit** 按钮,然后点击 **+ Add row** 来添加新行,或点击垃圾桶图标删除行。")
100
+
101
+ # -------------------- 主要修改部分在这里 --------------------
102
+ # 增加了 height 参数来让表格更高,并且放在了更宽的布局中
103
+ data_input = gr.DataFrame(
104
+ label="数据表格",
105
+ headers=["subject", "relation", "object", "prompt_source"],
106
+ value=DEFAULT_DATAFRAME_VALUE,
107
+ row_count=(3, "dynamic"), # 默认显示3行
108
+ col_count=(4, "fixed"),
109
+ interactive=True,
110
+ height=300 # 明确设置组件高度,单位为像素
111
+ )
112
+ # -------------------- 修改结束 --------------------
113
+
114
+ submit_btn = gr.Button("提交推理", variant="primary")
115
 
116
+ # ---------- 新布局:输出部分 ----------
117
+ with gr.Column():
118
+ gr.Markdown("### 3. 查看推理结果")
119
+ with gr.Tabs():
120
+ with gr.TabItem("结果表格"):
121
+ output_dataframe = gr.DataFrame(
122
+ label="推理结果详情",
123
+ interactive=False,
124
+ # 增加了列数以适应可能的更多输出
125
+ headers=["id", "subject", "relation", "object", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
126
+ )
127
+ with gr.TabItem("原始JSON响应"):
128
+ output_json = gr.JSON(label="API 返回的完整JSON")
129
 
130
+ # 定义按钮点击事件 (已更新 inputs 列表)
131
  submit_btn.click(
132
  fn=call_api,
133
+ inputs=[only_final_result_input, data_input], # 移除了 model_name_input
134
  outputs=[output_json, output_dataframe]
135
  )
136