File size: 5,268 Bytes
083841f
 
10705b6
86e2749
 
083841f
f63b31c
 
5d33dd6
f63b31c
083841f
f63b31c
 
a628b4c
f63b31c
9b2d6e3
f63b31c
86e2749
 
9b2d6e3
 
86e2749
083841f
9b2d6e3
083841f
f63b31c
 
083841f
86e2749
f63b31c
10705b6
f63b31c
86e2749
 
 
f63b31c
86e2749
 
 
 
f63b31c
083841f
10705b6
86e2749
083841f
10705b6
083841f
f63b31c
86e2749
f63b31c
86e2749
 
f63b31c
86e2749
 
 
 
10705b6
041cfb4
86e2749
 
083841f
86e2749
f63b31c
 
10705b6
f63b31c
86e2749
f63b31c
86e2749
10705b6
f63b31c
86e2749
 
 
f63b31c
 
 
86e2749
 
083841f
9b2d6e3
f63b31c
9b2d6e3
f63b31c
9b2d6e3
f63b31c
9b2d6e3
f63b31c
083841f
9b2d6e3
f63b31c
 
9b2d6e3
 
f63b31c
9b2d6e3
 
f63b31c
 
a57465e
9b2d6e3
a57465e
9b2d6e3
 
f63b31c
10705b6
9b2d6e3
f63b31c
9b2d6e3
f63b31c
9b2d6e3
f63b31c
9b2d6e3
 
 
f63b31c
 
083841f
 
10705b6
a57465e
10705b6
083841f
 
f63b31c
10705b6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import requests
import json
import pandas as pd
import io

# ----------------- Configure Your Server API Address -----------------
# Replace the IP address and port with your own
SERVER_URL = "http://103.235.229.133:7000/predict"
# ---------------------------------------------------------------------

# ----------------- Fixed Model Name -----------------
# As per requirements, the model name is fixed and not a UI input.
FIXED_MODEL_NAME = "Qwen2.5-1.5B-Instruct"
# ----------------------------------------------------

# Default input data for easy testing
DEFAULT_DATAFRAME_VALUE = [
    ["France", "capital city of", "Paris", "she traveled to France for the first time"],
    ["Germany", "capital city of", "Berlin", "Germany is a country in Central Europe"],
    ["Japan", "capital city of", "Tokyo", "Tokyo is the largest city in Japan"]
]

def call_api(only_final_result, data_df):
    """
    This function is called by Gradio. It collects inputs from the UI,
    formats them for the API, and sends the request.
    """
    if data_df is None or data_df.empty:
        raise gr.Error("Input data cannot be empty! Please add at least one row to the table.")

    # Create a copy and add an 'id' column
    data_df_with_id = data_df.copy()
    data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))])
    
    # Convert DataFrame to CSV string
    output = io.StringIO()
    data_df_with_id.to_csv(output, index=False)
    csv_data = output.getvalue()
    
    # Prepare the payload for the API
    payload = {
        "data_csv": csv_data,
        "only_final_result": only_final_result
    }

    try:
        # Send the request to the server
        response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True)
        response.raise_for_status()  # Raise an exception for bad status codes (4xx or 5xx)
        result_json = response.json()
        
        # Extract the result data for the DataFrame
        dataframe_result_data = result_json.get("result", [])
        
        if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0:
            output_df = pd.DataFrame(dataframe_result_data)
        else:
            output_df = pd.DataFrame()
            
        return result_json, output_df

    except requests.exceptions.HTTPError as e:
        error_details = f"API returned an error (Status Code: {e.response.status_code}):\n{e.response.text}"
        raise gr.Error(f"Request failed: {error_details}")
    except requests.exceptions.RequestException as e:
        raise gr.Error(f"Failed to connect to the server: {e}")
    except Exception as e:
        raise gr.Error(f"An unknown error occurred: {e}")


# --- Use Gradio Blocks for a more flexible and beautiful interface ---
with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280px !important; margin: auto;}") as demo:
    gr.Markdown(
        """
        # 🚀 Ml_patch Function Inference Interface
        An interactive web interface designed for the `Ml_patch` function. 
        Please configure the model parameters and enter your test data below.
        """
    )
    
    with gr.Column():
        gr.Markdown("### 1. Configure Parameters")
        with gr.Group():
            gr.Markdown(f"**Current Model (Fixed):** `{FIXED_MODEL_NAME}`")
            only_final_result_input = gr.Checkbox(
                label="Only Return Final Result (only_final_result)",
                value=False,
                info="If checked, the API will only return the final aggregated result, not the detailed information for all layers."
            )
        
        gr.Markdown("### 2. Input Data")
        gr.Markdown("ℹ️ **Instructions**: Click the **✏️ Edit** button at the bottom right of the table, then click **+ Add row** to add a new row, or the trash can icon to delete a row.")
        
        data_input = gr.DataFrame(
            label="Data Table",
            headers=["subject", "relation", "object", "prompt_source"],
            value=DEFAULT_DATAFRAME_VALUE,
            # Set the initial number of rows to 5 to give the table more vertical space on load.
            # (rows_to_display, "dynamic" allows user to add/remove rows)
            row_count=(5, "dynamic"),
            col_count=(4, "fixed"),
            interactive=True
        )
        
        submit_btn = gr.Button("Submit for Inference", variant="primary")

    with gr.Column():
        gr.Markdown("### 3. View Inference Results")
        with gr.Tabs():
            with gr.TabItem("Results Table"):
                output_dataframe = gr.DataFrame(
                    label="Inference Result Details", 
                    interactive=False,
                    headers=["id", "subject", "relation", "object", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
                )
            with gr.TabItem("Raw JSON Response"):
                output_json = gr.JSON(label="Full JSON response from API")
            
    submit_btn.click(
        fn=call_api,
        inputs=[only_final_result_input, data_input],
        outputs=[output_json, output_dataframe]
    )

# Launch the Gradio app
if __name__ == "__main__":
    demo.launch()