import gradio as gr import requests import json import pandas as pd import io # ----------------- Configure Your Server API Address ----------------- # Replace the IP address and port with your own SERVER_URL = "http://103.235.229.133:7000/predict" # --------------------------------------------------------------------- # ----------------- Fixed Model Name ----------------- # As per requirements, the model name is fixed and not a UI input. FIXED_MODEL_NAME = "Qwen2.5-1.5B-Instruct" # ---------------------------------------------------- # Default input data for easy testing DEFAULT_DATAFRAME_VALUE = [ ["France", "capital city of", "Paris", "she traveled to France for the first time"], ["Germany", "capital city of", "Berlin", "Germany is a country in Central Europe"], ["Japan", "capital city of", "Tokyo", "Tokyo is the largest city in Japan"] ] def call_api(only_final_result, data_df): """ This function is called by Gradio. It collects inputs from the UI, formats them for the API, and sends the request. """ if data_df is None or data_df.empty: raise gr.Error("Input data cannot be empty! Please add at least one row to the table.") # Create a copy and add an 'id' column data_df_with_id = data_df.copy() data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))]) # Convert DataFrame to CSV string output = io.StringIO() data_df_with_id.to_csv(output, index=False) csv_data = output.getvalue() # Prepare the payload for the API payload = { "data_csv": csv_data, "only_final_result": only_final_result } try: # Send the request to the server response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True) response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx) result_json = response.json() # Extract the result data for the DataFrame dataframe_result_data = result_json.get("result", []) if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0: output_df = pd.DataFrame(dataframe_result_data) else: output_df = pd.DataFrame() return result_json, output_df except requests.exceptions.HTTPError as e: error_details = f"API returned an error (Status Code: {e.response.status_code}):\n{e.response.text}" raise gr.Error(f"Request failed: {error_details}") except requests.exceptions.RequestException as e: raise gr.Error(f"Failed to connect to the server: {e}") except Exception as e: raise gr.Error(f"An unknown error occurred: {e}") # --- Use Gradio Blocks for a more flexible and beautiful interface --- with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280px !important; margin: auto;}") as demo: gr.Markdown( """ # 🚀 Ml_patch Function Inference Interface An interactive web interface designed for the `Ml_patch` function. Please configure the model parameters and enter your test data below. """ ) with gr.Column(): gr.Markdown("### 1. Configure Parameters") with gr.Group(): gr.Markdown(f"**Current Model (Fixed):** `{FIXED_MODEL_NAME}`") only_final_result_input = gr.Checkbox( label="Only Return Final Result (only_final_result)", value=False, info="If checked, the API will only return the final aggregated result, not the detailed information for all layers." ) gr.Markdown("### 2. Input Data") gr.Markdown("â„šī¸ **Instructions**: Click the **âœī¸ Edit** button at the bottom right of the table, then click **+ Add row** to add a new row, or the trash can icon to delete a row.") data_input = gr.DataFrame( label="Data Table", headers=["subject", "relation", "object", "prompt_source"], value=DEFAULT_DATAFRAME_VALUE, # Set the initial number of rows to 5 to give the table more vertical space on load. # (rows_to_display, "dynamic" allows user to add/remove rows) row_count=(5, "dynamic"), col_count=(4, "fixed"), interactive=True ) submit_btn = gr.Button("Submit for Inference", variant="primary") with gr.Column(): gr.Markdown("### 3. View Inference Results") with gr.Tabs(): with gr.TabItem("Results Table"): output_dataframe = gr.DataFrame( label="Inference Result Details", interactive=False, headers=["id", "subject", "relation", "object", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"] ) with gr.TabItem("Raw JSON Response"): output_json = gr.JSON(label="Full JSON response from API") submit_btn.click( fn=call_api, inputs=[only_final_result_input, data_input], outputs=[output_json, output_dataframe] ) # Launch the Gradio app if __name__ == "__main__": demo.launch()