Phoenix21 commited on
Commit
c0985ce
·
verified ·
1 Parent(s): 3a924ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -19
app.py CHANGED
@@ -6,35 +6,53 @@ from sklearn.cluster import KMeans
6
  import networkx as nx
7
  import matplotlib.pyplot as plt
8
 
9
- # 1. Configuration for Models & Datasets
10
  MODELS = ["gpt2", "distilgpt2", "qwen/Qwen2.5-0.5B", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
11
- DATASETS = ["wikitext", "tinystories", "ag_news"]
12
 
13
- def analyze_world_model(model_name, dataset_name, num_samples=20):
14
- # Load Model & Tokenizer
 
 
 
 
 
 
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModel.from_pretrained(model_name).to(device)
18
 
19
- # Load Dataset
20
- ds = load_dataset(dataset_name, split='train', streaming=True).take(num_samples)
 
 
 
 
 
 
 
21
 
22
  all_hidden_states = []
23
- labels = []
24
 
25
- # Step A: The Probe (Keplerian Observation)
26
  for i, example in enumerate(ds):
27
- text = example['text'][:100] # Use a snippet
 
 
 
28
  inputs = tokenizer(text, return_tensors="pt").to(device)
29
  with torch.no_grad():
30
  outputs = model(**inputs, output_hidden_states=True)
31
- # Take the last hidden state of the sequence
32
  state = outputs.hidden_states[-1][0, -1, :].cpu().numpy()
33
  all_hidden_states.append(state)
34
- labels.append(f"Seq_{i}")
35
 
36
- # Step B: Myhill-Nerode Clustering (Newtonian Recovery)
37
- # We cluster to find 'Equivalence Classes' (Internal States)
38
  n_clusters = min(len(all_hidden_states), 5)
39
  kmeans = KMeans(n_clusters=n_clusters, n_init=10).fit(all_hidden_states)
40
  state_assignments = kmeans.labels_
@@ -43,25 +61,30 @@ def analyze_world_model(model_name, dataset_name, num_samples=20):
43
  G = nx.DiGraph()
44
  for i in range(len(state_assignments) - 1):
45
  u, v = f"S{state_assignments[i]}", f"S{state_assignments[i+1]}"
46
- G.add_edge(u, v, label=f"Next_{i}")
47
 
48
- # Draw the DFA
49
  plt.figure(figsize=(8, 6))
50
  pos = nx.spring_layout(G)
51
- nx.draw(G, pos, with_labels=True, node_color='orange', node_size=3000, font_weight='bold')
52
  plt.savefig("dfa_output.png")
 
53
 
54
  return "dfa_output.png", f"Model '{model_name}' reduced this dataset into {n_clusters} distinct internal states."
55
 
56
  # 3. Gradio UI
57
  with gr.Blocks() as demo:
58
  gr.Markdown("# The Universal Newtonian Probe")
 
 
59
  with gr.Row():
60
  m_drop = gr.Dropdown(choices=MODELS, label="Select Model", value="gpt2")
61
- d_drop = gr.Dropdown(choices=DATASETS, label="Select Dataset", value="wikitext")
 
62
  btn = gr.Button("Analyze Coherence")
63
- out_img = gr.Image(label="Extracted DFA")
64
- out_txt = gr.Textbox(label="Analysis Result")
 
 
65
 
66
  btn.click(analyze_world_model, inputs=[m_drop, d_drop], outputs=[out_img, out_txt])
67
 
 
6
  import networkx as nx
7
  import matplotlib.pyplot as plt
8
 
9
+ # 1. Configuration for Models & Specific Dataset Configs
10
  MODELS = ["gpt2", "distilgpt2", "qwen/Qwen2.5-0.5B", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
 
11
 
12
+ # Updated to include the specific config names required by HuggingFace
13
+ DATASET_CONFIGS = {
14
+ "wikitext (v2-raw)": ("wikitext", "wikitext-2-raw-v1"),
15
+ "wikitext (v103-raw)": ("wikitext", "wikitext-103-raw-v1"),
16
+ "TinyStories": ("roneneldan/TinyStories", None),
17
+ "AG News": ("ag_news", None)
18
+ }
19
+
20
+ def analyze_world_model(model_name, dataset_key, num_samples=20):
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+ # Get the dataset name and its config
24
+ dataset_name, config_name = DATASET_CONFIGS[dataset_key]
25
+
26
+ # Load Model & Tokenizer
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
28
  model = AutoModel.from_pretrained(model_name).to(device)
29
 
30
+ # 2. FIXED: Load Dataset with config_name
31
+ try:
32
+ if config_name:
33
+ # Passes both dataset name and the specific config
34
+ ds = load_dataset(dataset_name, config_name, split='train', streaming=True).take(num_samples)
35
+ else:
36
+ ds = load_dataset(dataset_name, split='train', streaming=True).take(num_samples)
37
+ except Exception as e:
38
+ return None, f"Error loading dataset: {str(e)}"
39
 
40
  all_hidden_states = []
 
41
 
42
+ # Step A: The Probe (Hidden State Extraction)
43
  for i, example in enumerate(ds):
44
+ # Handle different dataset structures (some use 'text', some use 'content')
45
+ text = example.get('text', example.get('content', ''))[:100]
46
+ if not text: continue
47
+
48
  inputs = tokenizer(text, return_tensors="pt").to(device)
49
  with torch.no_grad():
50
  outputs = model(**inputs, output_hidden_states=True)
51
+ # Snapshot of the last layer's representation
52
  state = outputs.hidden_states[-1][0, -1, :].cpu().numpy()
53
  all_hidden_states.append(state)
 
54
 
55
+ # Step B: Newtonian Recovery (Clustering)
 
56
  n_clusters = min(len(all_hidden_states), 5)
57
  kmeans = KMeans(n_clusters=n_clusters, n_init=10).fit(all_hidden_states)
58
  state_assignments = kmeans.labels_
 
61
  G = nx.DiGraph()
62
  for i in range(len(state_assignments) - 1):
63
  u, v = f"S{state_assignments[i]}", f"S{state_assignments[i+1]}"
64
+ G.add_edge(u, v)
65
 
 
66
  plt.figure(figsize=(8, 6))
67
  pos = nx.spring_layout(G)
68
+ nx.draw(G, pos, with_labels=True, node_color='orange', node_size=3000, font_weight='bold', arrowsize=20)
69
  plt.savefig("dfa_output.png")
70
+ plt.close() # Clean up memory
71
 
72
  return "dfa_output.png", f"Model '{model_name}' reduced this dataset into {n_clusters} distinct internal states."
73
 
74
  # 3. Gradio UI
75
  with gr.Blocks() as demo:
76
  gr.Markdown("# The Universal Newtonian Probe")
77
+ gr.Markdown("Analyze how models build internal maps of different datasets.")
78
+
79
  with gr.Row():
80
  m_drop = gr.Dropdown(choices=MODELS, label="Select Model", value="gpt2")
81
+ d_drop = gr.Dropdown(choices=list(DATASET_CONFIGS.keys()), label="Select Dataset", value="wikitext (v2-raw)")
82
+
83
  btn = gr.Button("Analyze Coherence")
84
+
85
+ with gr.Row():
86
+ out_img = gr.Image(label="Extracted DFA")
87
+ out_txt = gr.Textbox(label="Analysis Result")
88
 
89
  btn.click(analyze_world_model, inputs=[m_drop, d_drop], outputs=[out_img, out_txt])
90