import torch import gradio as gr import networkx as nx import matplotlib.pyplot as plt import logging import io import numpy as np from transformers import GPT2Model, GPT2Tokenizer from sklearn.cluster import KMeans from sklearn.decomposition import PCA # Setup Logging log_capture = io.StringIO() logging.basicConfig(level=logging.INFO) logger = logging.getLogger("DFA_Probe") handler = logging.StreamHandler(log_capture) handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) logger.addHandler(handler) # Load GPT-2 device = "cuda" if torch.cuda.is_available() else "cpu" model_name = "gpt2" tokenizer = GPT2Tokenizer.from_pretrained(model_name) model = GPT2Model.from_pretrained(model_name).to(device) def get_hidden_state(sequence_str): inputs = tokenizer(sequence_str, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True) return outputs.hidden_states[-1][0, -1, :].cpu().numpy() def analyze_dfa(input_text): log_capture.truncate(0) log_capture.seek(0) moves = [m.strip() for m in input_text.split(",")] history = "" states_vectors = [] for i, move in enumerate(moves): history += f" Move {move}." vec = get_hidden_state(history) states_vectors.append(vec) # --- 1. KMeans Graph (Discrete State Machine) --- num_clusters = min(len(moves), 4) kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors) km_labels = kmeans.labels_ G_km = nx.DiGraph() for i in range(len(moves)-1): G_km.add_edge(f"S{km_labels[i]}", f"S{km_labels[i+1]}", label=moves[i+1]) plt.figure(figsize=(8, 6)) pos_km = nx.spring_layout(G_km) nx.draw(G_km, pos_km, with_labels=True, node_color='lightblue', node_size=2500, font_size=12, font_weight='bold') nx.draw_networkx_edge_labels(G_km, pos_km, edge_labels=nx.get_edge_attributes(G_km, 'label'), font_size=10) plt.title("Logical State Machine (KMeans)") km_plot = "km_plot.png" plt.savefig(km_plot, dpi=150) plt.close() # --- 2. Linear Probe PCA (Geometric State Machine) --- pca = PCA(n_components=2) coords = pca.fit_transform(states_vectors) plt.figure(figsize=(10, 8)) # Increased size for better visibility plt.scatter(coords[:, 0], coords[:, 1], c=range(len(moves)), cmap='viridis', s=200, edgecolors='black') # Drawing arrows between coordinates (The Linear Probe "State Machine") for i in range(len(moves)-1): plt.arrow(coords[i, 0], coords[i, 1], coords[i+1, 0] - coords[i, 0], coords[i+1, 1] - coords[i, 1], head_width=5, length_includes_head=True, alpha=0.5, color='gray') for i, move in enumerate(moves): plt.annotate(f"Step {i}: {move}", (coords[i, 0], coords[i, 1]), xytext=(5, 5), textcoords='offset points', fontsize=9, fontweight='bold') plt.grid(True, linestyle='--', alpha=0.6) plt.title("Geometric State Machine (Linear Probe PCA)") plt.xlabel("Principal Component 1 (Primary Axis of Variance)") plt.ylabel("Principal Component 2 (Secondary Axis of Variance)") pca_plot = "pca_plot.png" plt.savefig(pca_plot, dpi=150) plt.close() return km_plot, pca_plot, f"Labels: {km_labels}", log_capture.getvalue() # Gradio Interface with Separated Columns with gr.Blocks(title="World Model Hybrid Probe") as demo: gr.Markdown("# 🛰️ World Model Hybrid Probe") gr.Markdown("Comparing **Logical Categorization** (KMeans) vs **Spatial Intuition** (Linear PCA).") with gr.Row(): input_box = gr.Textbox(label="Input Moves", placeholder="Up, Up, Right, Left", scale=4) submit_btn = gr.Button("Analyze", variant="primary", scale=1) with gr.Row(): # Box 1: Logic with gr.Column(variant="panel"): gr.Markdown("### 1. Discrete State Logic (DFA)") output_km = gr.Image(label="KMeans DFA", type="filepath") analysis_text = gr.Textbox(label="Cluster Labels", interactive=False) # Box 2: Geometry (The Clearer Linear Probe) with gr.Column(variant="panel"): gr.Markdown("### 2. Geometric Trajectory (Linear Probe)") output_pca = gr.Image(label="Spatial PCA Map", type="filepath") gr.Markdown("*This map shows the 'Mental Path' GPT-2 takes through its vector space.*") log_box = gr.Textbox(label="Probe Logs", lines=5, interactive=False) submit_btn.click(analyze_dfa, input_box, [output_km, output_pca, analysis_text, log_box]) demo.launch()