import torch import gradio as gr import networkx as nx import matplotlib.pyplot as plt import logging import io from transformers import GPT2Model, GPT2Tokenizer from sklearn.cluster import KMeans import lightning as L # Using Lightning for structural logging # 1. Setup Logging Buffer log_capture = io.StringIO() logging.basicConfig(level=logging.INFO) logger = logging.getLogger("DFA_Probe") handler = logging.StreamHandler(log_capture) handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) logger.addHandler(handler) # 2. Model & Tokenizer Initialization device = "cuda" if torch.cuda.is_available() else "cpu" model_name = "gpt2" tokenizer = GPT2Tokenizer.from_pretrained(model_name) model = GPT2Model.from_pretrained(model_name).to(device) def get_hidden_state(sequence_str): inputs = tokenizer(sequence_str, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True) return outputs.hidden_states[-1][0, -1, :].cpu().numpy() def analyze_dfa(input_text): # Clear logs for a fresh run log_capture.truncate(0) log_capture.seek(0) logger.info(f"🚀 Starting analysis for input: '{input_text}'") moves = [m.strip() for m in input_text.split(",")] history = "" states_vectors = [] # Probing loop for i, move in enumerate(moves): history += f" Move {move}." logger.info(f"Processing Step {i+1}: Extracting activations for history '{history}'") vec = get_hidden_state(history) states_vectors.append(vec) # Clustering (The World Model logic) logger.info(f"🧠 Running KMeans clustering to find equivalent latent states...") num_clusters = min(len(moves), 4) kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors) labels = kmeans.labels_ logger.info(f"📊 State mapping completed: {labels}") # Build and Draw DFA G = nx.DiGraph() for i in range(len(moves)-1): u, v = f"S{labels[i]}", f"S{labels[i+1]}" G.add_edge(u, v, label=moves[i+1]) plt.figure(figsize=(6, 4)) pos = nx.spring_layout(G) nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=2000) edge_labels = nx.get_edge_attributes(G, 'label') nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels) plot_path = "dfa_plot.png" plt.savefig(plot_path) plt.close() logger.info("✅ Analysis finished. DFA plot generated.") return plot_path, f"Found {num_clusters} distinct internal states.", log_capture.getvalue() # 3. Custom Gradio UI with Log View with gr.Blocks(title="World Model DFA Extractor") as demo: gr.Markdown("# World Model DFA Extractor") gr.Markdown("Probing GPT-2 activations to visualize internal state logic.") with gr.Row(): with gr.Column(scale=1): input_box = gr.Textbox( label="Input Moves", placeholder="Right, Left, Right, Left", lines=2 ) submit_btn = gr.Button("Submit", variant="primary") clear_btn = gr.Button("Clear") with gr.Column(scale=2): output_img = gr.Image(label="Extracted Model DFA") analysis_text = gr.Textbox(label="Result Summary") with gr.Row(): # Dedicated Log Box log_box = gr.Textbox( label="System & Probe Logs", interactive=False, lines=10, max_lines=15, autoscroll=True ) submit_btn.click( fn=analyze_dfa, inputs=input_box, outputs=[output_img, analysis_text, log_box] ) clear_btn.click(lambda: [None, "", ""], None, [output_img, analysis_text, log_box]) demo.launch()