Spaces:
Paused
Paused
| import asyncio | |
| import logging | |
| import os | |
| from contextlib import asynccontextmanager | |
| from typing import AsyncGenerator | |
| import torch | |
| from dotenv import load_dotenv | |
| from fastapi import BackgroundTasks, FastAPI, File, Form, Request, Response, UploadFile | |
| from fastapi.responses import HTMLResponse, StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from opentelemetry import trace | |
| from slowapi import Limiter | |
| from slowapi.util import get_remote_address | |
| from uvicorn.logging import DefaultFormatter | |
| from classes.base_models import ( | |
| ChatRequest, | |
| CommentRequest, | |
| DeleteFileRequest, | |
| FeedbackRequest, | |
| ) | |
| from classes.pii_filter import PIIFilter | |
| from classes.session_conversation_store import SessionConversationStore | |
| from classes.session_document_store import SessionDocumentStore | |
| from classes.session_tracker import SessionTracker | |
| from constants import ( | |
| MAX_ID_LENGTH, | |
| STATUS_CODE_EXCEED_SIZE_LIMIT, | |
| STATUS_CODE_INTERNAL_SERVER_ERROR, | |
| ) | |
| from exceptions import ( | |
| FILE_EXTRACTION_ERROR_STATUS_CODES, | |
| FILE_VALIDATION_ERROR_STATUS_CODES, | |
| FileExtractionException, | |
| FileValidationException, | |
| ) | |
| from helpers.dynamodb_helper import log_event | |
| from helpers.file_helper import ( | |
| extract_text_from_file, | |
| replace_spaces_in_filename, | |
| validate_file, | |
| ) | |
| from helpers.lifespan_helper import cleanup_loop, load_heavy_models, run_cleanup | |
| from helpers.llm_helper import call_llm | |
| from telemetry import setup_telemetry | |
| load_dotenv() | |
| logger = logging.getLogger("uvicorn") | |
| # -------------------- Config -------------------- | |
| DEV = os.getenv("ENV", None) == "dev" | |
| # -------------------- Helpers -------------------- | |
| # For now, conversations and uploaded documents are stored in RAM. | |
| # This is tolerable for a demo, but we will have to switch to | |
| # Redis (or another real-time database) at some point. We are | |
| # currently storing sessions in what should be a stateless server. | |
| session_tracker = SessionTracker() | |
| session_document_store = SessionDocumentStore() | |
| session_conversation_store = SessionConversationStore() | |
| # -------------------- FastAPI setup -------------------- | |
| async def lifespan(app: FastAPI): | |
| logger = logging.getLogger("uvicorn") | |
| if logger.handlers: | |
| colored_formatter = DefaultFormatter( | |
| fmt="%(levelprefix)s %(asctime)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S" | |
| ) | |
| logger.handlers[0].setFormatter(colored_formatter) | |
| logger.info("Logging configured!") | |
| if torch.cuda.is_available(): | |
| logger.info("CUDA is available") | |
| else: | |
| logger.warning("CUDA is NOT available") | |
| load_heavy_models() | |
| bg_task = asyncio.create_task( | |
| cleanup_loop( | |
| session_tracker, session_document_store, session_conversation_store | |
| ) | |
| ) | |
| yield | |
| bg_task.cancel() | |
| app = FastAPI(lifespan=lifespan) | |
| setup_telemetry(app) | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| async def cleanup_middleware(request: Request, call_next): | |
| run_cleanup(session_tracker, session_document_store, session_conversation_store) | |
| response = await call_next(request) | |
| return response | |
| async def home(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| # Time profiler | |
| tracer = trace.get_tracer(__name__) | |
| # Rate limiter | |
| limiter = Limiter(key_func=get_remote_address) | |
| async def chat_endpoint( | |
| payload: ChatRequest, background_tasks: BackgroundTasks, request: Request | |
| ): | |
| session_id = payload.session_id | |
| model_type = payload.model_type | |
| lang = payload.lang | |
| conversation_id = payload.conversation_id | |
| human_message = payload.human_message | |
| session_tracker.update_session(session_id) | |
| pii_filter = PIIFilter() | |
| with tracer.start_as_current_span("sanitize_document"): | |
| pii_filtered_msg = pii_filter.sanitize(human_message) | |
| conversation = session_conversation_store.add_human_message( | |
| session_id, payload.conversation_id, pii_filtered_msg | |
| ) | |
| document_contents = session_document_store.get_document_contents(session_id) | |
| reply = "" | |
| triage_meta = {} | |
| context = [] | |
| try: | |
| loop = asyncio.get_running_loop() | |
| with tracer.start_as_current_span("call_llm"): | |
| result = await loop.run_in_executor( | |
| None, call_llm, model_type, lang, conversation, document_contents | |
| ) | |
| if isinstance(result, AsyncGenerator): | |
| async def logging_wrapper(): | |
| reply = "" | |
| async for token in result: | |
| reply += token | |
| yield token | |
| # Save the messages in DB | |
| background_tasks.add_task( | |
| log_event, | |
| user_id=payload.user_id, | |
| session_id=payload.session_id, | |
| data={ | |
| "model_type": payload.model_type, | |
| "consent": payload.consent, | |
| "human_message": payload.human_message, | |
| "reply": reply, | |
| "age_group": payload.age_group, | |
| "gender": payload.gender, | |
| "roles": payload.roles, | |
| "participant_id": payload.participant_id, | |
| "conversation_id": payload.conversation_id, | |
| "lang": payload.lang, | |
| "triage_meta": {}, | |
| }, | |
| ) | |
| # Save the messages in session_conversation_store | |
| background_tasks.add_task( | |
| session_conversation_store.add_assistant_reply, | |
| session_id=session_id, | |
| conversation_id=conversation_id, | |
| reply=reply, | |
| ) | |
| return StreamingResponse(logging_wrapper(), media_type="text/event-stream") | |
| reply, triage_meta, context = result | |
| except Exception as e: | |
| background_tasks.add_task( | |
| log_event, | |
| user_id=payload.user_id, | |
| session_id=payload.session_id, | |
| data={ | |
| "error": str(e), | |
| "model_type": payload.model_type, | |
| "consent": payload.consent, | |
| "human_message": payload.human_message, | |
| "age_group": payload.age_group, | |
| "gender": payload.gender, | |
| "roles": payload.roles, | |
| "participant_id": payload.participant_id, | |
| "conversation_id": payload.conversation_id, | |
| "lang": payload.lang, | |
| }, | |
| ) | |
| background_tasks.add_task( | |
| log_event, | |
| user_id=payload.user_id, | |
| session_id=payload.session_id, | |
| data={ | |
| "model_type": payload.model_type, | |
| "consent": payload.consent, | |
| "human_message": payload.human_message, | |
| "reply": reply, | |
| "context": context, | |
| "age_group": payload.age_group, | |
| "gender": payload.gender, | |
| "roles": payload.roles, | |
| "participant_id": payload.participant_id, | |
| "conversation_id": payload.conversation_id, | |
| "lang": payload.lang, | |
| **(triage_meta or {}), | |
| }, | |
| ) | |
| session_conversation_store.add_assistant_reply(session_id, conversation_id, reply) | |
| return {"reply": reply} | |
| # Endpoint for specific replies/responses | |
| def feedback_endpoint( | |
| payload: FeedbackRequest, background_tasks: BackgroundTasks, request: Request | |
| ): | |
| background_tasks.add_task( | |
| log_event, | |
| user_id=payload.user_id, | |
| session_id=payload.session_id, | |
| data={ | |
| "consent": payload.consent, | |
| "comment": payload.comment, | |
| "age_group": payload.age_group, | |
| "gender": payload.gender, | |
| "roles": payload.roles, | |
| "participant_id": payload.participant_id, | |
| "message_index": payload.message_index, | |
| "rating": payload.rating, | |
| "reply_content": payload.reply_content, | |
| }, | |
| ) | |
| # Endpoint for specific generic comments | |
| def comment_endpoint( | |
| payload: CommentRequest, background_tasks: BackgroundTasks, request: Request | |
| ): | |
| logger.info("Received comment") | |
| background_tasks.add_task( | |
| log_event, | |
| user_id=payload.user_id, | |
| session_id=payload.session_id, | |
| data={ | |
| "consent": payload.consent, | |
| "comment": payload.comment, | |
| "age_group": payload.age_group, | |
| "gender": payload.gender, | |
| "roles": payload.roles, | |
| "participant_id": payload.participant_id, | |
| }, | |
| ) | |
| async def upload_file( | |
| request: Request, | |
| file: UploadFile = File(...), | |
| session_id: str = Form( | |
| pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH | |
| ), | |
| ): | |
| try: | |
| validated_file = await validate_file(file) | |
| except FileValidationException as e: | |
| status_code = FILE_VALIDATION_ERROR_STATUS_CODES[e.error] | |
| return Response(status_code=status_code) | |
| file_content = validated_file.content | |
| file_name = validated_file.filename | |
| file_mime = validated_file.mime_type | |
| try: | |
| file_text = await extract_text_from_file(file_content, file_mime) | |
| except FileExtractionException as e: | |
| status_code = FILE_EXTRACTION_ERROR_STATUS_CODES[e.error] | |
| return Response(status_code=status_code) | |
| except Exception: | |
| # TODO: Log the unexpected failure | |
| return Response(status_code=STATUS_CODE_INTERNAL_SERVER_ERROR) | |
| pii_filter = PIIFilter() | |
| with tracer.start_as_current_span("sanitize_document"): | |
| pii_filtered_file_text = pii_filter.sanitize(file_text) | |
| if session_document_store.create_document( | |
| session_id, pii_filtered_file_text, file_name | |
| ): | |
| session_tracker.update_session(session_id) | |
| else: | |
| return Response(status_code=STATUS_CODE_EXCEED_SIZE_LIMIT) | |
| def delete_file( | |
| payload: DeleteFileRequest, | |
| request: Request, | |
| ): | |
| session_id = payload.session_id | |
| file_name = payload.file_name | |
| file_name = replace_spaces_in_filename(file_name) | |
| session_document_store.delete_document(session_id, file_name) | |