# app.py import gradio as gr from model import ModelWrapper from agents import ExperimentAgent, ExplanationAgent from database import DB import matplotlib.pyplot as plt import io import base64 # Initialize components MODEL_NAME = "gpt2" model = ModelWrapper(MODEL_NAME) db = DB("experiments.db") exp_agent = ExperimentAgent(model, db) expl_agent = ExplanationAgent() def run_experiment(prompt, experiment_type, top_k, max_length): # 1) generate gen_text = model.generate_text(prompt, max_length=max_length, top_k=top_k) # 2) run layer importance analysis (proxy for activation patching) layer_scores = model.layer_importance(prompt, experiment_type=experiment_type) # 3) save to DB exp_id = db.save_experiment(prompt, gen_text, layer_scores) # 4) explanation explanation = expl_agent.explain_layer_importance(layer_scores) # 5) heatmap figure fig = plt.figure(figsize=(6,1.5)) ax = fig.add_subplot(111) ax.imshow([layer_scores], aspect='auto') ax.set_yticks([]) ax.set_xlabel('Layer') ax.set_title('Layer importance (proxy)') buf = io.BytesIO() fig.tight_layout() fig.savefig(buf, format='png') buf.seek(0) return gen_text, explanation, buf demo = gr.Interface( fn=run_experiment, inputs=[ gr.Textbox(lines=3, label="Prompt", placeholder="Enter a sentence or prompt..."), gr.Radio(choices=["story_continuation", "sentence_completion", "token_prediction"], value="story_continuation", label="Experiment type"), gr.Slider(minimum=1, maximum=50, step=1, value=10, label="Top-k (generation)"), gr.Slider(minimum=10, maximum=200, step=1, value=50, label="Max generation length") ], outputs=[ gr.Textbox(label="Generated text"), gr.Textbox(label="Explanation"), gr.Image(type="pil", label="Layer importance heatmap") ], title="Mechanistic Analysis Prototype (GPT-2 + Layer Importance)", description="Quick prototype: GPT-2 generation + layer importance (proxy for activation patching) + SQLite logging" ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)