Phoenix21 commited on
Commit
68d7012
·
verified ·
1 Parent(s): 1c6672e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -25
app.py CHANGED
@@ -2,12 +2,23 @@ import torch
2
  import gradio as gr
3
  import networkx as nx
4
  import matplotlib.pyplot as plt
 
 
5
  from transformers import GPT2Model, GPT2Tokenizer
6
  from sklearn.cluster import KMeans
 
7
 
8
- # 1. Load a real small model
 
 
 
 
 
 
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_name = "gpt2" # 124M parameters
11
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
12
  model = GPT2Model.from_pretrained(model_name).to(device)
13
 
@@ -15,55 +26,87 @@ def get_hidden_state(sequence_str):
15
  inputs = tokenizer(sequence_str, return_tensors="pt").to(device)
16
  with torch.no_grad():
17
  outputs = model(**inputs, output_hidden_states=True)
18
- # Use the last hidden state of the last token
19
  return outputs.hidden_states[-1][0, -1, :].cpu().numpy()
20
 
21
  def analyze_dfa(input_text):
22
- """
23
- Simulates a 'State Probe'.
24
- Input: 'Right, Up, Left'
25
- Logic: Generates a graph showing how the model's internal representation
26
- changes with each move.
27
- """
28
  moves = [m.strip() for m in input_text.split(",")]
29
  history = ""
30
  states_vectors = []
31
 
32
- # Track the "path" through the model's internal space
33
- for move in moves:
34
  history += f" Move {move}."
 
35
  vec = get_hidden_state(history)
36
  states_vectors.append(vec)
37
 
38
- # Clustering: Vafa's Compression metric
39
- # We cluster activations to see which moves the model thinks are 'equivalent'
40
  num_clusters = min(len(moves), 4)
41
  kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors)
42
  labels = kmeans.labels_
43
 
44
- # Build the DFA Graph
 
 
45
  G = nx.DiGraph()
46
  for i in range(len(moves)-1):
47
  u, v = f"S{labels[i]}", f"S{labels[i+1]}"
48
  G.add_edge(u, v, label=moves[i+1])
49
 
50
- # Draw the DFA
51
  plt.figure(figsize=(6, 4))
52
  pos = nx.spring_layout(G)
53
  nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=2000)
54
  edge_labels = nx.get_edge_attributes(G, 'label')
55
  nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
56
 
57
- plt.savefig("dfa_plot.png")
58
- return "dfa_plot.png", f"Found {num_clusters} distinct internal states."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # 3. Gradio Interface
61
- demo = gr.Interface(
62
- fn=analyze_dfa,
63
- inputs=gr.Textbox(placeholder="Enter moves separated by commas, e.g.: Right, Up, Left, Down"),
64
- outputs=[gr.Image(label="Extracted Model DFA"), gr.Text(label="Analysis")],
65
- title="World Model DFA Extractor",
66
- description="This tool probes GPT-2's internal activations to see if it treats different move sequences as the same 'State'."
67
- )
68
 
69
  demo.launch()
 
2
  import gradio as gr
3
  import networkx as nx
4
  import matplotlib.pyplot as plt
5
+ import logging
6
+ import io
7
  from transformers import GPT2Model, GPT2Tokenizer
8
  from sklearn.cluster import KMeans
9
+ import lightning as L # Using Lightning for structural logging
10
 
11
+ # 1. Setup Logging Buffer
12
+ log_capture = io.StringIO()
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger("DFA_Probe")
15
+ handler = logging.StreamHandler(log_capture)
16
+ handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
17
+ logger.addHandler(handler)
18
+
19
+ # 2. Model & Tokenizer Initialization
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ model_name = "gpt2"
22
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
23
  model = GPT2Model.from_pretrained(model_name).to(device)
24
 
 
26
  inputs = tokenizer(sequence_str, return_tensors="pt").to(device)
27
  with torch.no_grad():
28
  outputs = model(**inputs, output_hidden_states=True)
 
29
  return outputs.hidden_states[-1][0, -1, :].cpu().numpy()
30
 
31
  def analyze_dfa(input_text):
32
+ # Clear logs for a fresh run
33
+ log_capture.truncate(0)
34
+ log_capture.seek(0)
35
+
36
+ logger.info(f"🚀 Starting analysis for input: '{input_text}'")
37
+
38
  moves = [m.strip() for m in input_text.split(",")]
39
  history = ""
40
  states_vectors = []
41
 
42
+ # Probing loop
43
+ for i, move in enumerate(moves):
44
  history += f" Move {move}."
45
+ logger.info(f"Processing Step {i+1}: Extracting activations for history '{history}'")
46
  vec = get_hidden_state(history)
47
  states_vectors.append(vec)
48
 
49
+ # Clustering (The World Model logic)
50
+ logger.info(f"🧠 Running KMeans clustering to find equivalent latent states...")
51
  num_clusters = min(len(moves), 4)
52
  kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors)
53
  labels = kmeans.labels_
54
 
55
+ logger.info(f"📊 State mapping completed: {labels}")
56
+
57
+ # Build and Draw DFA
58
  G = nx.DiGraph()
59
  for i in range(len(moves)-1):
60
  u, v = f"S{labels[i]}", f"S{labels[i+1]}"
61
  G.add_edge(u, v, label=moves[i+1])
62
 
 
63
  plt.figure(figsize=(6, 4))
64
  pos = nx.spring_layout(G)
65
  nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=2000)
66
  edge_labels = nx.get_edge_attributes(G, 'label')
67
  nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
68
 
69
+ plot_path = "dfa_plot.png"
70
+ plt.savefig(plot_path)
71
+ plt.close()
72
+
73
+ logger.info("✅ Analysis finished. DFA plot generated.")
74
+ return plot_path, f"Found {num_clusters} distinct internal states.", log_capture.getvalue()
75
+
76
+ # 3. Custom Gradio UI with Log View
77
+ with gr.Blocks(title="World Model DFA Extractor") as demo:
78
+ gr.Markdown("# World Model DFA Extractor")
79
+ gr.Markdown("Probing GPT-2 activations to visualize internal state logic.")
80
+
81
+ with gr.Row():
82
+ with gr.Column(scale=1):
83
+ input_box = gr.Textbox(
84
+ label="Input Moves",
85
+ placeholder="Right, Left, Right, Left",
86
+ lines=2
87
+ )
88
+ submit_btn = gr.Button("Submit", variant="primary")
89
+ clear_btn = gr.Button("Clear")
90
+
91
+ with gr.Column(scale=2):
92
+ output_img = gr.Image(label="Extracted Model DFA")
93
+ analysis_text = gr.Textbox(label="Result Summary")
94
+
95
+ with gr.Row():
96
+ # Dedicated Log Box
97
+ log_box = gr.Textbox(
98
+ label="System & Probe Logs",
99
+ interactive=False,
100
+ lines=10,
101
+ max_lines=15,
102
+ autoscroll=True
103
+ )
104
 
105
+ submit_btn.click(
106
+ fn=analyze_dfa,
107
+ inputs=input_box,
108
+ outputs=[output_img, analysis_text, log_box]
109
+ )
110
+ clear_btn.click(lambda: [None, "", ""], None, [output_img, analysis_text, log_box])
 
 
111
 
112
  demo.launch()