Spaces:
Paused
Paused
File size: 3,235 Bytes
0c3c7ed 869eb7d 593b823 869eb7d 0c3c7ed 869eb7d 3c35194 593b823 3c35194 593b823 3c35194 869eb7d 593b823 869eb7d 593b823 869eb7d 593b823 869eb7d 0c3c7ed 869eb7d 593b823 869eb7d 593b823 b11dd45 869eb7d 593b823 869eb7d 3c35194 869eb7d 0c3c7ed 3c35194 869eb7d 0c3c7ed 869eb7d 0c3c7ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import os
from dotenv import load_dotenv
from fastapi import FastAPI, Response, Body, Security
from fastapi.security import APIKeyHeader
from pydantic import BaseModel, model_validator
from typing import List
import json
from conversation.conversation_store import ConversationStore
from rag_langchain import LangChainRAG
load_dotenv()
api_keys = [os.environ["API_API_KEY"]]
api = FastAPI()
conversation_store = ConversationStore()
api_key_header = APIKeyHeader(name="Authorization", auto_error=True)
prompt_id = "summarize_rag_1"
check_prompt_id = "check_control_challenge_step_back"
rewrite_prompt_id = "first"
default_llm = "gpt-4o 128k"
class QModel(BaseModel):
q: str
retrieval_count: int = 10
temperature: str = "0.2"
llm: str = default_llm
@classmethod
@model_validator(mode='before')
def validate_to_json(cls, value):
if isinstance(value, str):
return cls(**json.loads(value))
return value
class AModel(BaseModel):
q: str
a: str
sources: List[str]
oid: str
class EmoModel(BaseModel):
qid: str
helpfulness: str
@classmethod
@model_validator(mode='before')
def validate_to_json(cls, value):
if isinstance(value, str):
return cls(**json.loads(value))
return value
@api.get("/")
async def read_root():
return "Empty"
@api.post("/qa", response_model=AModel)
async def qa(api_key: str = Security(api_key_header), data: QModel = Body(...)):
if not valid_api_key(api_key):
return Response(status_code=401)
rag = LangChainRAG(
config={
"retrieve_documents": data.retrieval_count,
"temperature": data.temperature,
"prompt_id": prompt_id,
"check_prompt_id": check_prompt_id,
"rewrite_prompt_id": rewrite_prompt_id
}
)
answer, check_result, sources = await rag.rag_chain(data.q, data.llm)
oid = conversation_store.save_content(
q=data.q,
a=answer,
sources=list(map(lambda doc: doc.page_content, sources)),
params=
{
"prompt_id": prompt_id,
"check_prompt_id": check_prompt_id,
"rewrite_prompt_id": rewrite_prompt_id,
"check_result": check_result,
"temperature": data.temperature,
"retrieve_document_count": str(data.retrieval_count),
}
)
return AModel(
a=answer,
q=data.q,
sources=list(map(lambda doc: doc.page_content, sources)),
oid=oid
)
@api.post("/emo")
async def emo(api_key: str = Security(api_key_header), json_body: EmoModel = Body(...)):
if not valid_api_key(api_key):
return Response(status_code=401)
conversation = conversation_store.get(json_body.qid)
new_params = conversation.params
new_params["user_grading"] = str(json_body.helpfulness)
conversation_store.update(
oid=json_body["qid"],
q=conversation.conversation[0].q,
a=conversation.conversation[0].a,
sources=conversation.conversation[0].sources,
params=new_params
)
def valid_api_key(api_key: str) -> bool:
return api_key in api_keys
|