Phoenix21 commited on
Commit
e62887f
·
verified ·
1 Parent(s): c0985ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -37
app.py CHANGED
@@ -5,87 +5,95 @@ from datasets import load_dataset
5
  from sklearn.cluster import KMeans
6
  import networkx as nx
7
  import matplotlib.pyplot as plt
 
8
 
9
- # 1. Configuration for Models & Specific Dataset Configs
10
  MODELS = ["gpt2", "distilgpt2", "qwen/Qwen2.5-0.5B", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
11
-
12
- # Updated to include the specific config names required by HuggingFace
13
  DATASET_CONFIGS = {
14
  "wikitext (v2-raw)": ("wikitext", "wikitext-2-raw-v1"),
15
- "wikitext (v103-raw)": ("wikitext", "wikitext-103-raw-v1"),
16
  "TinyStories": ("roneneldan/TinyStories", None),
17
  "AG News": ("ag_news", None)
18
  }
19
 
20
- def analyze_world_model(model_name, dataset_key, num_samples=20):
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
-
23
- # Get the dataset name and its config
24
  dataset_name, config_name = DATASET_CONFIGS[dataset_key]
25
 
26
- # Load Model & Tokenizer
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
28
  model = AutoModel.from_pretrained(model_name).to(device)
29
 
30
- # 2. FIXED: Load Dataset with config_name
31
- try:
32
- if config_name:
33
- # Passes both dataset name and the specific config
34
- ds = load_dataset(dataset_name, config_name, split='train', streaming=True).take(num_samples)
35
- else:
36
- ds = load_dataset(dataset_name, split='train', streaming=True).take(num_samples)
37
- except Exception as e:
38
- return None, f"Error loading dataset: {str(e)}"
39
 
40
  all_hidden_states = []
 
41
 
42
- # Step A: The Probe (Hidden State Extraction)
43
  for i, example in enumerate(ds):
44
- # Handle different dataset structures (some use 'text', some use 'content')
45
- text = example.get('text', example.get('content', ''))[:100]
46
  if not text: continue
47
 
48
  inputs = tokenizer(text, return_tensors="pt").to(device)
49
  with torch.no_grad():
50
  outputs = model(**inputs, output_hidden_states=True)
51
- # Snapshot of the last layer's representation
52
- state = outputs.hidden_states[-1][0, -1, :].cpu().numpy()
53
  all_hidden_states.append(state)
 
54
 
55
  # Step B: Newtonian Recovery (Clustering)
56
- n_clusters = min(len(all_hidden_states), 5)
57
  kmeans = KMeans(n_clusters=n_clusters, n_init=10).fit(all_hidden_states)
58
  state_assignments = kmeans.labels_
59
 
60
- # Step C: DFA Reconstruction
 
 
 
 
 
 
 
 
 
 
 
 
61
  G = nx.DiGraph()
62
  for i in range(len(state_assignments) - 1):
63
  u, v = f"S{state_assignments[i]}", f"S{state_assignments[i+1]}"
64
  G.add_edge(u, v)
65
 
66
  plt.figure(figsize=(8, 6))
67
- pos = nx.spring_layout(G)
68
- nx.draw(G, pos, with_labels=True, node_color='orange', node_size=3000, font_weight='bold', arrowsize=20)
69
- plt.savefig("dfa_output.png")
70
- plt.close() # Clean up memory
 
 
71
 
72
- return "dfa_output.png", f"Model '{model_name}' reduced this dataset into {n_clusters} distinct internal states."
73
 
74
- # 3. Gradio UI
75
- with gr.Blocks() as demo:
76
- gr.Markdown("# The Universal Newtonian Probe")
77
- gr.Markdown("Analyze how models build internal maps of different datasets.")
78
 
79
  with gr.Row():
80
  m_drop = gr.Dropdown(choices=MODELS, label="Select Model", value="gpt2")
81
  d_drop = gr.Dropdown(choices=list(DATASET_CONFIGS.keys()), label="Select Dataset", value="wikitext (v2-raw)")
82
 
83
- btn = gr.Button("Analyze Coherence")
84
 
85
  with gr.Row():
86
- out_img = gr.Image(label="Extracted DFA")
87
- out_txt = gr.Textbox(label="Analysis Result")
 
 
88
 
89
- btn.click(analyze_world_model, inputs=[m_drop, d_drop], outputs=[out_img, out_txt])
90
 
91
  demo.launch()
 
5
  from sklearn.cluster import KMeans
6
  import networkx as nx
7
  import matplotlib.pyplot as plt
8
+ import collections
9
 
10
+ # 1. Models & Datasets Configs
11
  MODELS = ["gpt2", "distilgpt2", "qwen/Qwen2.5-0.5B", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
 
 
12
  DATASET_CONFIGS = {
13
  "wikitext (v2-raw)": ("wikitext", "wikitext-2-raw-v1"),
 
14
  "TinyStories": ("roneneldan/TinyStories", None),
15
  "AG News": ("ag_news", None)
16
  }
17
 
18
+ def analyze_world_model(model_name, dataset_key, num_samples=25):
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
20
  dataset_name, config_name = DATASET_CONFIGS[dataset_key]
21
 
 
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
  model = AutoModel.from_pretrained(model_name).to(device)
24
 
25
+ # Load Dataset
26
+ if config_name:
27
+ ds = load_dataset(dataset_name, config_name, split='train', streaming=True).take(num_samples)
28
+ else:
29
+ ds = load_dataset(dataset_name, split='train', streaming=True).take(num_samples)
 
 
 
 
30
 
31
  all_hidden_states = []
32
+ input_snippets = []
33
 
34
+ # Step A: Probe (Hidden State Extraction)
35
  for i, example in enumerate(ds):
36
+ text = example.get('text', example.get('content', ''))[:150].strip()
 
37
  if not text: continue
38
 
39
  inputs = tokenizer(text, return_tensors="pt").to(device)
40
  with torch.no_grad():
41
  outputs = model(**inputs, output_hidden_states=True)
42
+ # We take the middle-to-late layer where semantic 'World Models' reside
43
+ state = outputs.hidden_states[-2][0, -1, :].cpu().numpy()
44
  all_hidden_states.append(state)
45
+ input_snippets.append(text)
46
 
47
  # Step B: Newtonian Recovery (Clustering)
48
+ n_clusters = 5
49
  kmeans = KMeans(n_clusters=n_clusters, n_init=10).fit(all_hidden_states)
50
  state_assignments = kmeans.labels_
51
 
52
+ # Step C: State Elaboration Logic
53
+ state_info = "### 🧠 State Interpretation & Dataset Mapping\n"
54
+ cluster_texts = collections.defaultdict(list)
55
+ for idx, cluster_id in enumerate(state_assignments):
56
+ cluster_texts[cluster_id].append(input_snippets[idx])
57
+
58
+ for cluster_id in range(n_clusters):
59
+ snippets = cluster_texts[cluster_id]
60
+ # Identify common tokens/attributes that represent this state
61
+ summary = " | ".join([s[:40] + "..." for s in snippets[:2]])
62
+ state_info += f"**State S{cluster_id}**: Representing context such as: *{summary}*\n\n"
63
+
64
+ # Step D: DFA Reconstruction
65
  G = nx.DiGraph()
66
  for i in range(len(state_assignments) - 1):
67
  u, v = f"S{state_assignments[i]}", f"S{state_assignments[i+1]}"
68
  G.add_edge(u, v)
69
 
70
  plt.figure(figsize=(8, 6))
71
+ pos = nx.kamada_kawai_layout(G)
72
+ nx.draw(G, pos, with_labels=True, node_color='#FF8C00', node_size=3500, font_weight='bold', font_size=12, arrowsize=20)
73
+ plt.savefig("dfa_output.png", transparent=True)
74
+ plt.close()
75
+
76
+ analysis_brief = f"Model '{model_name}' identified {n_clusters} distinct equivalence classes in the '{dataset_key}' dataset."
77
 
78
+ return "dfa_output.png", analysis_brief, state_info
79
 
80
+ # 2. Gradio UI with Elaboration
81
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
82
+ gr.Markdown("# 🌐 The Universal Newtonian Probe")
83
+ gr.Markdown("Extracting the hidden Deterministic Finite Automaton (DFA) from any model and dataset.")
84
 
85
  with gr.Row():
86
  m_drop = gr.Dropdown(choices=MODELS, label="Select Model", value="gpt2")
87
  d_drop = gr.Dropdown(choices=list(DATASET_CONFIGS.keys()), label="Select Dataset", value="wikitext (v2-raw)")
88
 
89
+ btn = gr.Button("Analyze Coherence", variant="primary")
90
 
91
  with gr.Row():
92
+ out_img = gr.Image(label="Extracted DFA (World Map)")
93
+ with gr.Column():
94
+ out_txt = gr.Textbox(label="Analysis Status")
95
+ out_elaboration = gr.Markdown() # Markdown for better readability of interpretation
96
 
97
+ btn.click(analyze_world_model, inputs=[m_drop, d_drop], outputs=[out_img, out_txt, out_elaboration])
98
 
99
  demo.launch()