Phoenix21 commited on
Commit
8113b74
·
verified ·
1 Parent(s): 1c8cef7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -59
app.py CHANGED
@@ -4,11 +4,12 @@ import networkx as nx
4
  import matplotlib.pyplot as plt
5
  import logging
6
  import io
 
7
  from transformers import GPT2Model, GPT2Tokenizer
8
  from sklearn.cluster import KMeans
9
- import lightning as L # Using Lightning for structural logging
10
 
11
- # 1. Setup Logging Buffer
12
  log_capture = io.StringIO()
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger("DFA_Probe")
@@ -16,7 +17,7 @@ handler = logging.StreamHandler(log_capture)
16
  handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
17
  logger.addHandler(handler)
18
 
19
- # 2. Model & Tokenizer Initialization
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
  model_name = "gpt2"
22
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
@@ -29,84 +30,61 @@ def get_hidden_state(sequence_str):
29
  return outputs.hidden_states[-1][0, -1, :].cpu().numpy()
30
 
31
  def analyze_dfa(input_text):
32
- # Clear logs for a fresh run
33
  log_capture.truncate(0)
34
  log_capture.seek(0)
35
 
36
- logger.info(f"🚀 Starting analysis for input: '{input_text}'")
37
-
38
  moves = [m.strip() for m in input_text.split(",")]
39
  history = ""
40
  states_vectors = []
41
 
42
- # Probing loop
43
  for i, move in enumerate(moves):
44
  history += f" Move {move}."
45
- logger.info(f"Processing Step {i+1}: Extracting activations for history '{history}'")
46
  vec = get_hidden_state(history)
47
  states_vectors.append(vec)
48
 
49
- # Clustering (The World Model logic)
50
- logger.info(f"🧠 Running KMeans clustering to find equivalent latent states...")
51
  num_clusters = min(len(moves), 4)
52
  kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors)
53
- labels = kmeans.labels_
54
 
55
- logger.info(f"📊 State mapping completed: {labels}")
56
-
57
- # Build and Draw DFA
58
- G = nx.DiGraph()
59
  for i in range(len(moves)-1):
60
- u, v = f"S{labels[i]}", f"S{labels[i+1]}"
61
- G.add_edge(u, v, label=moves[i+1])
62
 
63
- plt.figure(figsize=(6, 4))
64
- pos = nx.spring_layout(G)
65
- nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=2000)
66
- edge_labels = nx.get_edge_attributes(G, 'label')
67
- nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- plot_path = "dfa_plot.png"
70
  plt.savefig(plot_path)
71
  plt.close()
72
 
73
- logger.info("✅ Analysis finished. DFA plot generated.")
74
- return plot_path, f"Found {num_clusters} distinct internal states.", log_capture.getvalue()
75
-
76
- # 3. Custom Gradio UI with Log View
77
- with gr.Blocks(title="World Model DFA Extractor") as demo:
78
- gr.Markdown("# World Model DFA Extractor")
79
- gr.Markdown("Probing GPT-2 activations to visualize internal state logic.")
80
-
81
- with gr.Row():
82
- with gr.Column(scale=1):
83
- input_box = gr.Textbox(
84
- label="Input Moves",
85
- placeholder="Right, Left, Right, Left",
86
- lines=2
87
- )
88
- submit_btn = gr.Button("Submit", variant="primary")
89
- clear_btn = gr.Button("Clear")
90
-
91
- with gr.Column(scale=2):
92
- output_img = gr.Image(label="Extracted Model DFA")
93
- analysis_text = gr.Textbox(label="Result Summary")
94
 
 
 
 
 
 
95
  with gr.Row():
96
- # Dedicated Log Box
97
- log_box = gr.Textbox(
98
- label="System & Probe Logs",
99
- interactive=False,
100
- lines=10,
101
- max_lines=15,
102
- autoscroll=True
103
- )
104
-
105
- submit_btn.click(
106
- fn=analyze_dfa,
107
- inputs=input_box,
108
- outputs=[output_img, analysis_text, log_box]
109
- )
110
- clear_btn.click(lambda: [None, "", ""], None, [output_img, analysis_text, log_box])
111
 
