|
|
|
|
|
import os |
|
|
import json |
|
|
import streamlit as st |
|
|
import requests |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
from model_utils import generate_text, run_activation_patching |
|
|
|
|
|
|
|
|
|
|
|
RENDER_API_BASE = st.secrets.get("render_url") or st.sidebar.text_input( |
|
|
"Render API base URL", value="https://activation-patching-api.onrender.com" |
|
|
) |
|
|
SAVE_ENDPOINT = RENDER_API_BASE.rstrip("/") + "/save" |
|
|
|
|
|
st.title("Mechanistic Analysis Interface (HF Space)") |
|
|
st.write("This Streamlit app runs GPT-2 + activation patching locally, then saves metadata to your Render backend.") |
|
|
|
|
|
prompt = st.text_area("Enter your sentence / prompt", height=150) |
|
|
max_length = st.sidebar.slider("Max generation length", 20, 200, 60) |
|
|
|
|
|
if st.button("Run Experiment"): |
|
|
if not prompt.strip(): |
|
|
st.warning("Please enter a prompt.") |
|
|
else: |
|
|
with st.spinner("Generating text with GPT-2..."): |
|
|
try: |
|
|
generated = generate_text(prompt, max_length=max_length) |
|
|
except Exception as e: |
|
|
st.error(f"Error running generation: {e}") |
|
|
raise |
|
|
|
|
|
st.subheader("Generated Text") |
|
|
st.write(generated) |
|
|
|
|
|
with st.spinner("Running activation patching (TransformerLens)..."): |
|
|
try: |
|
|
activations = run_activation_patching(prompt) |
|
|
except Exception as e: |
|
|
st.error(f"Error running activation patching: {e}") |
|
|
raise |
|
|
|
|
|
st.subheader("Activation traces (sample)") |
|
|
|
|
|
sample = {k: (v.shape if hasattr(v, "shape") else type(v).__name__) for k, v in list(activations.items())[:10]} |
|
|
st.json(sample) |
|
|
|
|
|
|
|
|
explanation = "Explanation placeholder: top influencing layers ... (expand with LangGraph later)" |
|
|
|
|
|
st.subheader("Explanation") |
|
|
st.write(explanation) |
|
|
|
|
|
|
|
|
payload = { |
|
|
"prompt": prompt, |
|
|
"generated_text": generated, |
|
|
|
|
|
"activation_traces": json.dumps({ |
|
|
k: { |
|
|
"shape": getattr(v, "shape", None), |
|
|
"min": float(v.min()) if hasattr(v, "min") else None, |
|
|
"max": float(v.max()) if hasattr(v, "max") else None |
|
|
} for k, v in activations.items() |
|
|
}), |
|
|
"explanation": explanation |
|
|
} |
|
|
|
|
|
|
|
|
try: |
|
|
res = requests.post(SAVE_ENDPOINT, json=payload, timeout=30) |
|
|
st.write("Save status:", res.status_code) |
|
|
st.write("Save response:", res.text) |
|
|
if res.ok: |
|
|
data = res.json() |
|
|
st.success(f"Experiment saved with ID {data.get('id')}") |
|
|
else: |
|
|
st.error(f"Failed to save experiment: {res.text}") |
|
|
except Exception as e: |
|
|
st.error(f"Error saving to Render: {e}") |
|
|
|
|
|
st.markdown("---") |
|
|
st.write("Notes:") |
|
|
st.write("- This app runs heavy ML locally in the HF Space container; Render is used only to persist metadata.") |
|
|
st.write("- If you want LangGraph explanation, we can call a low-cost open model here or run agents locally.") |
|
|
|