lllyx commited on
Commit
16d6124
·
verified ·
1 Parent(s): b553c66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -14
app.py CHANGED
@@ -5,23 +5,20 @@ import pandas as pd
5
  import io
6
 
7
  # ----------------- Configure Your Server API Address -----------------
8
- # Replace the IP address and port with your own
9
  SERVER_URL = "http://103.235.229.133:7000/predict"
10
  # ---------------------------------------------------------------------
11
 
12
  # ----------------- Fixed Model Name -----------------
13
- # As per requirements, the model name is fixed and not a UI input.
14
  FIXED_MODEL_NAME = "Qwen2.5-1.5B-Instruct"
15
  # ----------------------------------------------------
16
 
17
  # Default input data for easy testing
18
  DEFAULT_DATAFRAME_VALUE = [
19
  ["France", "capital city of", "Paris", "she traveled to France for the first time"],
20
- # ["India", "official currency of", "Rupee", "caves of southern India and similar evidence"],
21
- # ["Australia", "largest city in", "Sydney", "Grainger left Australia at the age of 13"]
22
  ]
23
 
24
- def call_api(only_final_result, data_df):
 
25
  """
26
  This function is called by Gradio. It collects inputs from the UI,
27
  formats them for the API, and sends the request.
@@ -38,19 +35,18 @@ def call_api(only_final_result, data_df):
38
  data_df_with_id.to_csv(output, index=False)
39
  csv_data = output.getvalue()
40
 
41
- # Prepare the payload for the API
42
  payload = {
43
  "data_csv": csv_data,
 
44
  "only_final_result": only_final_result
45
  }
46
 
47
  try:
48
- # Send the request to the server
49
  response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True)
50
- response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
51
  result_json = response.json()
52
 
53
- # Extract the result data for the DataFrame
54
  dataframe_result_data = result_json.get("result", [])
55
 
56
  if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0:
@@ -83,10 +79,19 @@ with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280p
83
  gr.Markdown("### 1. Configure Parameters")
84
  with gr.Group():
85
  gr.Markdown(f"**Current Model (Fixed):** `{FIXED_MODEL_NAME}`")
 
 
 
 
 
 
 
 
 
86
  only_final_result_input = gr.Checkbox(
87
  label="Only Return Final Result (only_final_result)",
88
  value=False,
89
- info="If checked, the API will only return the final aggregated result, not the detailed information for all layers."
90
  )
91
 
92
  gr.Markdown("### 2. Input Data")
@@ -96,8 +101,6 @@ with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280p
96
  label="Data Table",
97
  headers=["subject", "relation", "object", "prompt_source"],
98
  value=DEFAULT_DATAFRAME_VALUE,
99
- # Set the initial number of rows to 5 to give the table more vertical space on load.
100
- # (rows_to_display, "dynamic" allows user to add/remove rows)
101
  row_count=(5, "dynamic"),
102
  col_count=(4, "fixed"),
103
  interactive=True
@@ -112,17 +115,20 @@ with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280p
112
  output_dataframe = gr.DataFrame(
113
  label="Inference Result Details",
114
  interactive=False,
 
115
  headers=["id", "subject", "relation", "object", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
116
  )
117
  with gr.TabItem("Raw JSON Response"):
118
  output_json = gr.JSON(label="Full JSON response from API")
119
 
 
 
120
  submit_btn.click(
121
  fn=call_api,
122
- inputs=[only_final_result_input, data_input],
123
  outputs=[output_json, output_dataframe]
124
  )
125
 
126
  # Launch the Gradio app
127
  if __name__ == "__main__":
128
- demo.launch()
 
5
  import io
6
 
7
  # ----------------- Configure Your Server API Address -----------------
 
8
  SERVER_URL = "http://103.235.229.133:7000/predict"
9
  # ---------------------------------------------------------------------
10
 
11
  # ----------------- Fixed Model Name -----------------
 
12
  FIXED_MODEL_NAME = "Qwen2.5-1.5B-Instruct"
13
  # ----------------------------------------------------
14
 
15
  # Default input data for easy testing
16
  DEFAULT_DATAFRAME_VALUE = [
17
  ["France", "capital city of", "Paris", "she traveled to France for the first time"],
 
 
18
  ]
19
 
20
+ # --- KEY CHANGE 1: Update the function signature to accept the new parameter ---
21
+ def call_api(patched_layer_number, only_final_result, data_df):
22
  """
23
  This function is called by Gradio. It collects inputs from the UI,
24
  formats them for the API, and sends the request.
 
35
  data_df_with_id.to_csv(output, index=False)
36
  csv_data = output.getvalue()
37
 
38
+ # --- KEY CHANGE 2: Add the new parameter to the API payload ---
39
  payload = {
40
  "data_csv": csv_data,
41
+ "patched_layer_number": patched_layer_number,
42
  "only_final_result": only_final_result
43
  }
44
 
45
  try:
 
46
  response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True)
47
+ response.raise_for_status()
48
  result_json = response.json()
49
 
 
50
  dataframe_result_data = result_json.get("result", [])
51
 
52
  if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0:
 
79
  gr.Markdown("### 1. Configure Parameters")
80
  with gr.Group():
81
  gr.Markdown(f"**Current Model (Fixed):** `{FIXED_MODEL_NAME}`")
82
+
83
+ # --- KEY CHANGE 3: Add the new Gradio input component ---
84
+ patched_layer_number_input = gr.Number(
85
+ label="Patched Layer Number (patched_layer_number)",
86
+ value=15, # Set a reasonable default value
87
+ precision=0, # This ensures the input is an integer
88
+ info="Specify the target layer number for patching."
89
+ )
90
+
91
  only_final_result_input = gr.Checkbox(
92
  label="Only Return Final Result (only_final_result)",
93
  value=False,
94
+ info="If checked, the API will only return the final aggregated result, not detailed info for all layers."
95
  )
96
 
97
  gr.Markdown("### 2. Input Data")
 
101
  label="Data Table",
102
  headers=["subject", "relation", "object", "prompt_source"],
103
  value=DEFAULT_DATAFRAME_VALUE,
 
 
104
  row_count=(5, "dynamic"),
105
  col_count=(4, "fixed"),
106
  interactive=True
 
115
  output_dataframe = gr.DataFrame(
116
  label="Inference Result Details",
117
  interactive=False,
118
+ # You might want to update headers if the output changes based on this parameter
119
  headers=["id", "subject", "relation", "object", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
120
  )
121
  with gr.TabItem("Raw JSON Response"):
122
  output_json = gr.JSON(label="Full JSON response from API")
123
 
124
+ # --- KEY CHANGE 4: Update the click event to include the new input ---
125
+ # The order of inputs here MUST match the order of arguments in the `call_api` function.
126
  submit_btn.click(
127
  fn=call_api,
128
+ inputs=[patched_layer_number_input, only_final_result_input, data_input],
129
  outputs=[output_json, output_dataframe]
130
  )
131
 
132
  # Launch the Gradio app
133
  if __name__ == "__main__":
134
+ demo.launch()