AmnaHassan commited on
Commit
a484066
·
verified ·
1 Parent(s): 768c65d

Create api.py

Browse files
Files changed (1) hide show
  1. api.py +32 -0
api.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from db import init_db, save_experiment, get_experiment
4
+ from model_utils import generate_text, run_activation_patching
5
+
6
+ app = FastAPI()
7
+ init_db()
8
+
9
+ class ExperimentInput(BaseModel):
10
+ prompt: str
11
+
12
+ @app.post("/generate")
13
+ def generate(input: ExperimentInput):
14
+ text = generate_text(input.prompt)
15
+ activations = run_activation_patching(input.prompt)
16
+ explanation = "Explanation generated by LangGraph agent"
17
+ exp_id = save_experiment(input.prompt, text, str(activations), explanation)
18
+ return {"id": exp_id, "generated_text": text, "activations": activations, "explanation": explanation}
19
+
20
+ @app.get("/results/{id}")
21
+ def get_results(id: int):
22
+ row = get_experiment(id)
23
+ if row:
24
+ return {
25
+ "id": row[0],
26
+ "prompt": row[1],
27
+ "generated_text": row[2],
28
+ "activation_traces": row[3],
29
+ "explanation": row[4],
30
+ "timestamp": row[5]
31
+ }
32
+ return {"error": "Experiment not found"}