Update app.py
Browse files
app.py
CHANGED
|
@@ -42,7 +42,7 @@ def analyze_dfa(input_text):
|
|
| 42 |
vec = get_hidden_state(history)
|
| 43 |
states_vectors.append(vec)
|
| 44 |
|
| 45 |
-
# --- 1. KMeans Graph (
|
| 46 |
num_clusters = min(len(moves), 4)
|
| 47 |
kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors)
|
| 48 |
km_labels = kmeans.labels_
|
|
@@ -51,40 +51,68 @@ def analyze_dfa(input_text):
|
|
| 51 |
for i in range(len(moves)-1):
|
| 52 |
G_km.add_edge(f"S{km_labels[i]}", f"S{km_labels[i+1]}", label=moves[i+1])
|
| 53 |
|
| 54 |
-
plt.figure(figsize=(
|
| 55 |
-
plt.subplot(1, 2, 1)
|
| 56 |
pos_km = nx.spring_layout(G_km)
|
| 57 |
-
nx.draw(G_km, pos_km, with_labels=True, node_color='lightblue', node_size=
|
| 58 |
-
nx.draw_networkx_edge_labels(G_km, pos_km, edge_labels=nx.get_edge_attributes(G_km, 'label'))
|
| 59 |
-
plt.title("
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
# --- 2. Linear Probe
|
| 62 |
-
logger.info("📐 Running Linear Probe (PCA) to find the 'Spatial Axis'...")
|
| 63 |
pca = PCA(n_components=2)
|
| 64 |
coords = pca.fit_transform(states_vectors)
|
| 65 |
|
| 66 |
-
plt.
|
| 67 |
-
plt.scatter(coords[:, 0], coords[:, 1], c=range(len(moves)), cmap='viridis', s=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
for i, move in enumerate(moves):
|
| 69 |
-
plt.annotate(f"{i}:{move}", (coords[i, 0], coords[i, 1])
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
|
| 74 |
-
plt.savefig(
|
| 75 |
plt.close()
|
| 76 |
|
| 77 |
-
return
|
| 78 |
|
| 79 |
-
#
|
| 80 |
-
with gr.Blocks() as demo:
|
| 81 |
-
gr.Markdown("#
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
with gr.Row():
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
demo.launch()
|
|
|
|
| 42 |
vec = get_hidden_state(history)
|
| 43 |
states_vectors.append(vec)
|
| 44 |
|
| 45 |
+
# --- 1. KMeans Graph (Discrete State Machine) ---
|
| 46 |
num_clusters = min(len(moves), 4)
|
| 47 |
kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors)
|
| 48 |
km_labels = kmeans.labels_
|
|
|
|
| 51 |
for i in range(len(moves)-1):
|
| 52 |
G_km.add_edge(f"S{km_labels[i]}", f"S{km_labels[i+1]}", label=moves[i+1])
|
| 53 |
|
| 54 |
+
plt.figure(figsize=(8, 6))
|
|
|
|
| 55 |
pos_km = nx.spring_layout(G_km)
|
| 56 |
+
nx.draw(G_km, pos_km, with_labels=True, node_color='lightblue', node_size=2500, font_size=12, font_weight='bold')
|
| 57 |
+
nx.draw_networkx_edge_labels(G_km, pos_km, edge_labels=nx.get_edge_attributes(G_km, 'label'), font_size=10)
|
| 58 |
+
plt.title("Logical State Machine (KMeans)")
|
| 59 |
+
km_plot = "km_plot.png"
|
| 60 |
+
plt.savefig(km_plot, dpi=150)
|
| 61 |
+
plt.close()
|
| 62 |
|
| 63 |
+
# --- 2. Linear Probe PCA (Geometric State Machine) ---
|
|
|
|
| 64 |
pca = PCA(n_components=2)
|
| 65 |
coords = pca.fit_transform(states_vectors)
|
| 66 |
|
| 67 |
+
plt.figure(figsize=(10, 8)) # Increased size for better visibility
|
| 68 |
+
plt.scatter(coords[:, 0], coords[:, 1], c=range(len(moves)), cmap='viridis', s=200, edgecolors='black')
|
| 69 |
+
|
| 70 |
+
# Drawing arrows between coordinates (The Linear Probe "State Machine")
|
| 71 |
+
for i in range(len(moves)-1):
|
| 72 |
+
plt.arrow(coords[i, 0], coords[i, 1],
|
| 73 |
+
coords[i+1, 0] - coords[i, 0],
|
| 74 |
+
coords[i+1, 1] - coords[i, 1],
|
| 75 |
+
head_width=5, length_includes_head=True, alpha=0.5, color='gray')
|
| 76 |
+
|
| 77 |
for i, move in enumerate(moves):
|
| 78 |
+
plt.annotate(f"Step {i}: {move}", (coords[i, 0], coords[i, 1]),
|
| 79 |
+
xytext=(5, 5), textcoords='offset points', fontsize=9, fontweight='bold')
|
| 80 |
+
|
| 81 |
+
plt.grid(True, linestyle='--', alpha=0.6)
|
| 82 |
+
plt.title("Geometric State Machine (Linear Probe PCA)")
|
| 83 |
+
plt.xlabel("Principal Component 1 (Primary Axis of Variance)")
|
| 84 |
+
plt.ylabel("Principal Component 2 (Secondary Axis of Variance)")
|
| 85 |
|
| 86 |
+
pca_plot = "pca_plot.png"
|
| 87 |
+
plt.savefig(pca_plot, dpi=150)
|
| 88 |
plt.close()
|
| 89 |
|
| 90 |
+
return km_plot, pca_plot, f"Labels: {km_labels}", log_capture.getvalue()
|
| 91 |
|
| 92 |
+
# Gradio Interface with Separated Columns
|
| 93 |
+
with gr.Blocks(title="World Model Hybrid Probe") as demo:
|
| 94 |
+
gr.Markdown("# 🛰️ World Model Hybrid Probe")
|
| 95 |
+
gr.Markdown("Comparing **Logical Categorization** (KMeans) vs **Spatial Intuition** (Linear PCA).")
|
| 96 |
+
|
| 97 |
+
with gr.Row():
|
| 98 |
+
input_box = gr.Textbox(label="Input Moves", placeholder="Up, Up, Right, Left", scale=4)
|
| 99 |
+
submit_btn = gr.Button("Analyze", variant="primary", scale=1)
|
| 100 |
+
|
| 101 |
with gr.Row():
|
| 102 |
+
# Box 1: Logic
|
| 103 |
+
with gr.Column(variant="panel"):
|
| 104 |
+
gr.Markdown("### 1. Discrete State Logic (DFA)")
|
| 105 |
+
output_km = gr.Image(label="KMeans DFA", type="filepath")
|
| 106 |
+
analysis_text = gr.Textbox(label="Cluster Labels", interactive=False)
|
| 107 |
+
|
| 108 |
+
# Box 2: Geometry (The Clearer Linear Probe)
|
| 109 |
+
with gr.Column(variant="panel"):
|
| 110 |
+
gr.Markdown("### 2. Geometric Trajectory (Linear Probe)")
|
| 111 |
+
output_pca = gr.Image(label="Spatial PCA Map", type="filepath")
|
| 112 |
+
gr.Markdown("*This map shows the 'Mental Path' GPT-2 takes through its vector space.*")
|
| 113 |
+
|
| 114 |
+
log_box = gr.Textbox(label="Probe Logs", lines=5, interactive=False)
|
| 115 |
+
|
| 116 |
+
submit_btn.click(analyze_dfa, input_box, [output_km, output_pca, analysis_text, log_box])
|
| 117 |
|
| 118 |
demo.launch()
|