import torch import gradio as gr from transformers import AutoModel, AutoTokenizer from datasets import load_dataset from sklearn.cluster import KMeans import networkx as nx import matplotlib.pyplot as plt import collections import os import google.generativeai as genai # 1. Models & Datasets Configs MODELS = ["gpt2", "distilgpt2", "qwen/Qwen2.5-0.5B", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"] DATASET_CONFIGS = { "wikitext (v2-raw)": ("wikitext", "wikitext-2-raw-v1"), "TinyStories": ("roneneldan/TinyStories", None), "AG News": ("ag_news", None) } # Added api_key parameter to the function def analyze_world_model(api_key, model_name, dataset_key, num_samples=25): # Validate API Key if not api_key or len(api_key) < 10: return None, "Error: Please provide a valid Gemini API Key.", "" # Configure Gemini with the user-provided key genai.configure(api_key=api_key) device = "cuda" if torch.cuda.is_available() else "cpu" dataset_name, config_name = DATASET_CONFIGS[dataset_key] tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name).to(device) # Load Dataset if config_name: ds = load_dataset(dataset_name, config_name, split='train', streaming=True).take(num_samples) else: ds = load_dataset(dataset_name, split='train', streaming=True).take(num_samples) all_hidden_states = [] input_snippets = [] # Step A: Probe (Hidden State Extraction) for i, example in enumerate(ds): text = example.get('text', example.get('content', ''))[:150].strip() if not text: continue inputs = tokenizer(text, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True) state = outputs.hidden_states[-2][0, -1, :].cpu().numpy() all_hidden_states.append(state) input_snippets.append(text) # Step B: Newtonian Recovery (Clustering) n_clusters = 5 kmeans = KMeans(n_clusters=n_clusters, n_init=10).fit(all_hidden_states) state_assignments = kmeans.labels_ # STEP C: Iterative Newtonian Interpretation cluster_texts = collections.defaultdict(list) for idx, cluster_id in enumerate(state_assignments): cluster_texts[cluster_id].append(input_snippets[idx]) # Initialize Gemini model gemini_model = genai.GenerativeModel('gemini-2.5-flash') # Updated to a widely available version state_info = "## 🧠 Newtonian State Interpretation\n" state_info += "Each state represents a discovered *Equivalence Class*.\n\n" for cluster_id in range(n_clusters): snippets = cluster_texts[cluster_id] context_payload = "\n".join([f"- {s}" for s in snippets[:8]]) prompt = f""" Act as a Mechanistic Interpretability Researcher. You are reverse-engineering Cluster S{cluster_id} from the '{dataset_key}' dataset. Analyze this cluster with high-fidelity Newtonian depth. ### RAW SNIPPETS: {context_payload} ### REQUIRED OUTPUT FORMAT: **State S{cluster_id} [Structural State Label]** - **Internal World Model**: CORE 'Law' or 'Invariant'. - **Dataset Sensor**: Triggers (Nouns, Syntax). - **Predictive Function**: Biased future tokens. """ try: response = gemini_model.generate_content(prompt, generation_config={"temperature": 0.2}) state_info += response.text.strip() + "\n\n---\n\n" except Exception as e: state_info += f"**State S{cluster_id} [API Error]**: {str(e)}\n\n---\n\n" # Step D: DFA Reconstruction G = nx.DiGraph() for i in range(len(state_assignments) - 1): u, v = f"S{state_assignments[i]}", f"S{state_assignments[i+1]}" G.add_edge(u, v) plt.figure(figsize=(8, 6)) pos = nx.kamada_kawai_layout(G) nx.draw(G, pos, with_labels=True, node_color='#FF8C00', node_size=3500, font_weight='bold', font_size=12, arrowsize=20) plt.savefig("dfa_output.png", transparent=True) plt.close() analysis_brief = f"Model '{model_name}' identified {n_clusters} distinct equivalence classes." return "dfa_output.png", analysis_brief, state_info # 2. Gradio UI with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🌐 The Universal Newtonian Probe") gr.Markdown("Extracting the hidden Deterministic Finite Automaton (DFA) from any model and dataset.") with gr.Row(): # Added API Key Input api_key_input = gr.Textbox( label="Gemini API Key", placeholder="paste your API key here...", type="password" ) with gr.Row(): m_drop = gr.Dropdown(choices=MODELS, label="Select Model", value="gpt2") d_drop = gr.Dropdown(choices=list(DATASET_CONFIGS.keys()), label="Select Dataset", value="wikitext (v2-raw)") btn = gr.Button("Analyze Coherence", variant="primary") with gr.Row(): out_img = gr.Image(label="Extracted DFA (World Map)") with gr.Column(): out_txt = gr.Textbox(label="Analysis Status") out_elaboration = gr.Markdown() # Updated inputs to include api_key_input btn.click( analyze_world_model, inputs=[api_key_input, m_drop, d_drop], outputs=[out_img, out_txt, out_elaboration] ) demo.launch()