File size: 4,680 Bytes
a40cb3b 68d7012 8113b74 a40cb3b 8113b74 a40cb3b 8113b74 68d7012 8113b74 a40cb3b 68d7012 a40cb3b 68d7012 a40cb3b 68d7012 a40cb3b c6dce4f a40cb3b 8113b74 a40cb3b 8113b74 a40cb3b 8113b74 a40cb3b c6dce4f 8113b74 c6dce4f 8113b74 c6dce4f 8113b74 c6dce4f 8113b74 c6dce4f a40cb3b c6dce4f 68d7012 c6dce4f 68d7012 c6dce4f 68d7012 c6dce4f a40cb3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import torch
import gradio as gr
import networkx as nx
import matplotlib.pyplot as plt
import logging
import io
import numpy as np
from transformers import GPT2Model, GPT2Tokenizer
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
# Setup Logging
log_capture = io.StringIO()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("DFA_Probe")
handler = logging.StreamHandler(log_capture)
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
# Load GPT-2
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2Model.from_pretrained(model_name).to(device)
def get_hidden_state(sequence_str):
inputs = tokenizer(sequence_str, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
return outputs.hidden_states[-1][0, -1, :].cpu().numpy()
def analyze_dfa(input_text):
log_capture.truncate(0)
log_capture.seek(0)
moves = [m.strip() for m in input_text.split(",")]
history = ""
states_vectors = []
for i, move in enumerate(moves):
history += f" Move {move}."
vec = get_hidden_state(history)
states_vectors.append(vec)
# --- 1. KMeans Graph (Discrete State Machine) ---
num_clusters = min(len(moves), 4)
kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors)
km_labels = kmeans.labels_
G_km = nx.DiGraph()
for i in range(len(moves)-1):
G_km.add_edge(f"S{km_labels[i]}", f"S{km_labels[i+1]}", label=moves[i+1])
plt.figure(figsize=(8, 6))
pos_km = nx.spring_layout(G_km)
nx.draw(G_km, pos_km, with_labels=True, node_color='lightblue', node_size=2500, font_size=12, font_weight='bold')
nx.draw_networkx_edge_labels(G_km, pos_km, edge_labels=nx.get_edge_attributes(G_km, 'label'), font_size=10)
plt.title("Logical State Machine (KMeans)")
km_plot = "km_plot.png"
plt.savefig(km_plot, dpi=150)
plt.close()
# --- 2. Linear Probe PCA (Geometric State Machine) ---
pca = PCA(n_components=2)
coords = pca.fit_transform(states_vectors)
plt.figure(figsize=(10, 8)) # Increased size for better visibility
plt.scatter(coords[:, 0], coords[:, 1], c=range(len(moves)), cmap='viridis', s=200, edgecolors='black')
# Drawing arrows between coordinates (The Linear Probe "State Machine")
for i in range(len(moves)-1):
plt.arrow(coords[i, 0], coords[i, 1],
coords[i+1, 0] - coords[i, 0],
coords[i+1, 1] - coords[i, 1],
head_width=5, length_includes_head=True, alpha=0.5, color='gray')
for i, move in enumerate(moves):
plt.annotate(f"Step {i}: {move}", (coords[i, 0], coords[i, 1]),
xytext=(5, 5), textcoords='offset points', fontsize=9, fontweight='bold')
plt.grid(True, linestyle='--', alpha=0.6)
plt.title("Geometric State Machine (Linear Probe PCA)")
plt.xlabel("Principal Component 1 (Primary Axis of Variance)")
plt.ylabel("Principal Component 2 (Secondary Axis of Variance)")
pca_plot = "pca_plot.png"
plt.savefig(pca_plot, dpi=150)
plt.close()
return km_plot, pca_plot, f"Labels: {km_labels}", log_capture.getvalue()
# Gradio Interface with Separated Columns
with gr.Blocks(title="World Model Hybrid Probe") as demo:
gr.Markdown("# 🛰️ World Model Hybrid Probe")
gr.Markdown("Comparing **Logical Categorization** (KMeans) vs **Spatial Intuition** (Linear PCA).")
with gr.Row():
input_box = gr.Textbox(label="Input Moves", placeholder="Up, Up, Right, Left", scale=4)
submit_btn = gr.Button("Analyze", variant="primary", scale=1)
with gr.Row():
# Box 1: Logic
with gr.Column(variant="panel"):
gr.Markdown("### 1. Discrete State Logic (DFA)")
output_km = gr.Image(label="KMeans DFA", type="filepath")
analysis_text = gr.Textbox(label="Cluster Labels", interactive=False)
# Box 2: Geometry (The Clearer Linear Probe)
with gr.Column(variant="panel"):
gr.Markdown("### 2. Geometric Trajectory (Linear Probe)")
output_pca = gr.Image(label="Spatial PCA Map", type="filepath")
gr.Markdown("*This map shows the 'Mental Path' GPT-2 takes through its vector space.*")
log_box = gr.Textbox(label="Probe Logs", lines=5, interactive=False)
submit_btn.click(analyze_dfa, input_box, [output_km, output_pca, analysis_text, log_box])
demo.launch() |