|
|
import torch |
|
|
import gradio as gr |
|
|
import networkx as nx |
|
|
import matplotlib.pyplot as plt |
|
|
import logging |
|
|
import io |
|
|
from transformers import GPT2Model, GPT2Tokenizer |
|
|
from sklearn.cluster import KMeans |
|
|
import lightning as L |
|
|
|
|
|
|
|
|
log_capture = io.StringIO() |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger("DFA_Probe") |
|
|
handler = logging.StreamHandler(log_capture) |
|
|
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) |
|
|
logger.addHandler(handler) |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model_name = "gpt2" |
|
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name) |
|
|
model = GPT2Model.from_pretrained(model_name).to(device) |
|
|
|
|
|
def get_hidden_state(sequence_str): |
|
|
inputs = tokenizer(sequence_str, return_tensors="pt").to(device) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs, output_hidden_states=True) |
|
|
return outputs.hidden_states[-1][0, -1, :].cpu().numpy() |
|
|
|
|
|
def analyze_dfa(input_text): |
|
|
|
|
|
log_capture.truncate(0) |
|
|
log_capture.seek(0) |
|
|
|
|
|
logger.info(f"π Starting analysis for input: '{input_text}'") |
|
|
|
|
|
moves = [m.strip() for m in input_text.split(",")] |
|
|
history = "" |
|
|
states_vectors = [] |
|
|
|
|
|
|
|
|
for i, move in enumerate(moves): |
|
|
history += f" Move {move}." |
|
|
logger.info(f"Processing Step {i+1}: Extracting activations for history '{history}'") |
|
|
vec = get_hidden_state(history) |
|
|
states_vectors.append(vec) |
|
|
|
|
|
|
|
|
logger.info(f"π§ Running KMeans clustering to find equivalent latent states...") |
|
|
num_clusters = min(len(moves), 4) |
|
|
kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors) |
|
|
labels = kmeans.labels_ |
|
|
|
|
|
logger.info(f"π State mapping completed: {labels}") |
|
|
|
|
|
|
|
|
G = nx.DiGraph() |
|
|
for i in range(len(moves)-1): |
|
|
u, v = f"S{labels[i]}", f"S{labels[i+1]}" |
|
|
G.add_edge(u, v, label=moves[i+1]) |
|
|
|
|
|
plt.figure(figsize=(6, 4)) |
|
|
pos = nx.spring_layout(G) |
|
|
nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=2000) |
|
|
edge_labels = nx.get_edge_attributes(G, 'label') |
|
|
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels) |
|
|
|
|
|
plot_path = "dfa_plot.png" |
|
|
plt.savefig(plot_path) |
|
|
plt.close() |
|
|
|
|
|
logger.info("β
Analysis finished. DFA plot generated.") |
|
|
return plot_path, f"Found {num_clusters} distinct internal states.", log_capture.getvalue() |
|
|
|
|
|
|
|
|
with gr.Blocks(title="World Model DFA Extractor") as demo: |
|
|
gr.Markdown("# World Model DFA Extractor") |
|
|
gr.Markdown("Probing GPT-2 activations to visualize internal state logic.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
input_box = gr.Textbox( |
|
|
label="Input Moves", |
|
|
placeholder="Right, Left, Right, Left", |
|
|
lines=2 |
|
|
) |
|
|
submit_btn = gr.Button("Submit", variant="primary") |
|
|
clear_btn = gr.Button("Clear") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
output_img = gr.Image(label="Extracted Model DFA") |
|
|
analysis_text = gr.Textbox(label="Result Summary") |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
log_box = gr.Textbox( |
|
|
label="System & Probe Logs", |
|
|
interactive=False, |
|
|
lines=10, |
|
|
max_lines=15, |
|
|
autoscroll=True |
|
|
) |
|
|
|
|
|
submit_btn.click( |
|
|
fn=analyze_dfa, |
|
|
inputs=input_box, |
|
|
outputs=[output_img, analysis_text, log_box] |
|
|
) |
|
|
clear_btn.click(lambda: [None, "", ""], None, [output_img, analysis_text, log_box]) |
|
|
|
|
|
demo.launch() |