File size: 6,144 Bytes
083841f
 
10705b6
86e2749
 
083841f
f63b31c
5d33dd6
f63b31c
083841f
f63b31c
a628b4c
f63b31c
9b2d6e3
f63b31c
86e2749
 
 
083841f
16d6124
 
083841f
f63b31c
 
083841f
86e2749
f63b31c
10705b6
f63b31c
86e2749
 
 
f63b31c
86e2749
 
 
 
16d6124
083841f
10705b6
16d6124
86e2749
083841f
10705b6
083841f
86e2749
16d6124
86e2749
 
 
 
 
 
10705b6
041cfb4
86e2749
 
083841f
86e2749
f63b31c
 
10705b6
f63b31c
86e2749
f63b31c
86e2749
10705b6
f63b31c
86e2749
 
 
f63b31c
 
 
86e2749
 
083841f
9b2d6e3
f63b31c
9b2d6e3
f63b31c
16d6124
 
 
 
ac28d4d
16d6124
 
 
 
9b2d6e3
f63b31c
9b2d6e3
16d6124
083841f
9b2d6e3
f63b31c
 
3d60ca8
 
9b2d6e3
f63b31c
9b2d6e3
 
a57465e
9b2d6e3
a57465e
9b2d6e3
 
f63b31c
10705b6
9b2d6e3
f63b31c
3d60ca8
 
 
9b2d6e3
f63b31c
9b2d6e3
f63b31c
9b2d6e3
16d6124
9b2d6e3
 
f63b31c
 
083841f
16d6124
 
083841f
10705b6
16d6124
10705b6
083841f
 
f63b31c
10705b6
16d6124
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
import requests
import json
import pandas as pd
import io

# ----------------- Configure Your Server API Address -----------------
SERVER_URL = "http://103.235.229.133:7000/predict"
# ---------------------------------------------------------------------

# ----------------- Fixed Model Name -----------------
FIXED_MODEL_NAME = "Qwen2.5-1.5B-Instruct"
# ----------------------------------------------------

# Default input data for easy testing
DEFAULT_DATAFRAME_VALUE = [
    ["France", "capital city of", "Paris", "she traveled to France for the first time"],
]

# --- KEY CHANGE 1: Update the function signature to accept the new parameter ---
def call_api(patched_layer_number, only_final_result, data_df):
    """
    This function is called by Gradio. It collects inputs from the UI,
    formats them for the API, and sends the request.
    """
    if data_df is None or data_df.empty:
        raise gr.Error("Input data cannot be empty! Please add at least one row to the table.")

    # Create a copy and add an 'id' column
    data_df_with_id = data_df.copy()
    data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))])
    
    # Convert DataFrame to CSV string
    output = io.StringIO()
    data_df_with_id.to_csv(output, index=False)
    csv_data = output.getvalue()
    
    # --- KEY CHANGE 2: Add the new parameter to the API payload ---
    payload = {
        "data_csv": csv_data,
        "patched_layer_number": patched_layer_number,
        "only_final_result": only_final_result
    }

    try:
        response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True)
        response.raise_for_status()
        result_json = response.json()
        
        dataframe_result_data = result_json.get("result", [])
        
        if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0:
            output_df = pd.DataFrame(dataframe_result_data)
        else:
            output_df = pd.DataFrame()
            
        return result_json, output_df

    except requests.exceptions.HTTPError as e:
        error_details = f"API returned an error (Status Code: {e.response.status_code}):\n{e.response.text}"
        raise gr.Error(f"Request failed: {error_details}")
    except requests.exceptions.RequestException as e:
        raise gr.Error(f"Failed to connect to the server: {e}")
    except Exception as e:
        raise gr.Error(f"An unknown error occurred: {e}")


# --- Use Gradio Blocks for a more flexible and beautiful interface ---
with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280px !important; margin: auto;}") as demo:
    gr.Markdown(
        """
        # 🚀 Ml_patch Function Inference Interface
        An interactive web interface designed for the `Ml_patch` function. 
        Please configure the model parameters and enter your test data below.
        """
    )
    
    with gr.Column():
        gr.Markdown("### 1. Configure Parameters")
        with gr.Group():
            gr.Markdown(f"**Current Model (Fixed):** `{FIXED_MODEL_NAME}`")
            
            # --- KEY CHANGE 3: Add the new Gradio input component ---
            patched_layer_number_input = gr.Number(
                label="Patched Layer Number (patched_layer_number)",
                value=3,  # Set a reasonable default value
                precision=0,  # This ensures the input is an integer
                info="Specify the target layer number for patching."
            )
            
            only_final_result_input = gr.Checkbox(
                label="Only Return Final Result (only_final_result)",
                value=False,
                info="If checked, the API will only return the final aggregated result, not detailed info for all layers."
            )
        
        gr.Markdown("### 2. Input Data")
        gr.Markdown("ℹ️ **Instructions**: Click the **✏️ Edit** button at the bottom right of the table, then click **+ Add row** to add a new row, or the trash can icon to delete a row.")
        gr.Markdown("**[subject,relation,object]**: A factual triple.")
        gr.Markdown("**prompt_source**: A sentence which contains subject.")
        data_input = gr.DataFrame(
            label="Data Table",
            headers=["subject", "relation", "object", "prompt_source"],
            value=DEFAULT_DATAFRAME_VALUE,
            row_count=(5, "dynamic"),
            col_count=(4, "fixed"),
            interactive=True
        )
        
        submit_btn = gr.Button("Submit for Inference", variant="primary")

    with gr.Column():
        gr.Markdown("### 3. View Inference Results")
        gr.Markdown("**layer_source**: Starting from this layer of **prompt_source**, we extracted the hidden states corresponding to the subject.")
        gr.Markdown("**is_correct_patched**: Whether the final answer is correct. If **object** is in **generations**, we think the final answer is correct.")
        gr.Markdown("**generations**: The output obtained after the model continues inference following a patch operation.")
        with gr.Tabs():
            with gr.TabItem("Results Table"):
                output_dataframe = gr.DataFrame(
                    label="Inference Result Details", 
                    interactive=False,
                    # You might want to update headers if the output changes based on this parameter
                    headers=["id", "subject", "relation", "object", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
                )
            with gr.TabItem("Raw JSON Response"):
                output_json = gr.JSON(label="Full JSON response from API")
            
    # --- KEY CHANGE 4: Update the click event to include the new input ---
    # The order of inputs here MUST match the order of arguments in the `call_api` function.
    submit_btn.click(
        fn=call_api,
        inputs=[patched_layer_number_input, only_final_result_input, data_input],
        outputs=[output_json, output_dataframe]
    )

# Launch the Gradio app
if __name__ == "__main__":
    demo.launch()