File size: 3,235 Bytes
0c3c7ed
 
 
869eb7d
 
593b823
 
 
869eb7d
 
 
 
0c3c7ed
 
 
869eb7d
 
 
 
 
 
 
 
 
 
 
3c35194
 
 
 
 
 
593b823
 
 
 
 
 
 
 
 
 
 
 
 
 
3c35194
 
 
 
 
593b823
 
 
 
 
 
 
3c35194
869eb7d
593b823
869eb7d
 
 
593b823
 
869eb7d
 
 
 
 
593b823
 
869eb7d
 
 
 
 
 
0c3c7ed
869eb7d
 
593b823
869eb7d
 
 
 
 
 
 
 
593b823
b11dd45
869eb7d
 
 
593b823
 
 
 
 
 
869eb7d
 
 
3c35194
869eb7d
 
 
0c3c7ed
 
3c35194
869eb7d
 
0c3c7ed
 
 
869eb7d
 
 
 
0c3c7ed
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os

from dotenv import load_dotenv
from fastapi import FastAPI, Response, Body, Security
from fastapi.security import APIKeyHeader
from pydantic import BaseModel, model_validator
from typing import List
import json

from conversation.conversation_store import ConversationStore
from rag_langchain import LangChainRAG

load_dotenv()
api_keys = [os.environ["API_API_KEY"]]

api = FastAPI()

conversation_store = ConversationStore()

api_key_header = APIKeyHeader(name="Authorization", auto_error=True)
prompt_id = "summarize_rag_1"
check_prompt_id = "check_control_challenge_step_back"
rewrite_prompt_id = "first"
default_llm = "gpt-4o 128k"


class QModel(BaseModel):
    q: str
    retrieval_count: int = 10
    temperature: str = "0.2"
    llm: str = default_llm

    @classmethod
    @model_validator(mode='before')
    def validate_to_json(cls, value):
        if isinstance(value, str):
            return cls(**json.loads(value))
        return value


class AModel(BaseModel):
    q: str
    a: str
    sources: List[str]
    oid: str


class EmoModel(BaseModel):
    qid: str
    helpfulness: str

    @classmethod
    @model_validator(mode='before')
    def validate_to_json(cls, value):
        if isinstance(value, str):
            return cls(**json.loads(value))
        return value


@api.get("/")
async def read_root():
    return "Empty"


@api.post("/qa", response_model=AModel)
async def qa(api_key: str = Security(api_key_header), data: QModel = Body(...)):
    if not valid_api_key(api_key):
        return Response(status_code=401)

    rag = LangChainRAG(
        config={
            "retrieve_documents": data.retrieval_count,
            "temperature": data.temperature,
            "prompt_id": prompt_id,
            "check_prompt_id": check_prompt_id,
            "rewrite_prompt_id": rewrite_prompt_id
        }
    )

    answer, check_result, sources = await rag.rag_chain(data.q, data.llm)

    oid = conversation_store.save_content(
        q=data.q,
        a=answer,
        sources=list(map(lambda doc: doc.page_content, sources)),
        params=
        {
            "prompt_id": prompt_id,
            "check_prompt_id": check_prompt_id,
            "rewrite_prompt_id": rewrite_prompt_id,
            "check_result": check_result,
            "temperature": data.temperature,
            "retrieve_document_count": str(data.retrieval_count),
        }
    )

    return AModel(
                a=answer,
                q=data.q,
                sources=list(map(lambda doc: doc.page_content, sources)),
                oid=oid
            )


@api.post("/emo")
async def emo(api_key: str = Security(api_key_header), json_body: EmoModel = Body(...)):
    if not valid_api_key(api_key):
        return Response(status_code=401)

    conversation = conversation_store.get(json_body.qid)
    new_params = conversation.params
    new_params["user_grading"] = str(json_body.helpfulness)
    conversation_store.update(
        oid=json_body["qid"],
        q=conversation.conversation[0].q,
        a=conversation.conversation[0].a,
        sources=conversation.conversation[0].sources,
        params=new_params
    )


def valid_api_key(api_key: str) -> bool:
    return api_key in api_keys