Update app/main.py
Browse files- app/main.py +10 -10
app/main.py
CHANGED
|
@@ -24,10 +24,14 @@ def get_db():
|
|
| 24 |
db.close()
|
| 25 |
|
| 26 |
|
| 27 |
-
async def generate_stream(input_data: schemas.BaseModel,
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
yield {"event": "end"}
|
| 32 |
|
| 33 |
|
|
@@ -82,12 +86,8 @@ async def rag_stream(request: Request, db: Session = Depends(get_db)):
|
|
| 82 |
question=user_request.question,
|
| 83 |
chat_history=format_chat_history(chat_history),
|
| 84 |
)
|
| 85 |
-
|
| 86 |
-
return EventSourceResponse(generate_stream(
|
| 87 |
-
rag_input,
|
| 88 |
-
rag_chain,
|
| 89 |
-
[LogResponseCallback(user_request, db)]
|
| 90 |
-
))
|
| 91 |
|
| 92 |
@app.post("/filtered_rag/stream")
|
| 93 |
async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)):
|
|
|
|
| 24 |
db.close()
|
| 25 |
|
| 26 |
|
| 27 |
+
async def generate_stream(input_data: schemas.BaseModel, callbacks):
|
| 28 |
+
stream = rag_chain.stream(
|
| 29 |
+
input_data.dict(),
|
| 30 |
+
config={"callbacks": [callback]}
|
| 31 |
+
)
|
| 32 |
+
for output in stream:
|
| 33 |
+
data = orjson.dumps(output).decode("utf-8")
|
| 34 |
+
yield {'data': data, "event": "data"}
|
| 35 |
yield {"event": "end"}
|
| 36 |
|
| 37 |
|
|
|
|
| 86 |
question=user_request.question,
|
| 87 |
chat_history=format_chat_history(chat_history),
|
| 88 |
)
|
| 89 |
+
callback = LogResponseCallback(user_request, db)
|
| 90 |
+
return EventSourceResponse(generate_stream(rag_input, callback))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
@app.post("/filtered_rag/stream")
|
| 93 |
async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)):
|