|
|
import torch |
|
|
import gradio as gr |
|
|
import networkx as nx |
|
|
import matplotlib.pyplot as plt |
|
|
import logging |
|
|
import io |
|
|
import numpy as np |
|
|
from transformers import GPT2Model, GPT2Tokenizer |
|
|
from sklearn.cluster import KMeans |
|
|
from sklearn.decomposition import PCA |
|
|
|
|
|
|
|
|
log_capture = io.StringIO() |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger("DFA_Probe") |
|
|
handler = logging.StreamHandler(log_capture) |
|
|
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) |
|
|
logger.addHandler(handler) |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model_name = "gpt2" |
|
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name) |
|
|
model = GPT2Model.from_pretrained(model_name).to(device) |
|
|
|
|
|
def get_hidden_state(sequence_str): |
|
|
inputs = tokenizer(sequence_str, return_tensors="pt").to(device) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs, output_hidden_states=True) |
|
|
return outputs.hidden_states[-1][0, -1, :].cpu().numpy() |
|
|
|
|
|
def analyze_dfa(input_text): |
|
|
log_capture.truncate(0) |
|
|
log_capture.seek(0) |
|
|
|
|
|
moves = [m.strip() for m in input_text.split(",")] |
|
|
history = "" |
|
|
states_vectors = [] |
|
|
|
|
|
for i, move in enumerate(moves): |
|
|
history += f" Move {move}." |
|
|
vec = get_hidden_state(history) |
|
|
states_vectors.append(vec) |
|
|
|
|
|
|
|
|
num_clusters = min(len(moves), 4) |
|
|
kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors) |
|
|
km_labels = kmeans.labels_ |
|
|
|
|
|
G_km = nx.DiGraph() |
|
|
for i in range(len(moves)-1): |
|
|
G_km.add_edge(f"S{km_labels[i]}", f"S{km_labels[i+1]}", label=moves[i+1]) |
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
|
pos_km = nx.spring_layout(G_km) |
|
|
nx.draw(G_km, pos_km, with_labels=True, node_color='lightblue', node_size=2500, font_size=12, font_weight='bold') |
|
|
nx.draw_networkx_edge_labels(G_km, pos_km, edge_labels=nx.get_edge_attributes(G_km, 'label'), font_size=10) |
|
|
plt.title("Logical State Machine (KMeans)") |
|
|
km_plot = "km_plot.png" |
|
|
plt.savefig(km_plot, dpi=150) |
|
|
plt.close() |
|
|
|
|
|
|
|
|
pca = PCA(n_components=2) |
|
|
coords = pca.fit_transform(states_vectors) |
|
|
|
|
|
plt.figure(figsize=(10, 8)) |
|
|
plt.scatter(coords[:, 0], coords[:, 1], c=range(len(moves)), cmap='viridis', s=200, edgecolors='black') |
|
|
|
|
|
|
|
|
for i in range(len(moves)-1): |
|
|
plt.arrow(coords[i, 0], coords[i, 1], |
|
|
coords[i+1, 0] - coords[i, 0], |
|
|
coords[i+1, 1] - coords[i, 1], |
|
|
head_width=5, length_includes_head=True, alpha=0.5, color='gray') |
|
|
|
|
|
for i, move in enumerate(moves): |
|
|
plt.annotate(f"Step {i}: {move}", (coords[i, 0], coords[i, 1]), |
|
|
xytext=(5, 5), textcoords='offset points', fontsize=9, fontweight='bold') |
|
|
|
|
|
plt.grid(True, linestyle='--', alpha=0.6) |
|
|
plt.title("Geometric State Machine (Linear Probe PCA)") |
|
|
plt.xlabel("Principal Component 1 (Primary Axis of Variance)") |
|
|
plt.ylabel("Principal Component 2 (Secondary Axis of Variance)") |
|
|
|
|
|
pca_plot = "pca_plot.png" |
|
|
plt.savefig(pca_plot, dpi=150) |
|
|
plt.close() |
|
|
|
|
|
return km_plot, pca_plot, f"Labels: {km_labels}", log_capture.getvalue() |
|
|
|
|
|
|
|
|
with gr.Blocks(title="World Model Hybrid Probe") as demo: |
|
|
gr.Markdown("# 🛰️ World Model Hybrid Probe") |
|
|
gr.Markdown("Comparing **Logical Categorization** (KMeans) vs **Spatial Intuition** (Linear PCA).") |
|
|
|
|
|
with gr.Row(): |
|
|
input_box = gr.Textbox(label="Input Moves", placeholder="Up, Up, Right, Left", scale=4) |
|
|
submit_btn = gr.Button("Analyze", variant="primary", scale=1) |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(variant="panel"): |
|
|
gr.Markdown("### 1. Discrete State Logic (DFA)") |
|
|
output_km = gr.Image(label="KMeans DFA", type="filepath") |
|
|
analysis_text = gr.Textbox(label="Cluster Labels", interactive=False) |
|
|
|
|
|
|
|
|
with gr.Column(variant="panel"): |
|
|
gr.Markdown("### 2. Geometric Trajectory (Linear Probe)") |
|
|
output_pca = gr.Image(label="Spatial PCA Map", type="filepath") |
|
|
gr.Markdown("*This map shows the 'Mental Path' GPT-2 takes through its vector space.*") |
|
|
|
|
|
log_box = gr.Textbox(label="Probe Logs", lines=5, interactive=False) |
|
|
|
|
|
submit_btn.click(analyze_dfa, input_box, [output_km, output_pca, analysis_text, log_box]) |
|
|
|
|
|
demo.launch() |