import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from config import setup_app, agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection from endpoints import create_router from fastapi import WebSocket, WebSocketDisconnect # Create the FastAPI app app = FastAPI(title="TxAgent API", version="2.6.0") # Apply CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) @app.websocket("/ws/notifications") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() try: while True: # Keep connection alive await websocket.receive_text() except WebSocketDisconnect: logger.info("Client disconnected") # Setup the app (e.g., initialize globals, startup event) setup_app(app) # Create and include the router with dependencies router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection) app.include_router(router, prefix="/txagent", tags=["txagent"]) # Also include some endpoints at root level for frontend compatibility from endpoints import ChatRequest, VoiceOutputRequest from fastapi import Depends, HTTPException, UploadFile, File, Form from typing import Optional from auth import get_current_user @app.post("/chat-stream") async def chat_stream_root( request: ChatRequest, current_user: dict = Depends(get_current_user) ): """Chat stream endpoint at root level for frontend compatibility""" # Import the chat stream function from endpoints temp_router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection) # Get the chat stream endpoint function for route in temp_router.routes: if hasattr(route, 'path') and route.path == "/chat-stream": return await route.endpoint(request, current_user) raise HTTPException(status_code=404, detail="Chat stream endpoint not found") @app.post("/voice/synthesize") async def voice_synthesize_root( request: dict, current_user: dict = Depends(get_current_user) ): """Voice synthesis endpoint at root level for frontend compatibility""" # Convert dict to VoiceOutputRequest voice_request = VoiceOutputRequest( text=request.get('text', ''), language=request.get('language', 'en-US'), slow=request.get('slow', False), return_format=request.get('return_format', 'mp3') ) # Get the voice synthesis endpoint function temp_router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection) for route in temp_router.routes: if hasattr(route, 'path') and route.path == "/voice/synthesize": return await route.endpoint(voice_request, current_user) raise HTTPException(status_code=404, detail="Voice synthesis endpoint not found") @app.post("/analyze-report") async def analyze_report_root( file: UploadFile = File(...), patient_id: Optional[str] = Form(None), temperature: float = Form(0.5), max_new_tokens: int = Form(1024), current_user: dict = Depends(get_current_user) ): """Report analysis endpoint at root level for frontend compatibility""" # Get the analyze report endpoint function temp_router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection) for route in temp_router.routes: if hasattr(route, 'path') and route.path == "/analyze-report": return await route.endpoint(file, patient_id, temperature, max_new_tokens, current_user) raise HTTPException(status_code=404, detail="Analyze report endpoint not found") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)