champ-chatbot / main.py
qyle's picture
deployment
2d42370 verified
import asyncio
import logging
import os
from contextlib import asynccontextmanager
from typing import AsyncGenerator
import torch
from dotenv import load_dotenv
from fastapi import BackgroundTasks, FastAPI, File, Form, Request, Response, UploadFile
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from opentelemetry import trace
from slowapi import Limiter
from slowapi.util import get_remote_address
from uvicorn.logging import DefaultFormatter
from classes.base_models import (
ChatRequest,
CommentRequest,
DeleteFileRequest,
FeedbackRequest,
)
from classes.pii_filter import PIIFilter
from classes.session_conversation_store import SessionConversationStore
from classes.session_document_store import SessionDocumentStore
from classes.session_tracker import SessionTracker
from constants import (
MAX_ID_LENGTH,
STATUS_CODE_EXCEED_SIZE_LIMIT,
STATUS_CODE_INTERNAL_SERVER_ERROR,
)
from exceptions import (
FILE_EXTRACTION_ERROR_STATUS_CODES,
FILE_VALIDATION_ERROR_STATUS_CODES,
FileExtractionException,
FileValidationException,
)
from helpers.dynamodb_helper import log_event
from helpers.file_helper import (
extract_text_from_file,
replace_spaces_in_filename,
validate_file,
)
from helpers.lifespan_helper import cleanup_loop, load_heavy_models, run_cleanup
from helpers.llm_helper import call_llm
from telemetry import setup_telemetry
load_dotenv()
logger = logging.getLogger("uvicorn")
# -------------------- Config --------------------
DEV = os.getenv("ENV", None) == "dev"
# -------------------- Helpers --------------------
# For now, conversations and uploaded documents are stored in RAM.
# This is tolerable for a demo, but we will have to switch to
# Redis (or another real-time database) at some point. We are
# currently storing sessions in what should be a stateless server.
session_tracker = SessionTracker()
session_document_store = SessionDocumentStore()
session_conversation_store = SessionConversationStore()
# -------------------- FastAPI setup --------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
logger = logging.getLogger("uvicorn")
if logger.handlers:
colored_formatter = DefaultFormatter(
fmt="%(levelprefix)s %(asctime)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
logger.handlers[0].setFormatter(colored_formatter)
logger.info("Logging configured!")
if torch.cuda.is_available():
logger.info("CUDA is available")
else:
logger.warning("CUDA is NOT available")
load_heavy_models()
bg_task = asyncio.create_task(
cleanup_loop(
session_tracker, session_document_store, session_conversation_store
)
)
yield
bg_task.cancel()
app = FastAPI(lifespan=lifespan)
setup_telemetry(app)
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
@app.middleware("http")
async def cleanup_middleware(request: Request, call_next):
run_cleanup(session_tracker, session_document_store, session_conversation_store)
response = await call_next(request)
return response
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
# Time profiler
tracer = trace.get_tracer(__name__)
# Rate limiter
limiter = Limiter(key_func=get_remote_address)
@app.post("/chat")
@limiter.limit("20/minute")
async def chat_endpoint(
payload: ChatRequest, background_tasks: BackgroundTasks, request: Request
):
session_id = payload.session_id
model_type = payload.model_type
lang = payload.lang
conversation_id = payload.conversation_id
human_message = payload.human_message
session_tracker.update_session(session_id)
pii_filter = PIIFilter()
with tracer.start_as_current_span("sanitize_document"):
pii_filtered_msg = pii_filter.sanitize(human_message)
conversation = session_conversation_store.add_human_message(
session_id, payload.conversation_id, pii_filtered_msg
)
document_contents = session_document_store.get_document_contents(session_id)
reply = ""
triage_meta = {}
context = []
try:
loop = asyncio.get_running_loop()
with tracer.start_as_current_span("call_llm"):
result = await loop.run_in_executor(
None, call_llm, model_type, lang, conversation, document_contents
)
if isinstance(result, AsyncGenerator):
async def logging_wrapper():
reply = ""
async for token in result:
reply += token
yield token
# Save the messages in DB
background_tasks.add_task(
log_event,
user_id=payload.user_id,
session_id=payload.session_id,
data={
"model_type": payload.model_type,
"consent": payload.consent,
"human_message": payload.human_message,
"reply": reply,
"age_group": payload.age_group,
"gender": payload.gender,
"roles": payload.roles,
"participant_id": payload.participant_id,
"conversation_id": payload.conversation_id,
"lang": payload.lang,
"triage_meta": {},
},
)
# Save the messages in session_conversation_store
background_tasks.add_task(
session_conversation_store.add_assistant_reply,
session_id=session_id,
conversation_id=conversation_id,
reply=reply,
)
return StreamingResponse(logging_wrapper(), media_type="text/event-stream")
reply, triage_meta, context = result
except Exception as e:
background_tasks.add_task(
log_event,
user_id=payload.user_id,
session_id=payload.session_id,
data={
"error": str(e),
"model_type": payload.model_type,
"consent": payload.consent,
"human_message": payload.human_message,
"age_group": payload.age_group,
"gender": payload.gender,
"roles": payload.roles,
"participant_id": payload.participant_id,
"conversation_id": payload.conversation_id,
"lang": payload.lang,
},
)
background_tasks.add_task(
log_event,
user_id=payload.user_id,
session_id=payload.session_id,
data={
"model_type": payload.model_type,
"consent": payload.consent,
"human_message": payload.human_message,
"reply": reply,
"context": context,
"age_group": payload.age_group,
"gender": payload.gender,
"roles": payload.roles,
"participant_id": payload.participant_id,
"conversation_id": payload.conversation_id,
"lang": payload.lang,
**(triage_meta or {}),
},
)
session_conversation_store.add_assistant_reply(session_id, conversation_id, reply)
return {"reply": reply}
# Endpoint for specific replies/responses
@app.post("/feedback")
@limiter.limit("20/minute")
def feedback_endpoint(
payload: FeedbackRequest, background_tasks: BackgroundTasks, request: Request
):
background_tasks.add_task(
log_event,
user_id=payload.user_id,
session_id=payload.session_id,
data={
"consent": payload.consent,
"comment": payload.comment,
"age_group": payload.age_group,
"gender": payload.gender,
"roles": payload.roles,
"participant_id": payload.participant_id,
"message_index": payload.message_index,
"rating": payload.rating,
"reply_content": payload.reply_content,
},
)
# Endpoint for specific generic comments
@app.post("/comment")
@limiter.limit("20/minute")
def comment_endpoint(
payload: CommentRequest, background_tasks: BackgroundTasks, request: Request
):
logger.info("Received comment")
background_tasks.add_task(
log_event,
user_id=payload.user_id,
session_id=payload.session_id,
data={
"consent": payload.consent,
"comment": payload.comment,
"age_group": payload.age_group,
"gender": payload.gender,
"roles": payload.roles,
"participant_id": payload.participant_id,
},
)
@app.put("/file")
@limiter.limit("12/minute")
async def upload_file(
request: Request,
file: UploadFile = File(...),
session_id: str = Form(
pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
),
):
try:
validated_file = await validate_file(file)
except FileValidationException as e:
status_code = FILE_VALIDATION_ERROR_STATUS_CODES[e.error]
return Response(status_code=status_code)
file_content = validated_file.content
file_name = validated_file.filename
file_mime = validated_file.mime_type
try:
file_text = await extract_text_from_file(file_content, file_mime)
except FileExtractionException as e:
status_code = FILE_EXTRACTION_ERROR_STATUS_CODES[e.error]
return Response(status_code=status_code)
except Exception:
# TODO: Log the unexpected failure
return Response(status_code=STATUS_CODE_INTERNAL_SERVER_ERROR)
pii_filter = PIIFilter()
with tracer.start_as_current_span("sanitize_document"):
pii_filtered_file_text = pii_filter.sanitize(file_text)
if session_document_store.create_document(
session_id, pii_filtered_file_text, file_name
):
session_tracker.update_session(session_id)
else:
return Response(status_code=STATUS_CODE_EXCEED_SIZE_LIMIT)
@app.delete("/file")
@limiter.limit("20/minute")
def delete_file(
payload: DeleteFileRequest,
request: Request,
):
session_id = payload.session_id
file_name = payload.file_name
file_name = replace_spaces_in_filename(file_name)
session_document_store.delete_document(session_id, file_name)