File size: 3,806 Bytes
a40cb3b 68d7012 a40cb3b 68d7012 a40cb3b 68d7012 a40cb3b 68d7012 a40cb3b 68d7012 a40cb3b 68d7012 a40cb3b 68d7012 a40cb3b 68d7012 a40cb3b 68d7012 a40cb3b 68d7012 a40cb3b 68d7012 a40cb3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import torch
import gradio as gr
import networkx as nx
import matplotlib.pyplot as plt
import logging
import io
from transformers import GPT2Model, GPT2Tokenizer
from sklearn.cluster import KMeans
import lightning as L # Using Lightning for structural logging
# 1. Setup Logging Buffer
log_capture = io.StringIO()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("DFA_Probe")
handler = logging.StreamHandler(log_capture)
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
# 2. Model & Tokenizer Initialization
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2Model.from_pretrained(model_name).to(device)
def get_hidden_state(sequence_str):
inputs = tokenizer(sequence_str, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
return outputs.hidden_states[-1][0, -1, :].cpu().numpy()
def analyze_dfa(input_text):
# Clear logs for a fresh run
log_capture.truncate(0)
log_capture.seek(0)
logger.info(f"🚀 Starting analysis for input: '{input_text}'")
moves = [m.strip() for m in input_text.split(",")]
history = ""
states_vectors = []
# Probing loop
for i, move in enumerate(moves):
history += f" Move {move}."
logger.info(f"Processing Step {i+1}: Extracting activations for history '{history}'")
vec = get_hidden_state(history)
states_vectors.append(vec)
# Clustering (The World Model logic)
logger.info(f"🧠 Running KMeans clustering to find equivalent latent states...")
num_clusters = min(len(moves), 4)
kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors)
labels = kmeans.labels_
logger.info(f"📊 State mapping completed: {labels}")
# Build and Draw DFA
G = nx.DiGraph()
for i in range(len(moves)-1):
u, v = f"S{labels[i]}", f"S{labels[i+1]}"
G.add_edge(u, v, label=moves[i+1])
plt.figure(figsize=(6, 4))
pos = nx.spring_layout(G)
nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=2000)
edge_labels = nx.get_edge_attributes(G, 'label')
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
plot_path = "dfa_plot.png"
plt.savefig(plot_path)
plt.close()
logger.info("✅ Analysis finished. DFA plot generated.")
return plot_path, f"Found {num_clusters} distinct internal states.", log_capture.getvalue()
# 3. Custom Gradio UI with Log View
with gr.Blocks(title="World Model DFA Extractor") as demo:
gr.Markdown("# World Model DFA Extractor")
gr.Markdown("Probing GPT-2 activations to visualize internal state logic.")
with gr.Row():
with gr.Column(scale=1):
input_box = gr.Textbox(
label="Input Moves",
placeholder="Right, Left, Right, Left",
lines=2
)
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Column(scale=2):
output_img = gr.Image(label="Extracted Model DFA")
analysis_text = gr.Textbox(label="Result Summary")
with gr.Row():
# Dedicated Log Box
log_box = gr.Textbox(
label="System & Probe Logs",
interactive=False,
lines=10,
max_lines=15,
autoscroll=True
)
submit_btn.click(
fn=analyze_dfa,
inputs=input_box,
outputs=[output_img, analysis_text, log_box]
)
clear_btn.click(lambda: [None, "", ""], None, [output_img, analysis_text, log_box])
demo.launch() |