File size: 3,806 Bytes
a40cb3b
 
 
 
68d7012
 
a40cb3b
 
68d7012
a40cb3b
68d7012
 
 
 
 
 
 
 
 
a40cb3b
68d7012
a40cb3b
 
 
 
 
 
 
 
 
 
68d7012
 
 
 
 
 
a40cb3b
 
 
 
68d7012
 
a40cb3b
68d7012
a40cb3b
 
 
68d7012
 
a40cb3b
 
 
 
68d7012
 
 
a40cb3b
 
 
 
 
 
 
 
 
 
 
68d7012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a40cb3b
68d7012
 
 
 
 
 
a40cb3b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import gradio as gr
import networkx as nx
import matplotlib.pyplot as plt
import logging
import io
from transformers import GPT2Model, GPT2Tokenizer
from sklearn.cluster import KMeans
import lightning as L  # Using Lightning for structural logging

# 1. Setup Logging Buffer
log_capture = io.StringIO()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("DFA_Probe")
handler = logging.StreamHandler(log_capture)
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)

# 2. Model & Tokenizer Initialization
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2Model.from_pretrained(model_name).to(device)

def get_hidden_state(sequence_str):
    inputs = tokenizer(sequence_str, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
    return outputs.hidden_states[-1][0, -1, :].cpu().numpy()

def analyze_dfa(input_text):
    # Clear logs for a fresh run
    log_capture.truncate(0)
    log_capture.seek(0)
    
    logger.info(f"🚀 Starting analysis for input: '{input_text}'")
    
    moves = [m.strip() for m in input_text.split(",")]
    history = ""
    states_vectors = []
    
    # Probing loop
    for i, move in enumerate(moves):
        history += f" Move {move}."
        logger.info(f"Processing Step {i+1}: Extracting activations for history '{history}'")
        vec = get_hidden_state(history)
        states_vectors.append(vec)
    
    # Clustering (The World Model logic)
    logger.info(f"🧠 Running KMeans clustering to find equivalent latent states...")
    num_clusters = min(len(moves), 4)
    kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors)
    labels = kmeans.labels_
    
    logger.info(f"📊 State mapping completed: {labels}")

    # Build and Draw DFA
    G = nx.DiGraph()
    for i in range(len(moves)-1):
        u, v = f"S{labels[i]}", f"S{labels[i+1]}"
        G.add_edge(u, v, label=moves[i+1])
    
    plt.figure(figsize=(6, 4))
    pos = nx.spring_layout(G)
    nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=2000)
    edge_labels = nx.get_edge_attributes(G, 'label')
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
    
    plot_path = "dfa_plot.png"
    plt.savefig(plot_path)
    plt.close()
    
    logger.info("✅ Analysis finished. DFA plot generated.")
    return plot_path, f"Found {num_clusters} distinct internal states.", log_capture.getvalue()

# 3. Custom Gradio UI with Log View
with gr.Blocks(title="World Model DFA Extractor") as demo:
    gr.Markdown("# World Model DFA Extractor")
    gr.Markdown("Probing GPT-2 activations to visualize internal state logic.")
    
    with gr.Row():
        with gr.Column(scale=1):
            input_box = gr.Textbox(
                label="Input Moves", 
                placeholder="Right, Left, Right, Left",
                lines=2
            )
            submit_btn = gr.Button("Submit", variant="primary")
            clear_btn = gr.Button("Clear")
        
        with gr.Column(scale=2):
            output_img = gr.Image(label="Extracted Model DFA")
            analysis_text = gr.Textbox(label="Result Summary")

    with gr.Row():
        # Dedicated Log Box
        log_box = gr.Textbox(
            label="System & Probe Logs", 
            interactive=False, 
            lines=10, 
            max_lines=15, 
            autoscroll=True
        )

    submit_btn.click(
        fn=analyze_dfa, 
        inputs=input_box, 
        outputs=[output_img, analysis_text, log_box]
    )
    clear_btn.click(lambda: [None, "", ""], None, [output_img, analysis_text, log_box])

demo.launch()