import asyncio import logging import os from contextlib import asynccontextmanager from typing import AsyncGenerator import torch from dotenv import load_dotenv from fastapi import BackgroundTasks, FastAPI, File, Form, Request, Response, UploadFile from fastapi.responses import HTMLResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from opentelemetry import trace from slowapi import Limiter from slowapi.util import get_remote_address from uvicorn.logging import DefaultFormatter from classes.base_models import ( ChatRequest, CommentRequest, DeleteFileRequest, FeedbackRequest, ) from classes.pii_filter import PIIFilter from classes.session_conversation_store import SessionConversationStore from classes.session_document_store import SessionDocumentStore from classes.session_tracker import SessionTracker from constants import ( MAX_ID_LENGTH, STATUS_CODE_EXCEED_SIZE_LIMIT, STATUS_CODE_INTERNAL_SERVER_ERROR, ) from exceptions import ( FILE_EXTRACTION_ERROR_STATUS_CODES, FILE_VALIDATION_ERROR_STATUS_CODES, FileExtractionException, FileValidationException, ) from helpers.dynamodb_helper import log_event from helpers.file_helper import ( extract_text_from_file, replace_spaces_in_filename, validate_file, ) from helpers.lifespan_helper import cleanup_loop, load_heavy_models, run_cleanup from helpers.llm_helper import call_llm from telemetry import setup_telemetry load_dotenv() logger = logging.getLogger("uvicorn") # -------------------- Config -------------------- DEV = os.getenv("ENV", None) == "dev" # -------------------- Helpers -------------------- # For now, conversations and uploaded documents are stored in RAM. # This is tolerable for a demo, but we will have to switch to # Redis (or another real-time database) at some point. We are # currently storing sessions in what should be a stateless server. session_tracker = SessionTracker() session_document_store = SessionDocumentStore() session_conversation_store = SessionConversationStore() # -------------------- FastAPI setup -------------------- @asynccontextmanager async def lifespan(app: FastAPI): logger = logging.getLogger("uvicorn") if logger.handlers: colored_formatter = DefaultFormatter( fmt="%(levelprefix)s %(asctime)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) logger.handlers[0].setFormatter(colored_formatter) logger.info("Logging configured!") if torch.cuda.is_available(): logger.info("CUDA is available") else: logger.warning("CUDA is NOT available") load_heavy_models() bg_task = asyncio.create_task( cleanup_loop( session_tracker, session_document_store, session_conversation_store ) ) yield bg_task.cancel() app = FastAPI(lifespan=lifespan) setup_telemetry(app) app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") @app.middleware("http") async def cleanup_middleware(request: Request, call_next): run_cleanup(session_tracker, session_document_store, session_conversation_store) response = await call_next(request) return response @app.get("/", response_class=HTMLResponse) async def home(request: Request): return templates.TemplateResponse("index.html", {"request": request}) # Time profiler tracer = trace.get_tracer(__name__) # Rate limiter limiter = Limiter(key_func=get_remote_address) @app.post("/chat") @limiter.limit("20/minute") async def chat_endpoint( payload: ChatRequest, background_tasks: BackgroundTasks, request: Request ): session_id = payload.session_id model_type = payload.model_type lang = payload.lang conversation_id = payload.conversation_id human_message = payload.human_message session_tracker.update_session(session_id) pii_filter = PIIFilter() with tracer.start_as_current_span("sanitize_document"): pii_filtered_msg = pii_filter.sanitize(human_message) conversation = session_conversation_store.add_human_message( session_id, payload.conversation_id, pii_filtered_msg ) document_contents = session_document_store.get_document_contents(session_id) reply = "" triage_meta = {} context = [] try: loop = asyncio.get_running_loop() with tracer.start_as_current_span("call_llm"): result = await loop.run_in_executor( None, call_llm, model_type, lang, conversation, document_contents ) if isinstance(result, AsyncGenerator): async def logging_wrapper(): reply = "" async for token in result: reply += token yield token # Save the messages in DB background_tasks.add_task( log_event, user_id=payload.user_id, session_id=payload.session_id, data={ "model_type": payload.model_type, "consent": payload.consent, "human_message": payload.human_message, "reply": reply, "age_group": payload.age_group, "gender": payload.gender, "roles": payload.roles, "participant_id": payload.participant_id, "conversation_id": payload.conversation_id, "lang": payload.lang, "triage_meta": {}, }, ) # Save the messages in session_conversation_store background_tasks.add_task( session_conversation_store.add_assistant_reply, session_id=session_id, conversation_id=conversation_id, reply=reply, ) return StreamingResponse(logging_wrapper(), media_type="text/event-stream") reply, triage_meta, context = result except Exception as e: background_tasks.add_task( log_event, user_id=payload.user_id, session_id=payload.session_id, data={ "error": str(e), "model_type": payload.model_type, "consent": payload.consent, "human_message": payload.human_message, "age_group": payload.age_group, "gender": payload.gender, "roles": payload.roles, "participant_id": payload.participant_id, "conversation_id": payload.conversation_id, "lang": payload.lang, }, ) background_tasks.add_task( log_event, user_id=payload.user_id, session_id=payload.session_id, data={ "model_type": payload.model_type, "consent": payload.consent, "human_message": payload.human_message, "reply": reply, "context": context, "age_group": payload.age_group, "gender": payload.gender, "roles": payload.roles, "participant_id": payload.participant_id, "conversation_id": payload.conversation_id, "lang": payload.lang, **(triage_meta or {}), }, ) session_conversation_store.add_assistant_reply(session_id, conversation_id, reply) return {"reply": reply} # Endpoint for specific replies/responses @app.post("/feedback") @limiter.limit("20/minute") def feedback_endpoint( payload: FeedbackRequest, background_tasks: BackgroundTasks, request: Request ): background_tasks.add_task( log_event, user_id=payload.user_id, session_id=payload.session_id, data={ "consent": payload.consent, "comment": payload.comment, "age_group": payload.age_group, "gender": payload.gender, "roles": payload.roles, "participant_id": payload.participant_id, "message_index": payload.message_index, "rating": payload.rating, "reply_content": payload.reply_content, }, ) # Endpoint for specific generic comments @app.post("/comment") @limiter.limit("20/minute") def comment_endpoint( payload: CommentRequest, background_tasks: BackgroundTasks, request: Request ): logger.info("Received comment") background_tasks.add_task( log_event, user_id=payload.user_id, session_id=payload.session_id, data={ "consent": payload.consent, "comment": payload.comment, "age_group": payload.age_group, "gender": payload.gender, "roles": payload.roles, "participant_id": payload.participant_id, }, ) @app.put("/file") @limiter.limit("12/minute") async def upload_file( request: Request, file: UploadFile = File(...), session_id: str = Form( pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH ), ): try: validated_file = await validate_file(file) except FileValidationException as e: status_code = FILE_VALIDATION_ERROR_STATUS_CODES[e.error] return Response(status_code=status_code) file_content = validated_file.content file_name = validated_file.filename file_mime = validated_file.mime_type try: file_text = await extract_text_from_file(file_content, file_mime) except FileExtractionException as e: status_code = FILE_EXTRACTION_ERROR_STATUS_CODES[e.error] return Response(status_code=status_code) except Exception: # TODO: Log the unexpected failure return Response(status_code=STATUS_CODE_INTERNAL_SERVER_ERROR) pii_filter = PIIFilter() with tracer.start_as_current_span("sanitize_document"): pii_filtered_file_text = pii_filter.sanitize(file_text) if session_document_store.create_document( session_id, pii_filtered_file_text, file_name ): session_tracker.update_session(session_id) else: return Response(status_code=STATUS_CODE_EXCEED_SIZE_LIMIT) @app.delete("/file") @limiter.limit("20/minute") def delete_file( payload: DeleteFileRequest, request: Request, ): session_id = payload.session_id file_name = payload.file_name file_name = replace_spaces_in_filename(file_name) session_document_store.delete_document(session_id, file_name)