Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -858,21 +858,30 @@ def create_statistics_dashboard(df: pd.DataFrame) -> Tuple[str, go.Figure]:
|
|
| 858 |
current_data = {"df": pd.DataFrame(), "type": "unknown"}
|
| 859 |
|
| 860 |
|
| 861 |
-
def load_dataset_action(source_type: str, dataset_id: str, file_upload)
|
| 862 |
-
"""Handle dataset loading."""
|
| 863 |
global current_data
|
| 864 |
|
| 865 |
if source_type == "HuggingFace Hub":
|
| 866 |
if not dataset_id:
|
| 867 |
-
|
|
|
|
|
|
|
|
|
|
| 868 |
df, msg = load_hf_dataset(dataset_id)
|
| 869 |
else: # Local File
|
| 870 |
if file_upload is None:
|
| 871 |
-
|
|
|
|
|
|
|
|
|
|
| 872 |
df, msg = load_jsonl_file(file_upload.name)
|
| 873 |
|
| 874 |
if df.empty:
|
| 875 |
-
|
|
|
|
|
|
|
|
|
|
| 876 |
|
| 877 |
current_data["df"] = df
|
| 878 |
current_data["type"] = detect_dataset_type(df)
|
|
@@ -881,7 +890,27 @@ def load_dataset_action(source_type: str, dataset_id: str, file_upload) -> Tuple
|
|
| 881 |
if len(df.columns) > 10:
|
| 882 |
columns_info += f" ... and {len(df.columns) - 10} more"
|
| 883 |
|
| 884 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 885 |
|
| 886 |
|
| 887 |
def get_token_details(idx: int) -> Tuple[str, go.Figure]:
|
|
@@ -1180,19 +1209,7 @@ with gr.Blocks(title="PTS Visualizer", css=CSS) as demo:
|
|
| 1180 |
load_btn.click(
|
| 1181 |
fn=load_dataset_action,
|
| 1182 |
inputs=[source_type, dataset_dropdown, file_upload],
|
| 1183 |
-
outputs=[load_status, dataset_info],
|
| 1184 |
-
api_name=False
|
| 1185 |
-
).then(
|
| 1186 |
-
fn=refresh_all,
|
| 1187 |
-
outputs=[stats_html, stats_chart, graph_plot, embed_plot, circuit_html, circuit_chart],
|
| 1188 |
-
api_name=False
|
| 1189 |
-
).then(
|
| 1190 |
-
fn=lambda: gr.update(maximum=max(0, len(current_data["df"]) - 1)),
|
| 1191 |
-
outputs=[token_slider],
|
| 1192 |
-
api_name=False
|
| 1193 |
-
).then(
|
| 1194 |
-
fn=get_query_list,
|
| 1195 |
-
outputs=[query_filter],
|
| 1196 |
api_name=False
|
| 1197 |
)
|
| 1198 |
|
|
|
|
| 858 |
current_data = {"df": pd.DataFrame(), "type": "unknown"}
|
| 859 |
|
| 860 |
|
| 861 |
+
def load_dataset_action(source_type: str, dataset_id: str, file_upload):
|
| 862 |
+
"""Handle dataset loading and return all visualization updates."""
|
| 863 |
global current_data
|
| 864 |
|
| 865 |
if source_type == "HuggingFace Hub":
|
| 866 |
if not dataset_id:
|
| 867 |
+
empty_fig = go.Figure()
|
| 868 |
+
empty_fig.update_layout(template="plotly_dark")
|
| 869 |
+
return ("Please enter a dataset ID", "", "No data", empty_fig, empty_fig, empty_fig, "No data", empty_fig,
|
| 870 |
+
gr.update(maximum=0), gr.update(choices=[], value=None))
|
| 871 |
df, msg = load_hf_dataset(dataset_id)
|
| 872 |
else: # Local File
|
| 873 |
if file_upload is None:
|
| 874 |
+
empty_fig = go.Figure()
|
| 875 |
+
empty_fig.update_layout(template="plotly_dark")
|
| 876 |
+
return ("Please upload a file", "", "No data", empty_fig, empty_fig, empty_fig, "No data", empty_fig,
|
| 877 |
+
gr.update(maximum=0), gr.update(choices=[], value=None))
|
| 878 |
df, msg = load_jsonl_file(file_upload.name)
|
| 879 |
|
| 880 |
if df.empty:
|
| 881 |
+
empty_fig = go.Figure()
|
| 882 |
+
empty_fig.update_layout(template="plotly_dark")
|
| 883 |
+
return (msg, "", "No data", empty_fig, empty_fig, empty_fig, "No data", empty_fig,
|
| 884 |
+
gr.update(maximum=0), gr.update(choices=[], value=None))
|
| 885 |
|
| 886 |
current_data["df"] = df
|
| 887 |
current_data["type"] = detect_dataset_type(df)
|
|
|
|
| 890 |
if len(df.columns) > 10:
|
| 891 |
columns_info += f" ... and {len(df.columns) - 10} more"
|
| 892 |
|
| 893 |
+
# Generate all visualizations
|
| 894 |
+
stats_html, stats_fig = create_statistics_dashboard(df)
|
| 895 |
+
graph_fig = create_thought_anchor_graph(df)
|
| 896 |
+
embed_fig = create_embedding_visualization(df)
|
| 897 |
+
circuit_html, circuit_fig = create_circuit_visualization(df)
|
| 898 |
+
|
| 899 |
+
# Generate query list
|
| 900 |
+
query_choices = []
|
| 901 |
+
if 'query' in df.columns:
|
| 902 |
+
queries = df['query'].unique().tolist()
|
| 903 |
+
for i, q in enumerate(queries):
|
| 904 |
+
q_str = str(q) if q is not None else ""
|
| 905 |
+
if len(q_str) > 80:
|
| 906 |
+
query_choices.append(f"[{i+1}] {q_str[:77]}...")
|
| 907 |
+
else:
|
| 908 |
+
query_choices.append(f"[{i+1}] {q_str}")
|
| 909 |
+
|
| 910 |
+
return (msg, f"Dataset type: {current_data['type']}\n{columns_info}",
|
| 911 |
+
stats_html, stats_fig, graph_fig, embed_fig, circuit_html, circuit_fig,
|
| 912 |
+
gr.update(maximum=max(0, len(df) - 1)),
|
| 913 |
+
gr.update(choices=query_choices, value=None))
|
| 914 |
|
| 915 |
|
| 916 |
def get_token_details(idx: int) -> Tuple[str, go.Figure]:
|
|
|
|
| 1209 |
load_btn.click(
|
| 1210 |
fn=load_dataset_action,
|
| 1211 |
inputs=[source_type, dataset_dropdown, file_upload],
|
| 1212 |
+
outputs=[load_status, dataset_info, stats_html, stats_chart, graph_plot, embed_plot, circuit_html, circuit_chart, token_slider, query_filter],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1213 |
api_name=False
|
| 1214 |
)
|
| 1215 |
|