File size: 1,305 Bytes
8aa6d51
a9ae09a
 
 
eed0e46
a9ae09a
 
8aa6d51
a9ae09a
 
eed0e46
 
 
 
 
 
 
a9ae09a
 
 
 
 
8aa6d51
a9ae09a
 
 
 
8aa6d51
7116b5e
 
8aa6d51
a9ae09a
7116b5e
 
 
a9ae09a
 
7116b5e
 
 
a9ae09a
 
 
 
7116b5e
 
eed0e46
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import Union
from uuid import UUID, uuid4

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from langchain_core.documents import Document

from src.model.model import generate, search

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
document_store = {}
query_mapping = {}


@app.get("/")
def index() -> dict[str, str]:
    return {"message": "Make a post request to /query."}


@app.post("/query")
async def query(q: str) -> dict[str, Union[UUID, str, list[dict]]]:
    thread_id = uuid4()
    out = search(query=q, thread_id=thread_id)

    docs_dict = [d.dict() for d in out["documents"]]
    document_store[thread_id] = docs_dict
    query_mapping[thread_id] = q
    return {"thread_id": thread_id, "query": q, "documents": docs_dict}


@app.get("/explain/{thread_id}")
async def explain(thread_id: UUID, docid: int) -> dict:
    doc_dict = document_store[thread_id][docid]
    document = Document(
        page_content=doc_dict["page_content"],
        metadata=doc_dict["metadata"],
    )
    query = query_mapping[thread_id]
    out = generate(query=query, document=document, thread_id=thread_id)
    out["document"] = doc_dict

    return out