Phoenix21's picture
Update app.py
c6dce4f verified
import torch
import gradio as gr
import networkx as nx
import matplotlib.pyplot as plt
import logging
import io
import numpy as np
from transformers import GPT2Model, GPT2Tokenizer
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
# Setup Logging
log_capture = io.StringIO()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("DFA_Probe")
handler = logging.StreamHandler(log_capture)
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
# Load GPT-2
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2Model.from_pretrained(model_name).to(device)
def get_hidden_state(sequence_str):
inputs = tokenizer(sequence_str, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
return outputs.hidden_states[-1][0, -1, :].cpu().numpy()
def analyze_dfa(input_text):
log_capture.truncate(0)
log_capture.seek(0)
moves = [m.strip() for m in input_text.split(",")]
history = ""
states_vectors = []
for i, move in enumerate(moves):
history += f" Move {move}."
vec = get_hidden_state(history)
states_vectors.append(vec)
# --- 1. KMeans Graph (Discrete State Machine) ---
num_clusters = min(len(moves), 4)
kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors)
km_labels = kmeans.labels_
G_km = nx.DiGraph()
for i in range(len(moves)-1):
G_km.add_edge(f"S{km_labels[i]}", f"S{km_labels[i+1]}", label=moves[i+1])
plt.figure(figsize=(8, 6))
pos_km = nx.spring_layout(G_km)
nx.draw(G_km, pos_km, with_labels=True, node_color='lightblue', node_size=2500, font_size=12, font_weight='bold')
nx.draw_networkx_edge_labels(G_km, pos_km, edge_labels=nx.get_edge_attributes(G_km, 'label'), font_size=10)
plt.title("Logical State Machine (KMeans)")
km_plot = "km_plot.png"
plt.savefig(km_plot, dpi=150)
plt.close()
# --- 2. Linear Probe PCA (Geometric State Machine) ---
pca = PCA(n_components=2)
coords = pca.fit_transform(states_vectors)
plt.figure(figsize=(10, 8)) # Increased size for better visibility
plt.scatter(coords[:, 0], coords[:, 1], c=range(len(moves)), cmap='viridis', s=200, edgecolors='black')
# Drawing arrows between coordinates (The Linear Probe "State Machine")
for i in range(len(moves)-1):
plt.arrow(coords[i, 0], coords[i, 1],
coords[i+1, 0] - coords[i, 0],
coords[i+1, 1] - coords[i, 1],
head_width=5, length_includes_head=True, alpha=0.5, color='gray')
for i, move in enumerate(moves):
plt.annotate(f"Step {i}: {move}", (coords[i, 0], coords[i, 1]),
xytext=(5, 5), textcoords='offset points', fontsize=9, fontweight='bold')
plt.grid(True, linestyle='--', alpha=0.6)
plt.title("Geometric State Machine (Linear Probe PCA)")
plt.xlabel("Principal Component 1 (Primary Axis of Variance)")
plt.ylabel("Principal Component 2 (Secondary Axis of Variance)")
pca_plot = "pca_plot.png"
plt.savefig(pca_plot, dpi=150)
plt.close()
return km_plot, pca_plot, f"Labels: {km_labels}", log_capture.getvalue()
# Gradio Interface with Separated Columns
with gr.Blocks(title="World Model Hybrid Probe") as demo:
gr.Markdown("# 🛰️ World Model Hybrid Probe")
gr.Markdown("Comparing **Logical Categorization** (KMeans) vs **Spatial Intuition** (Linear PCA).")
with gr.Row():
input_box = gr.Textbox(label="Input Moves", placeholder="Up, Up, Right, Left", scale=4)
submit_btn = gr.Button("Analyze", variant="primary", scale=1)
with gr.Row():
# Box 1: Logic
with gr.Column(variant="panel"):
gr.Markdown("### 1. Discrete State Logic (DFA)")
output_km = gr.Image(label="KMeans DFA", type="filepath")
analysis_text = gr.Textbox(label="Cluster Labels", interactive=False)
# Box 2: Geometry (The Clearer Linear Probe)
with gr.Column(variant="panel"):
gr.Markdown("### 2. Geometric Trajectory (Linear Probe)")
output_pca = gr.Image(label="Spatial PCA Map", type="filepath")
gr.Markdown("*This map shows the 'Mental Path' GPT-2 takes through its vector space.*")
log_box = gr.Textbox(label="Probe Logs", lines=5, interactive=False)
submit_btn.click(analyze_dfa, input_box, [output_km, output_pca, analysis_text, log_box])
demo.launch()