ML_Patch / app.py
lllyx's picture
Update app.py
f63b31c verified
raw
history blame
5.27 kB
import gradio as gr
import requests
import json
import pandas as pd
import io
# ----------------- Configure Your Server API Address -----------------
# Replace the IP address and port with your own
SERVER_URL = "http://103.235.229.133:7000/predict"
# ---------------------------------------------------------------------
# ----------------- Fixed Model Name -----------------
# As per requirements, the model name is fixed and not a UI input.
FIXED_MODEL_NAME = "Qwen2.5-1.5B-Instruct"
# ----------------------------------------------------
# Default input data for easy testing
DEFAULT_DATAFRAME_VALUE = [
["France", "capital city of", "Paris", "she traveled to France for the first time"],
["Germany", "capital city of", "Berlin", "Germany is a country in Central Europe"],
["Japan", "capital city of", "Tokyo", "Tokyo is the largest city in Japan"]
]
def call_api(only_final_result, data_df):
"""
This function is called by Gradio. It collects inputs from the UI,
formats them for the API, and sends the request.
"""
if data_df is None or data_df.empty:
raise gr.Error("Input data cannot be empty! Please add at least one row to the table.")
# Create a copy and add an 'id' column
data_df_with_id = data_df.copy()
data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))])
# Convert DataFrame to CSV string
output = io.StringIO()
data_df_with_id.to_csv(output, index=False)
csv_data = output.getvalue()
# Prepare the payload for the API
payload = {
"data_csv": csv_data,
"only_final_result": only_final_result
}
try:
# Send the request to the server
response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True)
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
result_json = response.json()
# Extract the result data for the DataFrame
dataframe_result_data = result_json.get("result", [])
if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0:
output_df = pd.DataFrame(dataframe_result_data)
else:
output_df = pd.DataFrame()
return result_json, output_df
except requests.exceptions.HTTPError as e:
error_details = f"API returned an error (Status Code: {e.response.status_code}):\n{e.response.text}"
raise gr.Error(f"Request failed: {error_details}")
except requests.exceptions.RequestException as e:
raise gr.Error(f"Failed to connect to the server: {e}")
except Exception as e:
raise gr.Error(f"An unknown error occurred: {e}")
# --- Use Gradio Blocks for a more flexible and beautiful interface ---
with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280px !important; margin: auto;}") as demo:
gr.Markdown(
"""
# 🚀 Ml_patch Function Inference Interface
An interactive web interface designed for the `Ml_patch` function.
Please configure the model parameters and enter your test data below.
"""
)
with gr.Column():
gr.Markdown("### 1. Configure Parameters")
with gr.Group():
gr.Markdown(f"**Current Model (Fixed):** `{FIXED_MODEL_NAME}`")
only_final_result_input = gr.Checkbox(
label="Only Return Final Result (only_final_result)",
value=False,
info="If checked, the API will only return the final aggregated result, not the detailed information for all layers."
)
gr.Markdown("### 2. Input Data")
gr.Markdown("ℹ️ **Instructions**: Click the **✏️ Edit** button at the bottom right of the table, then click **+ Add row** to add a new row, or the trash can icon to delete a row.")
data_input = gr.DataFrame(
label="Data Table",
headers=["subject", "relation", "object", "prompt_source"],
value=DEFAULT_DATAFRAME_VALUE,
# Set the initial number of rows to 5 to give the table more vertical space on load.
# (rows_to_display, "dynamic" allows user to add/remove rows)
row_count=(5, "dynamic"),
col_count=(4, "fixed"),
interactive=True
)
submit_btn = gr.Button("Submit for Inference", variant="primary")
with gr.Column():
gr.Markdown("### 3. View Inference Results")
with gr.Tabs():
with gr.TabItem("Results Table"):
output_dataframe = gr.DataFrame(
label="Inference Result Details",
interactive=False,
headers=["id", "subject", "relation", "object", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
)
with gr.TabItem("Raw JSON Response"):
output_json = gr.JSON(label="Full JSON response from API")
submit_btn.click(
fn=call_api,
inputs=[only_final_result_input, data_input],
outputs=[output_json, output_dataframe]
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch()