File size: 4,680 Bytes
a40cb3b
 
 
 
68d7012
 
8113b74
a40cb3b
 
8113b74
a40cb3b
8113b74
68d7012
 
 
 
 
 
 
8113b74
a40cb3b
68d7012
a40cb3b
 
 
 
 
 
 
 
 
 
68d7012
 
 
a40cb3b
 
 
 
68d7012
a40cb3b
 
 
 
c6dce4f
a40cb3b
 
8113b74
a40cb3b
8113b74
a40cb3b
8113b74
a40cb3b
c6dce4f
8113b74
c6dce4f
 
 
 
 
 
8113b74
c6dce4f
8113b74
 
 
c6dce4f
 
 
 
 
 
 
 
 
 
8113b74
c6dce4f
 
 
 
 
 
 
a40cb3b
c6dce4f
 
68d7012
 
c6dce4f
68d7012
c6dce4f
 
 
 
 
 
 
 
 
68d7012
c6dce4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a40cb3b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import gradio as gr
import networkx as nx
import matplotlib.pyplot as plt
import logging
import io
import numpy as np
from transformers import GPT2Model, GPT2Tokenizer
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

# Setup Logging
log_capture = io.StringIO()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("DFA_Probe")
handler = logging.StreamHandler(log_capture)
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)

# Load GPT-2
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2Model.from_pretrained(model_name).to(device)

def get_hidden_state(sequence_str):
    inputs = tokenizer(sequence_str, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
    return outputs.hidden_states[-1][0, -1, :].cpu().numpy()

def analyze_dfa(input_text):
    log_capture.truncate(0)
    log_capture.seek(0)
    
    moves = [m.strip() for m in input_text.split(",")]
    history = ""
    states_vectors = []
    
    for i, move in enumerate(moves):
        history += f" Move {move}."
        vec = get_hidden_state(history)
        states_vectors.append(vec)
    
    # --- 1. KMeans Graph (Discrete State Machine) ---
    num_clusters = min(len(moves), 4)
    kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors)
    km_labels = kmeans.labels_
    
    G_km = nx.DiGraph()
    for i in range(len(moves)-1):
        G_km.add_edge(f"S{km_labels[i]}", f"S{km_labels[i+1]}", label=moves[i+1])
    
    plt.figure(figsize=(8, 6))
    pos_km = nx.spring_layout(G_km)
    nx.draw(G_km, pos_km, with_labels=True, node_color='lightblue', node_size=2500, font_size=12, font_weight='bold')
    nx.draw_networkx_edge_labels(G_km, pos_km, edge_labels=nx.get_edge_attributes(G_km, 'label'), font_size=10)
    plt.title("Logical State Machine (KMeans)")
    km_plot = "km_plot.png"
    plt.savefig(km_plot, dpi=150)
    plt.close()

    # --- 2. Linear Probe PCA (Geometric State Machine) ---
    pca = PCA(n_components=2)
    coords = pca.fit_transform(states_vectors)
    
    plt.figure(figsize=(10, 8)) # Increased size for better visibility
    plt.scatter(coords[:, 0], coords[:, 1], c=range(len(moves)), cmap='viridis', s=200, edgecolors='black')
    
    # Drawing arrows between coordinates (The Linear Probe "State Machine")
    for i in range(len(moves)-1):
        plt.arrow(coords[i, 0], coords[i, 1], 
                  coords[i+1, 0] - coords[i, 0], 
                  coords[i+1, 1] - coords[i, 1], 
                  head_width=5, length_includes_head=True, alpha=0.5, color='gray')
    
    for i, move in enumerate(moves):
        plt.annotate(f"Step {i}: {move}", (coords[i, 0], coords[i, 1]), 
                     xytext=(5, 5), textcoords='offset points', fontsize=9, fontweight='bold')
    
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.title("Geometric State Machine (Linear Probe PCA)")
    plt.xlabel("Principal Component 1 (Primary Axis of Variance)")
    plt.ylabel("Principal Component 2 (Secondary Axis of Variance)")
    
    pca_plot = "pca_plot.png"
    plt.savefig(pca_plot, dpi=150)
    plt.close()
    
    return km_plot, pca_plot, f"Labels: {km_labels}", log_capture.getvalue()

# Gradio Interface with Separated Columns
with gr.Blocks(title="World Model Hybrid Probe") as demo:
    gr.Markdown("# 🛰️ World Model Hybrid Probe")
    gr.Markdown("Comparing **Logical Categorization** (KMeans) vs **Spatial Intuition** (Linear PCA).")
    
    with gr.Row():
        input_box = gr.Textbox(label="Input Moves", placeholder="Up, Up, Right, Left", scale=4)
        submit_btn = gr.Button("Analyze", variant="primary", scale=1)
    
    with gr.Row():
        # Box 1: Logic
        with gr.Column(variant="panel"):
            gr.Markdown("### 1. Discrete State Logic (DFA)")
            output_km = gr.Image(label="KMeans DFA", type="filepath")
            analysis_text = gr.Textbox(label="Cluster Labels", interactive=False)
        
        # Box 2: Geometry (The Clearer Linear Probe)
        with gr.Column(variant="panel"):
            gr.Markdown("### 2. Geometric Trajectory (Linear Probe)")
            output_pca = gr.Image(label="Spatial PCA Map", type="filepath")
            gr.Markdown("*This map shows the 'Mental Path' GPT-2 takes through its vector space.*")

    log_box = gr.Textbox(label="Probe Logs", lines=5, interactive=False)

    submit_btn.click(analyze_dfa, input_box, [output_km, output_pca, analysis_text, log_box])

demo.launch()