AmnaHassan's picture
Create api.py
a484066 verified
from fastapi import FastAPI
from pydantic import BaseModel
from db import init_db, save_experiment, get_experiment
from model_utils import generate_text, run_activation_patching
app = FastAPI()
init_db()
class ExperimentInput(BaseModel):
prompt: str
@app.post("/generate")
def generate(input: ExperimentInput):
text = generate_text(input.prompt)
activations = run_activation_patching(input.prompt)
explanation = "Explanation generated by LangGraph agent"
exp_id = save_experiment(input.prompt, text, str(activations), explanation)
return {"id": exp_id, "generated_text": text, "activations": activations, "explanation": explanation}
@app.get("/results/{id}")
def get_results(id: int):
row = get_experiment(id)
if row:
return {
"id": row[0],
"prompt": row[1],
"generated_text": row[2],
"activation_traces": row[3],
"explanation": row[4],
"timestamp": row[5]
}
return {"error": "Experiment not found"}