Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from transformers import AutoModel, AutoTokenizer | |
| from datasets import load_dataset | |
| from sklearn.cluster import KMeans | |
| import networkx as nx | |
| import matplotlib.pyplot as plt | |
| import collections | |
| import os | |
| import google.generativeai as genai | |
| # 1. Models & Datasets Configs | |
| MODELS = ["gpt2", "distilgpt2", "qwen/Qwen2.5-0.5B", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"] | |
| DATASET_CONFIGS = { | |
| "wikitext (v2-raw)": ("wikitext", "wikitext-2-raw-v1"), | |
| "TinyStories": ("roneneldan/TinyStories", None), | |
| "AG News": ("ag_news", None) | |
| } | |
| # Added api_key parameter to the function | |
| def analyze_world_model(api_key, model_name, dataset_key, num_samples=25): | |
| # Validate API Key | |
| if not api_key or len(api_key) < 10: | |
| return None, "Error: Please provide a valid Gemini API Key.", "" | |
| # Configure Gemini with the user-provided key | |
| genai.configure(api_key=api_key) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dataset_name, config_name = DATASET_CONFIGS[dataset_key] | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModel.from_pretrained(model_name).to(device) | |
| # Load Dataset | |
| if config_name: | |
| ds = load_dataset(dataset_name, config_name, split='train', streaming=True).take(num_samples) | |
| else: | |
| ds = load_dataset(dataset_name, split='train', streaming=True).take(num_samples) | |
| all_hidden_states = [] | |
| input_snippets = [] | |
| # Step A: Probe (Hidden State Extraction) | |
| for i, example in enumerate(ds): | |
| text = example.get('text', example.get('content', ''))[:150].strip() | |
| if not text: continue | |
| inputs = tokenizer(text, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs, output_hidden_states=True) | |
| state = outputs.hidden_states[-2][0, -1, :].cpu().numpy() | |
| all_hidden_states.append(state) | |
| input_snippets.append(text) | |
| # Step B: Newtonian Recovery (Clustering) | |
| n_clusters = 5 | |
| kmeans = KMeans(n_clusters=n_clusters, n_init=10).fit(all_hidden_states) | |
| state_assignments = kmeans.labels_ | |
| # STEP C: Iterative Newtonian Interpretation | |
| cluster_texts = collections.defaultdict(list) | |
| for idx, cluster_id in enumerate(state_assignments): | |
| cluster_texts[cluster_id].append(input_snippets[idx]) | |
| # Initialize Gemini model | |
| gemini_model = genai.GenerativeModel('gemini-2.5-flash') # Updated to a widely available version | |
| state_info = "## 🧠 Newtonian State Interpretation\n" | |
| state_info += "Each state represents a discovered *Equivalence Class*.\n\n" | |
| for cluster_id in range(n_clusters): | |
| snippets = cluster_texts[cluster_id] | |
| context_payload = "\n".join([f"- {s}" for s in snippets[:8]]) | |
| prompt = f""" | |
| Act as a Mechanistic Interpretability Researcher. You are reverse-engineering Cluster S{cluster_id} | |
| from the '{dataset_key}' dataset. Analyze this cluster with high-fidelity Newtonian depth. | |
| ### RAW SNIPPETS: | |
| {context_payload} | |
| ### REQUIRED OUTPUT FORMAT: | |
| **State S{cluster_id} [Structural State Label]** | |
| - **Internal World Model**: CORE 'Law' or 'Invariant'. | |
| - **Dataset Sensor**: Triggers (Nouns, Syntax). | |
| - **Predictive Function**: Biased future tokens. | |
| """ | |
| try: | |
| response = gemini_model.generate_content(prompt, generation_config={"temperature": 0.2}) | |
| state_info += response.text.strip() + "\n\n---\n\n" | |
| except Exception as e: | |
| state_info += f"**State S{cluster_id} [API Error]**: {str(e)}\n\n---\n\n" | |
| # Step D: DFA Reconstruction | |
| G = nx.DiGraph() | |
| for i in range(len(state_assignments) - 1): | |
| u, v = f"S{state_assignments[i]}", f"S{state_assignments[i+1]}" | |
| G.add_edge(u, v) | |
| plt.figure(figsize=(8, 6)) | |
| pos = nx.kamada_kawai_layout(G) | |
| nx.draw(G, pos, with_labels=True, node_color='#FF8C00', node_size=3500, font_weight='bold', font_size=12, arrowsize=20) | |
| plt.savefig("dfa_output.png", transparent=True) | |
| plt.close() | |
| analysis_brief = f"Model '{model_name}' identified {n_clusters} distinct equivalence classes." | |
| return "dfa_output.png", analysis_brief, state_info | |
| # 2. Gradio UI | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🌐 The Universal Newtonian Probe") | |
| gr.Markdown("Extracting the hidden Deterministic Finite Automaton (DFA) from any model and dataset.") | |
| with gr.Row(): | |
| # Added API Key Input | |
| api_key_input = gr.Textbox( | |
| label="Gemini API Key", | |
| placeholder="paste your API key here...", | |
| type="password" | |
| ) | |
| with gr.Row(): | |
| m_drop = gr.Dropdown(choices=MODELS, label="Select Model", value="gpt2") | |
| d_drop = gr.Dropdown(choices=list(DATASET_CONFIGS.keys()), label="Select Dataset", value="wikitext (v2-raw)") | |
| btn = gr.Button("Analyze Coherence", variant="primary") | |
| with gr.Row(): | |
| out_img = gr.Image(label="Extracted DFA (World Map)") | |
| with gr.Column(): | |
| out_txt = gr.Textbox(label="Analysis Status") | |
| out_elaboration = gr.Markdown() | |
| # Updated inputs to include api_key_input | |
| btn.click( | |
| analyze_world_model, | |
| inputs=[api_key_input, m_drop, d_drop], | |
| outputs=[out_img, out_txt, out_elaboration] | |
| ) | |
| demo.launch() |