112
  demo.launch()
 
4
  import matplotlib.pyplot as plt
5
  import logging
6
  import io
7
+ import numpy as np
8
  from transformers import GPT2Model, GPT2Tokenizer
9
  from sklearn.cluster import KMeans
10
+ from sklearn.decomposition import PCA
11
 
12
+ # Setup Logging
13
  log_capture = io.StringIO()
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger("DFA_Probe")
 
17
  handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
18
  logger.addHandler(handler)
19
 
20
+ # Load GPT-2
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
  model_name = "gpt2"
23
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
 
30
  return outputs.hidden_states[-1][0, -1, :].cpu().numpy()
31
 
32
  def analyze_dfa(input_text):
 
33
  log_capture.truncate(0)
34
  log_capture.seek(0)
35
 
 
 
36
  moves = [m.strip() for m in input_text.split(",")]
37
  history = ""
38
  states_vectors = []
39
 
 
40
  for i, move in enumerate(moves):
41
  history += f" Move {move}."
 
42
  vec = get_hidden_state(history)
43
  states_vectors.append(vec)
44
 
45
+ # --- 1. KMeans Graph (Unsupervised State Map) ---
 
46
  num_clusters = min(len(moves), 4)
47
  kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors)
48
+ km_labels = kmeans.labels_
49
 
50
+ G_km = nx.DiGraph()
 
 
 
51
  for i in range(len(moves)-1):
52
+ G_km.add_edge(f"S{km_labels[i]}", f"S{km_labels[i+1]}", label=moves[i+1])
 
53
 
54
+ plt.figure(figsize=(12, 5))
55
+ plt.subplot(1, 2, 1)
56
+ pos_km = nx.spring_layout(G_km)
57
+ nx.draw(G_km, pos_km, with_labels=True, node_color='lightblue', node_size=1500)
58
+ nx.draw_networkx_edge_labels(G_km, pos_km, edge_labels=nx.get_edge_attributes(G_km, 'label'))
59
+ plt.title("KMeans DFA (State-Based)")
60
+
61
+ # --- 2. Linear Probe / PCA (Geometric Map) ---
62
+ logger.info("📐 Running Linear Probe (PCA) to find the 'Spatial Axis'...")
63
+ pca = PCA(n_components=2)
64
+ coords = pca.fit_transform(states_vectors)
65
+
66
+ plt.subplot(1, 2, 2)
67
+ plt.scatter(coords[:, 0], coords[:, 1], c=range(len(moves)), cmap='viridis', s=100)
68
+ for i, move in enumerate(moves):
69
+ plt.annotate(f"{i}:{move}", (coords[i, 0], coords[i, 1]))
70
+ plt.plot(coords[:, 0], coords[:, 1], 'r--', alpha=0.3) # Path line
71
+ plt.title("Linear Probe (Spatial Projection)")
72
 
73
+ plot_path = "comparison_plot.png"
74
  plt.savefig(plot_path)
75
  plt.close()
76
 
77
+ return plot_path, f"KMeans Labels: {km_labels}", log_capture.getvalue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ # Launching with dual display
80
+ with gr.Blocks() as demo:
81
+ gr.Markdown("# KMeans vs. Linear Probe Analysis")
82
+ input_box = gr.Textbox(label="Moves (Right, Left...)")
83
+ submit_btn = gr.Button("Compare")
84
  with gr.Row():
85
+ output_img = gr.Image(label="KMeans (Left) vs Linear PCA (Right)")
86
+ analysis_text = gr.Textbox(label="Mapping Results")
87
+ log_box = gr.Textbox(label="Probe Logs", lines=5)
88
+ submit_btn.click(analyze_dfa, input_box, [output_img, analysis_text, log_box])
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  demo.launch()