codelion commited on
Commit
dd018ef
·
verified ·
1 Parent(s): 4f48e62

Upload 2 files

Browse files
Files changed (1) hide show
  1. app.py +32 -24
app.py CHANGED
@@ -933,6 +933,26 @@ def get_token_details(idx: int) -> Tuple[str, go.Figure]:
933
  return html, chart
934
 
935
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
936
  def update_graph_visualization(query_dropdown: str = None):
937
  """Update the thought anchor graph."""
938
  dataset_type = current_data.get("type", "unknown")
@@ -945,7 +965,10 @@ def update_graph_visualization(query_dropdown: str = None):
945
  )
946
  fig.update_layout(template="plotly_dark", height=400)
947
  return fig
948
- return create_thought_anchor_graph(current_data["df"], query_dropdown)
 
 
 
949
 
950
 
951
  def update_embedding_visualization(color_by: str):
@@ -992,16 +1015,14 @@ def get_query_list():
992
  return gr.update(choices=[], value=None)
993
 
994
  queries = df['query'].unique().tolist()
995
- # Return tuples of (truncated_label, full_value) for dropdown
996
- # Gradio will show the label but pass the value
997
  truncated_queries = []
998
  for i, q in enumerate(queries):
999
  q_str = str(q) if q is not None else ""
1000
  if len(q_str) > 80:
1001
- label = f"[{i+1}] {q_str[:77]}..."
1002
  else:
1003
- label = f"[{i+1}] {q_str}"
1004
- truncated_queries.append((label, q_str))
1005
 
1006
  return gr.update(choices=truncated_queries, value=None)
1007
 
@@ -1043,26 +1064,13 @@ HF_DATASETS = [
1043
  "codelion/DeepSeek-R1-Distill-Qwen-1.5B-pts-steering-vectors",
1044
  ]
1045
 
1046
- # Theme and CSS configuration
1047
- THEME = gr.themes.Soft(
1048
- primary_hue="indigo",
1049
- secondary_hue="emerald",
1050
- neutral_hue="slate"
1051
- )
1052
  CSS = """
1053
  .gradio-container { max-width: 1400px !important; }
1054
  .main-header { text-align: center; margin-bottom: 20px; }
1055
  """
1056
 
1057
- # Use try/except for Gradio version compatibility
1058
- try:
1059
- # Gradio 4.x style
1060
- demo_context = gr.Blocks(title="PTS Visualizer", theme=THEME, css=CSS)
1061
- except TypeError:
1062
- # Gradio 6.x style (theme/css moved to launch)
1063
- demo_context = gr.Blocks(title="PTS Visualizer")
1064
-
1065
- with demo_context as demo:
1066
 
1067
  # Header
1068
  gr.Markdown("""
@@ -1088,8 +1096,8 @@ with demo_context as demo:
1088
  with gr.Column(scale=3):
1089
  dataset_dropdown = gr.Dropdown(
1090
  choices=HF_DATASETS,
 
1091
  label="Select Dataset",
1092
- allow_custom_value=True,
1093
  info="Choose a pre-defined dataset or enter your own HuggingFace dataset ID"
1094
  )
1095
  with gr.Column(scale=1):
@@ -1139,8 +1147,8 @@ with demo_context as demo:
1139
  with gr.Row():
1140
  query_filter = gr.Dropdown(
1141
  choices=[],
1142
- label="Filter by Query",
1143
- allow_custom_value=True
1144
  )
1145
  graph_plot = gr.Plot()
1146
 
 
933
  return html, chart
934
 
935
 
936
+ def get_original_query_from_label(label: str) -> str:
937
+ """Extract original query from truncated dropdown label like '[1] query...'"""
938
+ if not label or not isinstance(label, str):
939
+ return None
940
+
941
+ df = current_data["df"]
942
+ if df.empty or 'query' not in df.columns:
943
+ return None
944
+
945
+ # Extract index from "[N] query..." format
946
+ match = re.match(r'\[(\d+)\]', label)
947
+ if match:
948
+ idx = int(match.group(1)) - 1 # Convert to 0-based index
949
+ queries = df['query'].unique().tolist()
950
+ if 0 <= idx < len(queries):
951
+ return queries[idx]
952
+
953
+ return None
954
+
955
+
956
  def update_graph_visualization(query_dropdown: str = None):
957
  """Update the thought anchor graph."""
958
  dataset_type = current_data.get("type", "unknown")
 
965
  )
966
  fig.update_layout(template="plotly_dark", height=400)
967
  return fig
968
+
969
+ # Convert truncated label back to original query
970
+ original_query = get_original_query_from_label(query_dropdown)
971
+ return create_thought_anchor_graph(current_data["df"], original_query)
972
 
973
 
974
  def update_embedding_visualization(color_by: str):
 
1015
  return gr.update(choices=[], value=None)
1016
 
1017
  queries = df['query'].unique().tolist()
1018
+ # Return simple truncated strings for dropdown choices
 
1019
  truncated_queries = []
1020
  for i, q in enumerate(queries):
1021
  q_str = str(q) if q is not None else ""
1022
  if len(q_str) > 80:
1023
+ truncated_queries.append(f"[{i+1}] {q_str[:77]}...")
1024
  else:
1025
+ truncated_queries.append(f"[{i+1}] {q_str}")
 
1026
 
1027
  return gr.update(choices=truncated_queries, value=None)
1028
 
 
1064
  "codelion/DeepSeek-R1-Distill-Qwen-1.5B-pts-steering-vectors",
1065
  ]
1066
 
1067
+ # CSS configuration
 
 
 
 
 
1068
  CSS = """
1069
  .gradio-container { max-width: 1400px !important; }
1070
  .main-header { text-align: center; margin-bottom: 20px; }
1071
  """
1072
 
1073
+ with gr.Blocks(title="PTS Visualizer", css=CSS) as demo:
 
 
 
 
 
 
 
 
1074
 
1075
  # Header
1076
  gr.Markdown("""
 
1096
  with gr.Column(scale=3):
1097
  dataset_dropdown = gr.Dropdown(
1098
  choices=HF_DATASETS,
1099
+ value=HF_DATASETS[0],
1100
  label="Select Dataset",
 
1101
  info="Choose a pre-defined dataset or enter your own HuggingFace dataset ID"
1102
  )
1103
  with gr.Column(scale=1):
 
1147
  with gr.Row():
1148
  query_filter = gr.Dropdown(
1149
  choices=[],
1150
+ value=None,
1151
+ label="Filter by Query"
1152
  )
1153
  graph_plot = gr.Plot()
1154