File size: 5,503 Bytes
a40cb3b
 
3a924ef
 
 
a40cb3b
 
e62887f
3c6abe9
 
 
e62887f
3a924ef
c0985ce
 
 
 
 
 
69fe038
 
 
 
 
 
 
 
 
3a924ef
c0985ce
 
3a924ef
 
a40cb3b
e62887f
 
 
 
 
a40cb3b
3a924ef
e62887f
a40cb3b
e62887f
3a924ef
e62887f
c0985ce
 
3a924ef
 
 
e62887f
3a924ef
e62887f
3a924ef
c0985ce
e62887f
3a924ef
 
cddf7c8
 
e62887f
 
 
5c3b16a
cddf7c8
2991e6e
cddf7c8
5c3b16a
69fe038
5c3b16a
e62887f
 
cddf7c8
8788b70
5c3b16a
cddf7c8
69fe038
5c3b16a
69fe038
5c3b16a
 
69fe038
d13742d
69fe038
 
 
5c3b16a
 
 
cddf7c8
 
 
69fe038
 
e62887f
3a924ef
 
 
c0985ce
3a924ef
 
e62887f
 
 
 
 
69fe038
a40cb3b
e62887f
a40cb3b
69fe038
e62887f
 
 
c0985ce
69fe038
 
 
 
 
 
 
 
3a924ef
 
c0985ce
 
e62887f
c0985ce
 
e62887f
 
 
69fe038
3a924ef
69fe038
 
 
 
 
 
a40cb3b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import gradio as gr
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset
from sklearn.cluster import KMeans
import networkx as nx
import matplotlib.pyplot as plt
import collections
import os
import google.generativeai as genai

# 1. Models & Datasets Configs
MODELS = ["gpt2", "distilgpt2", "qwen/Qwen2.5-0.5B", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
DATASET_CONFIGS = {
    "wikitext (v2-raw)": ("wikitext", "wikitext-2-raw-v1"),
    "TinyStories": ("roneneldan/TinyStories", None),
    "AG News": ("ag_news", None)
}

# Added api_key parameter to the function
def analyze_world_model(api_key, model_name, dataset_key, num_samples=25):
    # Validate API Key
    if not api_key or len(api_key) < 10:
        return None, "Error: Please provide a valid Gemini API Key.", ""
    
    # Configure Gemini with the user-provided key
    genai.configure(api_key=api_key)
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dataset_name, config_name = DATASET_CONFIGS[dataset_key]
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(device)
    
    # Load Dataset
    if config_name:
        ds = load_dataset(dataset_name, config_name, split='train', streaming=True).take(num_samples)
    else:
        ds = load_dataset(dataset_name, split='train', streaming=True).take(num_samples)
    
    all_hidden_states = []
    input_snippets = []
    
    # Step A: Probe (Hidden State Extraction)
    for i, example in enumerate(ds):
        text = example.get('text', example.get('content', ''))[:150].strip()
        if not text: continue
            
        inputs = tokenizer(text, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
            state = outputs.hidden_states[-2][0, -1, :].cpu().numpy()
            all_hidden_states.append(state)
            input_snippets.append(text)

    # Step B: Newtonian Recovery (Clustering)
    n_clusters = 5
    kmeans = KMeans(n_clusters=n_clusters, n_init=10).fit(all_hidden_states)
    state_assignments = kmeans.labels_

    # STEP C: Iterative Newtonian Interpretation 
    cluster_texts = collections.defaultdict(list)
    for idx, cluster_id in enumerate(state_assignments):
        cluster_texts[cluster_id].append(input_snippets[idx])

    # Initialize Gemini model
    gemini_model = genai.GenerativeModel('gemini-2.5-flash') # Updated to a widely available version
    
    state_info = "## 🧠 Newtonian State Interpretation\n"
    state_info += "Each state represents a discovered *Equivalence Class*.\n\n"

    for cluster_id in range(n_clusters):
        snippets = cluster_texts[cluster_id]
        context_payload = "\n".join([f"- {s}" for s in snippets[:8]]) 
        
        prompt = f"""
        Act as a Mechanistic Interpretability Researcher. You are reverse-engineering Cluster S{cluster_id} 
        from the '{dataset_key}' dataset. Analyze this cluster with high-fidelity Newtonian depth.
        
        ### RAW SNIPPETS:
        {context_payload}
        
        ### REQUIRED OUTPUT FORMAT:
        **State S{cluster_id} [Structural State Label]**
        - **Internal World Model**: CORE 'Law' or 'Invariant'.
        - **Dataset Sensor**: Triggers (Nouns, Syntax).
        - **Predictive Function**: Biased future tokens.
        """
        
        try:
            response = gemini_model.generate_content(prompt, generation_config={"temperature": 0.2})
            state_info += response.text.strip() + "\n\n---\n\n"
        except Exception as e:
            state_info += f"**State S{cluster_id} [API Error]**: {str(e)}\n\n---\n\n"

    # Step D: DFA Reconstruction
    G = nx.DiGraph()
    for i in range(len(state_assignments) - 1):
        u, v = f"S{state_assignments[i]}", f"S{state_assignments[i+1]}"
        G.add_edge(u, v)

    plt.figure(figsize=(8, 6))
    pos = nx.kamada_kawai_layout(G)
    nx.draw(G, pos, with_labels=True, node_color='#FF8C00', node_size=3500, font_weight='bold', font_size=12, arrowsize=20)
    plt.savefig("dfa_output.png", transparent=True)
    plt.close()
    
    analysis_brief = f"Model '{model_name}' identified {n_clusters} distinct equivalence classes."
    
    return "dfa_output.png", analysis_brief, state_info

# 2. Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🌐 The Universal Newtonian Probe")
    gr.Markdown("Extracting the hidden Deterministic Finite Automaton (DFA) from any model and dataset.")
    
    with gr.Row():
        # Added API Key Input
        api_key_input = gr.Textbox(
            label="Gemini API Key", 
            placeholder="paste your API key here...", 
            type="password"
        )
    
    with gr.Row():
        m_drop = gr.Dropdown(choices=MODELS, label="Select Model", value="gpt2")
        d_drop = gr.Dropdown(choices=list(DATASET_CONFIGS.keys()), label="Select Dataset", value="wikitext (v2-raw)")
    
    btn = gr.Button("Analyze Coherence", variant="primary")
    
    with gr.Row():
        out_img = gr.Image(label="Extracted DFA (World Map)")
        with gr.Column():
            out_txt = gr.Textbox(label="Analysis Status")
            out_elaboration = gr.Markdown()
    
    # Updated inputs to include api_key_input
    btn.click(
        analyze_world_model, 
        inputs=[api_key_input, m_drop, d_drop], 
        outputs=[out_img, out_txt, out_elaboration]
    )

demo.launch()