File size: 704 Bytes
768c65d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
from langgraph import Agent, Workflow
from model_utils import generate_text, run_activation_patching
class ExperimentAgent(Agent):
def run(self, prompt):
text = generate_text(prompt)
activations = run_activation_patching(prompt)
return {"generated_text": text, "activations": activations}
class ExplanationAgent(Agent):
def run(self, activations):
explanation = "Layer 5 and 7 had the most influence on the model output."
return explanation
workflow = Workflow()
workflow.add_agent("experiment", ExperimentAgent())
workflow.add_agent("explanation", ExplanationAgent())
workflow.connect("experiment", "explanation", lambda result: result["activations"])
|