Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| import networkx as nx | |
| import matplotlib.pyplot as plt | |
| from transformers import GPT2Model, GPT2Tokenizer | |
| from sklearn.cluster import KMeans | |
| # 1. Load a real small model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_name = "gpt2" # 124M parameters | |
| tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
| model = GPT2Model.from_pretrained(model_name).to(device) | |
| def get_hidden_state(sequence_str): | |
| inputs = tokenizer(sequence_str, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs, output_hidden_states=True) | |
| # Use the last hidden state of the last token | |
| return outputs.hidden_states[-1][0, -1, :].cpu().numpy() | |
| def analyze_dfa(input_text): | |
| """ | |
| Simulates a 'State Probe'. | |
| Input: 'Right, Up, Left' | |
| Logic: Generates a graph showing how the model's internal representation | |
| changes with each move. | |
| """ | |
| moves = [m.strip() for m in input_text.split(",")] | |
| history = "" | |
| states_vectors = [] | |
| # Track the "path" through the model's internal space | |
| for move in moves: | |
| history += f" Move {move}." | |
| vec = get_hidden_state(history) | |
| states_vectors.append(vec) | |
| # Clustering: Vafa's Compression metric | |
| # We cluster activations to see which moves the model thinks are 'equivalent' | |
| num_clusters = min(len(moves), 4) | |
| kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors) | |
| labels = kmeans.labels_ | |
| # Build the DFA Graph | |
| G = nx.DiGraph() | |
| for i in range(len(moves)-1): | |
| u, v = f"S{labels[i]}", f"S{labels[i+1]}" | |
| G.add_edge(u, v, label=moves[i+1]) | |
| # Draw the DFA | |
| plt.figure(figsize=(6, 4)) | |
| pos = nx.spring_layout(G) | |
| nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=2000) | |
| edge_labels = nx.get_edge_attributes(G, 'label') | |
| nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels) | |
| plt.savefig("dfa_plot.png") | |
| return "dfa_plot.png", f"Found {num_clusters} distinct internal states." | |
| # 3. Gradio Interface | |
| demo = gr.Interface( | |
| fn=analyze_dfa, | |
| inputs=gr.Textbox(placeholder="Enter moves separated by commas, e.g.: Right, Up, Left, Down"), | |
| outputs=[gr.Image(label="Extracted Model DFA"), gr.Text(label="Analysis")], | |
| title="World Model DFA Extractor", | |
| description="This tool probes GPT-2's internal activations to see if it treats different move sequences as the same 'State'." | |
| ) | |
| demo.launch() |