trykopy / api.py
Pavol Liška
async
0c3c7ed
raw
history blame
3.24 kB
import os
from dotenv import load_dotenv
from fastapi import FastAPI, Response, Body, Security
from fastapi.security import APIKeyHeader
from pydantic import BaseModel, model_validator
from typing import List
import json
from conversation.conversation_store import ConversationStore
from rag_langchain import LangChainRAG
load_dotenv()
api_keys = [os.environ["API_API_KEY"]]
api = FastAPI()
conversation_store = ConversationStore()
api_key_header = APIKeyHeader(name="Authorization", auto_error=True)
prompt_id = "summarize_rag_1"
check_prompt_id = "check_control_challenge_step_back"
rewrite_prompt_id = "first"
default_llm = "gpt-4o 128k"
class QModel(BaseModel):
q: str
retrieval_count: int = 10
temperature: str = "0.2"
llm: str = default_llm
@classmethod
@model_validator(mode='before')
def validate_to_json(cls, value):
if isinstance(value, str):
return cls(**json.loads(value))
return value
class AModel(BaseModel):
q: str
a: str
sources: List[str]
oid: str
class EmoModel(BaseModel):
qid: str
helpfulness: str
@classmethod
@model_validator(mode='before')
def validate_to_json(cls, value):
if isinstance(value, str):
return cls(**json.loads(value))
return value
@api.get("/")
async def read_root():
return "Empty"
@api.post("/qa", response_model=AModel)
async def qa(api_key: str = Security(api_key_header), data: QModel = Body(...)):
if not valid_api_key(api_key):
return Response(status_code=401)
rag = LangChainRAG(
config={
"retrieve_documents": data.retrieval_count,
"temperature": data.temperature,
"prompt_id": prompt_id,
"check_prompt_id": check_prompt_id,
"rewrite_prompt_id": rewrite_prompt_id
}
)
answer, check_result, sources = await rag.rag_chain(data.q, data.llm)
oid = conversation_store.save_content(
q=data.q,
a=answer,
sources=list(map(lambda doc: doc.page_content, sources)),
params=
{
"prompt_id": prompt_id,
"check_prompt_id": check_prompt_id,
"rewrite_prompt_id": rewrite_prompt_id,
"check_result": check_result,
"temperature": data.temperature,
"retrieve_document_count": str(data.retrieval_count),
}
)
return AModel(
a=answer,
q=data.q,
sources=list(map(lambda doc: doc.page_content, sources)),
oid=oid
)
@api.post("/emo")
async def emo(api_key: str = Security(api_key_header), json_body: EmoModel = Body(...)):
if not valid_api_key(api_key):
return Response(status_code=401)
conversation = conversation_store.get(json_body.qid)
new_params = conversation.params
new_params["user_grading"] = str(json_body.helpfulness)
conversation_store.update(
oid=json_body["qid"],
q=conversation.conversation[0].q,
a=conversation.conversation[0].a,
sources=conversation.conversation[0].sources,
params=new_params
)
def valid_api_key(api_key: str) -> bool:
return api_key in api_keys