import gradio as gr import requests import json import pandas as pd import io # ----------------- Configure Your Server API Address ----------------- SERVER_URL = "http://103.235.229.133:7000/predict" # --------------------------------------------------------------------- # ----------------- Fixed Model Name ----------------- FIXED_MODEL_NAME = "Qwen2.5-1.5B-Instruct" # ---------------------------------------------------- # Default input data for easy testing DEFAULT_DATAFRAME_VALUE = [ ["France", "capital city of", "Paris", "she traveled to France for the first time"], ] # --- KEY CHANGE 1: Update the function signature to accept the new parameter --- def call_api(patched_layer_number, only_final_result, data_df): """ This function is called by Gradio. It collects inputs from the UI, formats them for the API, and sends the request. """ if data_df is None or data_df.empty: raise gr.Error("Input data cannot be empty! Please add at least one row to the table.") # Create a copy and add an 'id' column data_df_with_id = data_df.copy() data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))]) # Convert DataFrame to CSV string output = io.StringIO() data_df_with_id.to_csv(output, index=False) csv_data = output.getvalue() # --- KEY CHANGE 2: Add the new parameter to the API payload --- payload = { "data_csv": csv_data, "patched_layer_number": patched_layer_number, "only_final_result": only_final_result } try: response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True) response.raise_for_status() result_json = response.json() dataframe_result_data = result_json.get("result", []) if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0: output_df = pd.DataFrame(dataframe_result_data) else: output_df = pd.DataFrame() return result_json, output_df except requests.exceptions.HTTPError as e: error_details = f"API returned an error (Status Code: {e.response.status_code}):\n{e.response.text}" raise gr.Error(f"Request failed: {error_details}") except requests.exceptions.RequestException as e: raise gr.Error(f"Failed to connect to the server: {e}") except Exception as e: raise gr.Error(f"An unknown error occurred: {e}") # --- Use Gradio Blocks for a more flexible and beautiful interface --- with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280px !important; margin: auto;}") as demo: gr.Markdown( """ # 🚀 Ml_patch Function Inference Interface An interactive web interface designed for the `Ml_patch` function. Please configure the model parameters and enter your test data below. """ ) with gr.Column(): gr.Markdown("### 1. Configure Parameters") with gr.Group(): gr.Markdown(f"**Current Model (Fixed):** `{FIXED_MODEL_NAME}`") # --- KEY CHANGE 3: Add the new Gradio input component --- patched_layer_number_input = gr.Number( label="Patched Layer Number (patched_layer_number)", value=3, # Set a reasonable default value precision=0, # This ensures the input is an integer info="Specify the target layer number for patching." ) only_final_result_input = gr.Checkbox( label="Only Return Final Result (only_final_result)", value=False, info="If checked, the API will only return the final aggregated result, not detailed info for all layers." ) gr.Markdown("### 2. Input Data") gr.Markdown("â„šī¸ **Instructions**: Click the **âœī¸ Edit** button at the bottom right of the table, then click **+ Add row** to add a new row, or the trash can icon to delete a row.") gr.Markdown("**[subject,relation,object]**: A factual triple.") gr.Markdown("**prompt_source**: A sentence which contains subject.") data_input = gr.DataFrame( label="Data Table", headers=["subject", "relation", "object", "prompt_source"], value=DEFAULT_DATAFRAME_VALUE, row_count=(5, "dynamic"), col_count=(4, "fixed"), interactive=True ) submit_btn = gr.Button("Submit for Inference", variant="primary") with gr.Column(): gr.Markdown("### 3. View Inference Results") gr.Markdown("**layer_source**: Starting from this layer of **prompt_source**, we extracted the hidden states corresponding to the subject.") gr.Markdown("**is_correct_patched**: Whether the final answer is correct. If **object** is in **generations**, we think the final answer is correct.") gr.Markdown("**generations**: The output obtained after the model continues inference following a patch operation.") with gr.Tabs(): with gr.TabItem("Results Table"): output_dataframe = gr.DataFrame( label="Inference Result Details", interactive=False, # You might want to update headers if the output changes based on this parameter headers=["id", "subject", "relation", "object", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"] ) with gr.TabItem("Raw JSON Response"): output_json = gr.JSON(label="Full JSON response from API") # --- KEY CHANGE 4: Update the click event to include the new input --- # The order of inputs here MUST match the order of arguments in the `call_api` function. submit_btn.click( fn=call_api, inputs=[patched_layer_number_input, only_final_result_input, data_input], outputs=[output_json, output_dataframe] ) # Launch the Gradio app if __name__ == "__main__": demo.launch()