Phoenix21's picture
Update app.py
68d7012 verified
import torch
import gradio as gr
import networkx as nx
import matplotlib.pyplot as plt
import logging
import io
from transformers import GPT2Model, GPT2Tokenizer
from sklearn.cluster import KMeans
import lightning as L # Using Lightning for structural logging
# 1. Setup Logging Buffer
log_capture = io.StringIO()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("DFA_Probe")
handler = logging.StreamHandler(log_capture)
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
# 2. Model & Tokenizer Initialization
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2Model.from_pretrained(model_name).to(device)
def get_hidden_state(sequence_str):
inputs = tokenizer(sequence_str, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
return outputs.hidden_states[-1][0, -1, :].cpu().numpy()
def analyze_dfa(input_text):
# Clear logs for a fresh run
log_capture.truncate(0)
log_capture.seek(0)
logger.info(f"πŸš€ Starting analysis for input: '{input_text}'")
moves = [m.strip() for m in input_text.split(",")]
history = ""
states_vectors = []
# Probing loop
for i, move in enumerate(moves):
history += f" Move {move}."
logger.info(f"Processing Step {i+1}: Extracting activations for history '{history}'")
vec = get_hidden_state(history)
states_vectors.append(vec)
# Clustering (The World Model logic)
logger.info(f"🧠 Running KMeans clustering to find equivalent latent states...")
num_clusters = min(len(moves), 4)
kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors)
labels = kmeans.labels_
logger.info(f"πŸ“Š State mapping completed: {labels}")
# Build and Draw DFA
G = nx.DiGraph()
for i in range(len(moves)-1):
u, v = f"S{labels[i]}", f"S{labels[i+1]}"
G.add_edge(u, v, label=moves[i+1])
plt.figure(figsize=(6, 4))
pos = nx.spring_layout(G)
nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=2000)
edge_labels = nx.get_edge_attributes(G, 'label')
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
plot_path = "dfa_plot.png"
plt.savefig(plot_path)
plt.close()
logger.info("βœ… Analysis finished. DFA plot generated.")
return plot_path, f"Found {num_clusters} distinct internal states.", log_capture.getvalue()
# 3. Custom Gradio UI with Log View
with gr.Blocks(title="World Model DFA Extractor") as demo:
gr.Markdown("# World Model DFA Extractor")
gr.Markdown("Probing GPT-2 activations to visualize internal state logic.")
with gr.Row():
with gr.Column(scale=1):
input_box = gr.Textbox(
label="Input Moves",
placeholder="Right, Left, Right, Left",
lines=2
)
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Column(scale=2):
output_img = gr.Image(label="Extracted Model DFA")
analysis_text = gr.Textbox(label="Result Summary")
with gr.Row():
# Dedicated Log Box
log_box = gr.Textbox(
label="System & Probe Logs",
interactive=False,
lines=10,
max_lines=15,
autoscroll=True
)
submit_btn.click(
fn=analyze_dfa,
inputs=input_box,
outputs=[output_img, analysis_text, log_box]
)
clear_btn.click(lambda: [None, "", ""], None, [output_img, analysis_text, log_box])
demo.launch()