wang16888 commited on
Commit
a79699d
·
verified ·
1 Parent(s): 4776abb

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +10 -10
app/main.py CHANGED
@@ -24,10 +24,14 @@ def get_db():
24
  db.close()
25
 
26
 
27
- async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[]):
28
- for output in runnable.stream(input_data.dict(), config={"callbacks": callbacks}):
29
- data = WellKnownLCSerializer().dumps(output).decode("utf-8")
30
- yield {'data': data, "event": "data"}
 
 
 
 
31
  yield {"event": "end"}
32
 
33
 
@@ -82,12 +86,8 @@ async def rag_stream(request: Request, db: Session = Depends(get_db)):
82
  question=user_request.question,
83
  chat_history=format_chat_history(chat_history),
84
  )
85
-
86
- return EventSourceResponse(generate_stream(
87
- rag_input,
88
- rag_chain,
89
- [LogResponseCallback(user_request, db)]
90
- ))
91
 
92
  @app.post("/filtered_rag/stream")
93
  async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)):
 
24
  db.close()
25
 
26
 
27
+ async def generate_stream(input_data: schemas.BaseModel, callbacks):
28
+ stream = rag_chain.stream(
29
+ input_data.dict(),
30
+ config={"callbacks": [callback]}
31
+ )
32
+ for output in stream:
33
+ data = orjson.dumps(output).decode("utf-8")
34
+ yield {'data': data, "event": "data"}
35
  yield {"event": "end"}
36
 
37
 
 
86
  question=user_request.question,
87
  chat_history=format_chat_history(chat_history),
88
  )
89
+ callback = LogResponseCallback(user_request, db)
90
+ return EventSourceResponse(generate_stream(rag_input, callback))
 
 
 
 
91
 
92
  @app.post("/filtered_rag/stream")
93
  async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)):