Phoenix21's picture
Update app.py
2991e6e verified
import torch
import gradio as gr
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset
from sklearn.cluster import KMeans
import networkx as nx
import matplotlib.pyplot as plt
import collections
import os
import google.generativeai as genai
# 1. Models & Datasets Configs
MODELS = ["gpt2", "distilgpt2", "qwen/Qwen2.5-0.5B", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
DATASET_CONFIGS = {
"wikitext (v2-raw)": ("wikitext", "wikitext-2-raw-v1"),
"TinyStories": ("roneneldan/TinyStories", None),
"AG News": ("ag_news", None)
}
# Added api_key parameter to the function
def analyze_world_model(api_key, model_name, dataset_key, num_samples=25):
# Validate API Key
if not api_key or len(api_key) < 10:
return None, "Error: Please provide a valid Gemini API Key.", ""
# Configure Gemini with the user-provided key
genai.configure(api_key=api_key)
device = "cuda" if torch.cuda.is_available() else "cpu"
dataset_name, config_name = DATASET_CONFIGS[dataset_key]
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
# Load Dataset
if config_name:
ds = load_dataset(dataset_name, config_name, split='train', streaming=True).take(num_samples)
else:
ds = load_dataset(dataset_name, split='train', streaming=True).take(num_samples)
all_hidden_states = []
input_snippets = []
# Step A: Probe (Hidden State Extraction)
for i, example in enumerate(ds):
text = example.get('text', example.get('content', ''))[:150].strip()
if not text: continue
inputs = tokenizer(text, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
state = outputs.hidden_states[-2][0, -1, :].cpu().numpy()
all_hidden_states.append(state)
input_snippets.append(text)
# Step B: Newtonian Recovery (Clustering)
n_clusters = 5
kmeans = KMeans(n_clusters=n_clusters, n_init=10).fit(all_hidden_states)
state_assignments = kmeans.labels_
# STEP C: Iterative Newtonian Interpretation
cluster_texts = collections.defaultdict(list)
for idx, cluster_id in enumerate(state_assignments):
cluster_texts[cluster_id].append(input_snippets[idx])
# Initialize Gemini model
gemini_model = genai.GenerativeModel('gemini-2.5-flash') # Updated to a widely available version
state_info = "## 🧠 Newtonian State Interpretation\n"
state_info += "Each state represents a discovered *Equivalence Class*.\n\n"
for cluster_id in range(n_clusters):
snippets = cluster_texts[cluster_id]
context_payload = "\n".join([f"- {s}" for s in snippets[:8]])
prompt = f"""
Act as a Mechanistic Interpretability Researcher. You are reverse-engineering Cluster S{cluster_id}
from the '{dataset_key}' dataset. Analyze this cluster with high-fidelity Newtonian depth.
### RAW SNIPPETS:
{context_payload}
### REQUIRED OUTPUT FORMAT:
**State S{cluster_id} [Structural State Label]**
- **Internal World Model**: CORE 'Law' or 'Invariant'.
- **Dataset Sensor**: Triggers (Nouns, Syntax).
- **Predictive Function**: Biased future tokens.
"""
try:
response = gemini_model.generate_content(prompt, generation_config={"temperature": 0.2})
state_info += response.text.strip() + "\n\n---\n\n"
except Exception as e:
state_info += f"**State S{cluster_id} [API Error]**: {str(e)}\n\n---\n\n"
# Step D: DFA Reconstruction
G = nx.DiGraph()
for i in range(len(state_assignments) - 1):
u, v = f"S{state_assignments[i]}", f"S{state_assignments[i+1]}"
G.add_edge(u, v)
plt.figure(figsize=(8, 6))
pos = nx.kamada_kawai_layout(G)
nx.draw(G, pos, with_labels=True, node_color='#FF8C00', node_size=3500, font_weight='bold', font_size=12, arrowsize=20)
plt.savefig("dfa_output.png", transparent=True)
plt.close()
analysis_brief = f"Model '{model_name}' identified {n_clusters} distinct equivalence classes."
return "dfa_output.png", analysis_brief, state_info
# 2. Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🌐 The Universal Newtonian Probe")
gr.Markdown("Extracting the hidden Deterministic Finite Automaton (DFA) from any model and dataset.")
with gr.Row():
# Added API Key Input
api_key_input = gr.Textbox(
label="Gemini API Key",
placeholder="paste your API key here...",
type="password"
)
with gr.Row():
m_drop = gr.Dropdown(choices=MODELS, label="Select Model", value="gpt2")
d_drop = gr.Dropdown(choices=list(DATASET_CONFIGS.keys()), label="Select Dataset", value="wikitext (v2-raw)")
btn = gr.Button("Analyze Coherence", variant="primary")
with gr.Row():
out_img = gr.Image(label="Extracted DFA (World Map)")
with gr.Column():
out_txt = gr.Textbox(label="Analysis Status")
out_elaboration = gr.Markdown()
# Updated inputs to include api_key_input
btn.click(
analyze_world_model,
inputs=[api_key_input, m_drop, d_drop],
outputs=[out_img, out_txt, out_elaboration]
)
demo.launch()