File size: 6,144 Bytes
083841f 10705b6 86e2749 083841f f63b31c 5d33dd6 f63b31c 083841f f63b31c a628b4c f63b31c 9b2d6e3 f63b31c 86e2749 083841f 16d6124 083841f f63b31c 083841f 86e2749 f63b31c 10705b6 f63b31c 86e2749 f63b31c 86e2749 16d6124 083841f 10705b6 16d6124 86e2749 083841f 10705b6 083841f 86e2749 16d6124 86e2749 10705b6 041cfb4 86e2749 083841f 86e2749 f63b31c 10705b6 f63b31c 86e2749 f63b31c 86e2749 10705b6 f63b31c 86e2749 f63b31c 86e2749 083841f 9b2d6e3 f63b31c 9b2d6e3 f63b31c 16d6124 ac28d4d 16d6124 9b2d6e3 f63b31c 9b2d6e3 16d6124 083841f 9b2d6e3 f63b31c 3d60ca8 9b2d6e3 f63b31c 9b2d6e3 a57465e 9b2d6e3 a57465e 9b2d6e3 f63b31c 10705b6 9b2d6e3 f63b31c 3d60ca8 9b2d6e3 f63b31c 9b2d6e3 f63b31c 9b2d6e3 16d6124 9b2d6e3 f63b31c 083841f 16d6124 083841f 10705b6 16d6124 10705b6 083841f f63b31c 10705b6 16d6124 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import gradio as gr
import requests
import json
import pandas as pd
import io
# ----------------- Configure Your Server API Address -----------------
SERVER_URL = "http://103.235.229.133:7000/predict"
# ---------------------------------------------------------------------
# ----------------- Fixed Model Name -----------------
FIXED_MODEL_NAME = "Qwen2.5-1.5B-Instruct"
# ----------------------------------------------------
# Default input data for easy testing
DEFAULT_DATAFRAME_VALUE = [
["France", "capital city of", "Paris", "she traveled to France for the first time"],
]
# --- KEY CHANGE 1: Update the function signature to accept the new parameter ---
def call_api(patched_layer_number, only_final_result, data_df):
"""
This function is called by Gradio. It collects inputs from the UI,
formats them for the API, and sends the request.
"""
if data_df is None or data_df.empty:
raise gr.Error("Input data cannot be empty! Please add at least one row to the table.")
# Create a copy and add an 'id' column
data_df_with_id = data_df.copy()
data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))])
# Convert DataFrame to CSV string
output = io.StringIO()
data_df_with_id.to_csv(output, index=False)
csv_data = output.getvalue()
# --- KEY CHANGE 2: Add the new parameter to the API payload ---
payload = {
"data_csv": csv_data,
"patched_layer_number": patched_layer_number,
"only_final_result": only_final_result
}
try:
response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True)
response.raise_for_status()
result_json = response.json()
dataframe_result_data = result_json.get("result", [])
if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0:
output_df = pd.DataFrame(dataframe_result_data)
else:
output_df = pd.DataFrame()
return result_json, output_df
except requests.exceptions.HTTPError as e:
error_details = f"API returned an error (Status Code: {e.response.status_code}):\n{e.response.text}"
raise gr.Error(f"Request failed: {error_details}")
except requests.exceptions.RequestException as e:
raise gr.Error(f"Failed to connect to the server: {e}")
except Exception as e:
raise gr.Error(f"An unknown error occurred: {e}")
# --- Use Gradio Blocks for a more flexible and beautiful interface ---
with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280px !important; margin: auto;}") as demo:
gr.Markdown(
"""
# 🚀 Ml_patch Function Inference Interface
An interactive web interface designed for the `Ml_patch` function.
Please configure the model parameters and enter your test data below.
"""
)
with gr.Column():
gr.Markdown("### 1. Configure Parameters")
with gr.Group():
gr.Markdown(f"**Current Model (Fixed):** `{FIXED_MODEL_NAME}`")
# --- KEY CHANGE 3: Add the new Gradio input component ---
patched_layer_number_input = gr.Number(
label="Patched Layer Number (patched_layer_number)",
value=3, # Set a reasonable default value
precision=0, # This ensures the input is an integer
info="Specify the target layer number for patching."
)
only_final_result_input = gr.Checkbox(
label="Only Return Final Result (only_final_result)",
value=False,
info="If checked, the API will only return the final aggregated result, not detailed info for all layers."
)
gr.Markdown("### 2. Input Data")
gr.Markdown("ℹ️ **Instructions**: Click the **✏️ Edit** button at the bottom right of the table, then click **+ Add row** to add a new row, or the trash can icon to delete a row.")
gr.Markdown("**[subject,relation,object]**: A factual triple.")
gr.Markdown("**prompt_source**: A sentence which contains subject.")
data_input = gr.DataFrame(
label="Data Table",
headers=["subject", "relation", "object", "prompt_source"],
value=DEFAULT_DATAFRAME_VALUE,
row_count=(5, "dynamic"),
col_count=(4, "fixed"),
interactive=True
)
submit_btn = gr.Button("Submit for Inference", variant="primary")
with gr.Column():
gr.Markdown("### 3. View Inference Results")
gr.Markdown("**layer_source**: Starting from this layer of **prompt_source**, we extracted the hidden states corresponding to the subject.")
gr.Markdown("**is_correct_patched**: Whether the final answer is correct. If **object** is in **generations**, we think the final answer is correct.")
gr.Markdown("**generations**: The output obtained after the model continues inference following a patch operation.")
with gr.Tabs():
with gr.TabItem("Results Table"):
output_dataframe = gr.DataFrame(
label="Inference Result Details",
interactive=False,
# You might want to update headers if the output changes based on this parameter
headers=["id", "subject", "relation", "object", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
)
with gr.TabItem("Raw JSON Response"):
output_json = gr.JSON(label="Full JSON response from API")
# --- KEY CHANGE 4: Update the click event to include the new input ---
# The order of inputs here MUST match the order of arguments in the `call_api` function.
submit_btn.click(
fn=call_api,
inputs=[patched_layer_number_input, only_final_result_input, data_input],
outputs=[output_json, output_dataframe]
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch()
|