File size: 5,268 Bytes
083841f 10705b6 86e2749 083841f f63b31c 5d33dd6 f63b31c 083841f f63b31c a628b4c f63b31c 9b2d6e3 f63b31c 86e2749 9b2d6e3 86e2749 083841f 9b2d6e3 083841f f63b31c 083841f 86e2749 f63b31c 10705b6 f63b31c 86e2749 f63b31c 86e2749 f63b31c 083841f 10705b6 86e2749 083841f 10705b6 083841f f63b31c 86e2749 f63b31c 86e2749 f63b31c 86e2749 10705b6 041cfb4 86e2749 083841f 86e2749 f63b31c 10705b6 f63b31c 86e2749 f63b31c 86e2749 10705b6 f63b31c 86e2749 f63b31c 86e2749 083841f 9b2d6e3 f63b31c 9b2d6e3 f63b31c 9b2d6e3 f63b31c 9b2d6e3 f63b31c 083841f 9b2d6e3 f63b31c 9b2d6e3 f63b31c 9b2d6e3 f63b31c a57465e 9b2d6e3 a57465e 9b2d6e3 f63b31c 10705b6 9b2d6e3 f63b31c 9b2d6e3 f63b31c 9b2d6e3 f63b31c 9b2d6e3 f63b31c 083841f 10705b6 a57465e 10705b6 083841f f63b31c 10705b6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | import gradio as gr
import requests
import json
import pandas as pd
import io
# ----------------- Configure Your Server API Address -----------------
# Replace the IP address and port with your own
SERVER_URL = "http://103.235.229.133:7000/predict"
# ---------------------------------------------------------------------
# ----------------- Fixed Model Name -----------------
# As per requirements, the model name is fixed and not a UI input.
FIXED_MODEL_NAME = "Qwen2.5-1.5B-Instruct"
# ----------------------------------------------------
# Default input data for easy testing
DEFAULT_DATAFRAME_VALUE = [
["France", "capital city of", "Paris", "she traveled to France for the first time"],
["Germany", "capital city of", "Berlin", "Germany is a country in Central Europe"],
["Japan", "capital city of", "Tokyo", "Tokyo is the largest city in Japan"]
]
def call_api(only_final_result, data_df):
"""
This function is called by Gradio. It collects inputs from the UI,
formats them for the API, and sends the request.
"""
if data_df is None or data_df.empty:
raise gr.Error("Input data cannot be empty! Please add at least one row to the table.")
# Create a copy and add an 'id' column
data_df_with_id = data_df.copy()
data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))])
# Convert DataFrame to CSV string
output = io.StringIO()
data_df_with_id.to_csv(output, index=False)
csv_data = output.getvalue()
# Prepare the payload for the API
payload = {
"data_csv": csv_data,
"only_final_result": only_final_result
}
try:
# Send the request to the server
response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True)
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
result_json = response.json()
# Extract the result data for the DataFrame
dataframe_result_data = result_json.get("result", [])
if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0:
output_df = pd.DataFrame(dataframe_result_data)
else:
output_df = pd.DataFrame()
return result_json, output_df
except requests.exceptions.HTTPError as e:
error_details = f"API returned an error (Status Code: {e.response.status_code}):\n{e.response.text}"
raise gr.Error(f"Request failed: {error_details}")
except requests.exceptions.RequestException as e:
raise gr.Error(f"Failed to connect to the server: {e}")
except Exception as e:
raise gr.Error(f"An unknown error occurred: {e}")
# --- Use Gradio Blocks for a more flexible and beautiful interface ---
with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280px !important; margin: auto;}") as demo:
gr.Markdown(
"""
# 🚀 Ml_patch Function Inference Interface
An interactive web interface designed for the `Ml_patch` function.
Please configure the model parameters and enter your test data below.
"""
)
with gr.Column():
gr.Markdown("### 1. Configure Parameters")
with gr.Group():
gr.Markdown(f"**Current Model (Fixed):** `{FIXED_MODEL_NAME}`")
only_final_result_input = gr.Checkbox(
label="Only Return Final Result (only_final_result)",
value=False,
info="If checked, the API will only return the final aggregated result, not the detailed information for all layers."
)
gr.Markdown("### 2. Input Data")
gr.Markdown("ℹ️ **Instructions**: Click the **✏️ Edit** button at the bottom right of the table, then click **+ Add row** to add a new row, or the trash can icon to delete a row.")
data_input = gr.DataFrame(
label="Data Table",
headers=["subject", "relation", "object", "prompt_source"],
value=DEFAULT_DATAFRAME_VALUE,
# Set the initial number of rows to 5 to give the table more vertical space on load.
# (rows_to_display, "dynamic" allows user to add/remove rows)
row_count=(5, "dynamic"),
col_count=(4, "fixed"),
interactive=True
)
submit_btn = gr.Button("Submit for Inference", variant="primary")
with gr.Column():
gr.Markdown("### 3. View Inference Results")
with gr.Tabs():
with gr.TabItem("Results Table"):
output_dataframe = gr.DataFrame(
label="Inference Result Details",
interactive=False,
headers=["id", "subject", "relation", "object", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
)
with gr.TabItem("Raw JSON Response"):
output_json = gr.JSON(label="Full JSON response from API")
submit_btn.click(
fn=call_api,
inputs=[only_final_result_input, data_input],
outputs=[output_json, output_dataframe]
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch() |