backend1 / app /main.py
wang16888's picture
Update app/main.py
5845f07 verified
from langchain_core.runnables import Runnable
from langchain_core.callbacks import BaseCallbackHandler
from fastapi import FastAPI, Request, Depends
from sse_starlette.sse import EventSourceResponse
from sqlalchemy.orm import Session
from langserve.serialization import WellKnownLCSerializer
from typing import Any, List
import crud
import models
import schemas
from database import SessionLocal, engine
from chains import simple_chain, formatted_chain, history_chain, rag_chain
from prompts import format_chat_history
from callbacks import LogResponseCallback
import orjson
models.Base.metadata.create_all(bind=engine)
app = FastAPI()
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
async def generate_stream(input_data: schemas.BaseModel, callback):
stream = rag_chain.stream(
input_data.dict(),
config={"callbacks": [callback]}
)
for output in stream:
data = orjson.dumps(output).decode("utf-8")
yield {'data': data, "event": "data"}
yield {"event": "end"}
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/simple/stream")
async def simple_stream(request: Request):
data = await request.json()
user_question = schemas.UserQuestion(**data['input'])
return EventSourceResponse(generate_stream(user_question, simple_chain))
@app.post("/formatted/stream")
async def formatted_stream(request: Request):
data = await request.json()
user_question = schemas.UserQuestion(**data['input'])
return EventSourceResponse(generate_stream(user_question, formatted_chain))
@app.post("/history/stream")
async def history_stream(request: Request, db: Session = Depends(get_db)):
data = await request.json()
user_request = schemas.UserRequest(**data['input'])
chat_history = crud.get_user_chat_history(db=db, username=user_request.username)
message = schemas.MessageBase(message=user_request.question, type='User')
crud.add_message(db, message, user_request.username)
history_input = schemas.HistoryInput(
question=user_request.question,
chat_history=format_chat_history(chat_history)
)
return EventSourceResponse(generate_stream(
history_input,
history_chain,
[LogResponseCallback(user_request, db)]
))
@app.post("/rag/stream")
async def rag_stream(request: Request, db: Session = Depends(get_db)):
data = await request.json()
user_request = schemas.UserRequest(**data['input'])
chat_history = crud.get_user_chat_history(db=db, username=user_request.username)
message = schemas.MessageBase(message=user_request.question, type='User')
crud.add_message(db, message, user_request.username)
rag_input = schemas.RagInput(
question=user_request.question,
chat_history=format_chat_history(chat_history),
)
callback = LogResponseCallback(user_request, db)
return EventSourceResponse(generate_stream(rag_input, callback))
@app.post("/filtered_rag/stream")
async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)):
data = await request.json()
print(data)
user_request = schemas.UserRequest(**data['input'])
chat_history = crud.get_user_chat_history(db=db, username=user_request.username)
message = schemas.MessageBase(message=user_request.question, type='User')
crud.add_message(db, message, user_request.username)
rag_input = schemas.RagInput(
question=user_request.question,
chat_history=format_chat_history(chat_history),
hybrid_search=True
)
return EventSourceResponse(generate_stream(
rag_input,
rag_chain,
[LogResponseCallback(user_request, db)]
))
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="localhost", reload=True, port=8002)