from langchain_core.runnables import Runnable from langchain_core.callbacks import BaseCallbackHandler from fastapi import FastAPI, Request, Depends from sse_starlette.sse import EventSourceResponse from sqlalchemy.orm import Session from langserve.serialization import WellKnownLCSerializer from typing import Any, List import crud import models import schemas from database import SessionLocal, engine from chains import simple_chain, formatted_chain, history_chain, rag_chain from prompts import format_chat_history from callbacks import LogResponseCallback import orjson models.Base.metadata.create_all(bind=engine) app = FastAPI() def get_db(): db = SessionLocal() try: yield db finally: db.close() async def generate_stream(input_data: schemas.BaseModel, callback): stream = rag_chain.stream( input_data.dict(), config={"callbacks": [callback]} ) for output in stream: data = orjson.dumps(output).decode("utf-8") yield {'data': data, "event": "data"} yield {"event": "end"} @app.get("/") def greet_json(): return {"Hello": "World!"} @app.post("/simple/stream") async def simple_stream(request: Request): data = await request.json() user_question = schemas.UserQuestion(**data['input']) return EventSourceResponse(generate_stream(user_question, simple_chain)) @app.post("/formatted/stream") async def formatted_stream(request: Request): data = await request.json() user_question = schemas.UserQuestion(**data['input']) return EventSourceResponse(generate_stream(user_question, formatted_chain)) @app.post("/history/stream") async def history_stream(request: Request, db: Session = Depends(get_db)): data = await request.json() user_request = schemas.UserRequest(**data['input']) chat_history = crud.get_user_chat_history(db=db, username=user_request.username) message = schemas.MessageBase(message=user_request.question, type='User') crud.add_message(db, message, user_request.username) history_input = schemas.HistoryInput( question=user_request.question, chat_history=format_chat_history(chat_history) ) return EventSourceResponse(generate_stream( history_input, history_chain, [LogResponseCallback(user_request, db)] )) @app.post("/rag/stream") async def rag_stream(request: Request, db: Session = Depends(get_db)): data = await request.json() user_request = schemas.UserRequest(**data['input']) chat_history = crud.get_user_chat_history(db=db, username=user_request.username) message = schemas.MessageBase(message=user_request.question, type='User') crud.add_message(db, message, user_request.username) rag_input = schemas.RagInput( question=user_request.question, chat_history=format_chat_history(chat_history), ) callback = LogResponseCallback(user_request, db) return EventSourceResponse(generate_stream(rag_input, callback)) @app.post("/filtered_rag/stream") async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)): data = await request.json() print(data) user_request = schemas.UserRequest(**data['input']) chat_history = crud.get_user_chat_history(db=db, username=user_request.username) message = schemas.MessageBase(message=user_request.question, type='User') crud.add_message(db, message, user_request.username) rag_input = schemas.RagInput( question=user_request.question, chat_history=format_chat_history(chat_history), hybrid_search=True ) return EventSourceResponse(generate_stream( rag_input, rag_chain, [LogResponseCallback(user_request, db)] )) if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="localhost", reload=True, port=8002)