|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
from db import init_db, save_experiment, get_experiment |
|
|
from model_utils import generate_text, run_activation_patching |
|
|
|
|
|
app = FastAPI() |
|
|
init_db() |
|
|
|
|
|
class ExperimentInput(BaseModel): |
|
|
prompt: str |
|
|
|
|
|
@app.post("/generate") |
|
|
def generate(input: ExperimentInput): |
|
|
text = generate_text(input.prompt) |
|
|
activations = run_activation_patching(input.prompt) |
|
|
explanation = "Explanation generated by LangGraph agent" |
|
|
exp_id = save_experiment(input.prompt, text, str(activations), explanation) |
|
|
return {"id": exp_id, "generated_text": text, "activations": activations, "explanation": explanation} |
|
|
|
|
|
@app.get("/results/{id}") |
|
|
def get_results(id: int): |
|
|
row = get_experiment(id) |
|
|
if row: |
|
|
return { |
|
|
"id": row[0], |
|
|
"prompt": row[1], |
|
|
"generated_text": row[2], |
|
|
"activation_traces": row[3], |
|
|
"explanation": row[4], |
|
|
"timestamp": row[5] |
|
|
} |
|
|
return {"error": "Experiment not found"} |
|
|
|