Phoenix21 commited on
Commit
c6dce4f
·
verified ·
1 Parent(s): 8113b74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -25
app.py CHANGED
@@ -42,7 +42,7 @@ def analyze_dfa(input_text):
42
  vec = get_hidden_state(history)
43
  states_vectors.append(vec)
44
 
45
- # --- 1. KMeans Graph (Unsupervised State Map) ---
46
  num_clusters = min(len(moves), 4)
47
  kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors)
48
  km_labels = kmeans.labels_
@@ -51,40 +51,68 @@ def analyze_dfa(input_text):
51
  for i in range(len(moves)-1):
52
  G_km.add_edge(f"S{km_labels[i]}", f"S{km_labels[i+1]}", label=moves[i+1])
53
 
54
- plt.figure(figsize=(12, 5))
55
- plt.subplot(1, 2, 1)
56
  pos_km = nx.spring_layout(G_km)
57
- nx.draw(G_km, pos_km, with_labels=True, node_color='lightblue', node_size=1500)
58
- nx.draw_networkx_edge_labels(G_km, pos_km, edge_labels=nx.get_edge_attributes(G_km, 'label'))
59
- plt.title("KMeans DFA (State-Based)")
 
 
 
60
 
61
- # --- 2. Linear Probe / PCA (Geometric Map) ---
62
- logger.info("📐 Running Linear Probe (PCA) to find the 'Spatial Axis'...")
63
  pca = PCA(n_components=2)
64
  coords = pca.fit_transform(states_vectors)
65
 
66
- plt.subplot(1, 2, 2)
67
- plt.scatter(coords[:, 0], coords[:, 1], c=range(len(moves)), cmap='viridis', s=100)
 
 
 
 
 
 
 
 
68
  for i, move in enumerate(moves):
69
- plt.annotate(f"{i}:{move}", (coords[i, 0], coords[i, 1]))
70
- plt.plot(coords[:, 0], coords[:, 1], 'r--', alpha=0.3) # Path line
71
- plt.title("Linear Probe (Spatial Projection)")
 
 
 
 
72
 
73
- plot_path = "comparison_plot.png"
74
- plt.savefig(plot_path)
75
  plt.close()
76
 
77
- return plot_path, f"KMeans Labels: {km_labels}", log_capture.getvalue()
78
 
79
- # Launching with dual display
80
- with gr.Blocks() as demo:
81
- gr.Markdown("# KMeans vs. Linear Probe Analysis")
82
- input_box = gr.Textbox(label="Moves (Right, Left...)")
83
- submit_btn = gr.Button("Compare")
 
 
 
 
84
  with gr.Row():
85
- output_img = gr.Image(label="KMeans (Left) vs Linear PCA (Right)")
86
- analysis_text = gr.Textbox(label="Mapping Results")
87
- log_box = gr.Textbox(label="Probe Logs", lines=5)
88
- submit_btn.click(analyze_dfa, input_box, [output_img, analysis_text, log_box])
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  demo.launch()
 
42
  vec = get_hidden_state(history)
43
  states_vectors.append(vec)
44
 
45
+ # --- 1. KMeans Graph (Discrete State Machine) ---
46
  num_clusters = min(len(moves), 4)
47
  kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors)
48
  km_labels = kmeans.labels_
 
51
  for i in range(len(moves)-1):
52
  G_km.add_edge(f"S{km_labels[i]}", f"S{km_labels[i+1]}", label=moves[i+1])
53
 
54
+ plt.figure(figsize=(8, 6))
 
55
  pos_km = nx.spring_layout(G_km)
56
+ nx.draw(G_km, pos_km, with_labels=True, node_color='lightblue', node_size=2500, font_size=12, font_weight='bold')
57
+ nx.draw_networkx_edge_labels(G_km, pos_km, edge_labels=nx.get_edge_attributes(G_km, 'label'), font_size=10)
58
+ plt.title("Logical State Machine (KMeans)")
59
+ km_plot = "km_plot.png"
60
+ plt.savefig(km_plot, dpi=150)
61
+ plt.close()
62
 
63
+ # --- 2. Linear Probe PCA (Geometric State Machine) ---
 
64
  pca = PCA(n_components=2)
65
  coords = pca.fit_transform(states_vectors)
66
 
67
+ plt.figure(figsize=(10, 8)) # Increased size for better visibility
68
+ plt.scatter(coords[:, 0], coords[:, 1], c=range(len(moves)), cmap='viridis', s=200, edgecolors='black')
69
+
70
+ # Drawing arrows between coordinates (The Linear Probe "State Machine")
71
+ for i in range(len(moves)-1):
72
+ plt.arrow(coords[i, 0], coords[i, 1],
73
+ coords[i+1, 0] - coords[i, 0],
74
+ coords[i+1, 1] - coords[i, 1],
75
+ head_width=5, length_includes_head=True, alpha=0.5, color='gray')
76
+
77
  for i, move in enumerate(moves):
78
+ plt.annotate(f"Step {i}: {move}", (coords[i, 0], coords[i, 1]),
79
+ xytext=(5, 5), textcoords='offset points', fontsize=9, fontweight='bold')
80
+
81
+ plt.grid(True, linestyle='--', alpha=0.6)
82
+ plt.title("Geometric State Machine (Linear Probe PCA)")
83
+ plt.xlabel("Principal Component 1 (Primary Axis of Variance)")
84
+ plt.ylabel("Principal Component 2 (Secondary Axis of Variance)")
85
 
86
+ pca_plot = "pca_plot.png"
87
+ plt.savefig(pca_plot, dpi=150)
88
  plt.close()
89
 
90
+ return km_plot, pca_plot, f"Labels: {km_labels}", log_capture.getvalue()
91
 
92
+ # Gradio Interface with Separated Columns
93
+ with gr.Blocks(title="World Model Hybrid Probe") as demo:
94
+ gr.Markdown("# 🛰️ World Model Hybrid Probe")
95
+ gr.Markdown("Comparing **Logical Categorization** (KMeans) vs **Spatial Intuition** (Linear PCA).")
96
+
97
+ with gr.Row():
98
+ input_box = gr.Textbox(label="Input Moves", placeholder="Up, Up, Right, Left", scale=4)
99
+ submit_btn = gr.Button("Analyze", variant="primary", scale=1)
100
+
101
  with gr.Row():
102
+ # Box 1: Logic
103
+ with gr.Column(variant="panel"):
104
+ gr.Markdown("### 1. Discrete State Logic (DFA)")
105
+ output_km = gr.Image(label="KMeans DFA", type="filepath")
106
+ analysis_text = gr.Textbox(label="Cluster Labels", interactive=False)
107
+
108
+ # Box 2: Geometry (The Clearer Linear Probe)
109
+ with gr.Column(variant="panel"):
110
+ gr.Markdown("### 2. Geometric Trajectory (Linear Probe)")
111
+ output_pca = gr.Image(label="Spatial PCA Map", type="filepath")
112
+ gr.Markdown("*This map shows the 'Mental Path' GPT-2 takes through its vector space.*")
113
+
114
+ log_box = gr.Textbox(label="Probe Logs", lines=5, interactive=False)
115
+
116
+ submit_btn.click(analyze_dfa, input_box, [output_km, output_pca, analysis_text, log_box])
117
 
118
  demo.launch()