Spaces:
Runtime error
Runtime error
File size: 4,066 Bytes
62d835c f126604 5795f83 7cd1148 a4ccb56 f275c80 62d835c f275c80 f126604 62d835c f126604 046255d ac9926b f126604 a4ccb56 62d835c 1e0df14 046255d 5795f83 0ccca39 7cd1148 ac9926b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from config import setup_app, agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection
from endpoints import create_router
from fastapi import WebSocket, WebSocketDisconnect
# Create the FastAPI app
app = FastAPI(title="TxAgent API", version="2.6.0")
# Apply CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
@app.websocket("/ws/notifications")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
# Keep connection alive
await websocket.receive_text()
except WebSocketDisconnect:
logger.info("Client disconnected")
# Setup the app (e.g., initialize globals, startup event)
setup_app(app)
# Create and include the router with dependencies
router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection)
app.include_router(router, prefix="/txagent", tags=["txagent"])
# Also include some endpoints at root level for frontend compatibility
from endpoints import ChatRequest, VoiceOutputRequest
from fastapi import Depends, HTTPException, UploadFile, File, Form
from typing import Optional
from auth import get_current_user
@app.post("/chat-stream")
async def chat_stream_root(
request: ChatRequest,
current_user: dict = Depends(get_current_user)
):
"""Chat stream endpoint at root level for frontend compatibility"""
# Import the chat stream function from endpoints
temp_router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection)
# Get the chat stream endpoint function
for route in temp_router.routes:
if hasattr(route, 'path') and route.path == "/chat-stream":
return await route.endpoint(request, current_user)
raise HTTPException(status_code=404, detail="Chat stream endpoint not found")
@app.post("/voice/synthesize")
async def voice_synthesize_root(
request: dict,
current_user: dict = Depends(get_current_user)
):
"""Voice synthesis endpoint at root level for frontend compatibility"""
# Convert dict to VoiceOutputRequest
voice_request = VoiceOutputRequest(
text=request.get('text', ''),
language=request.get('language', 'en-US'),
slow=request.get('slow', False),
return_format=request.get('return_format', 'mp3')
)
# Get the voice synthesis endpoint function
temp_router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection)
for route in temp_router.routes:
if hasattr(route, 'path') and route.path == "/voice/synthesize":
return await route.endpoint(voice_request, current_user)
raise HTTPException(status_code=404, detail="Voice synthesis endpoint not found")
@app.post("/analyze-report")
async def analyze_report_root(
file: UploadFile = File(...),
patient_id: Optional[str] = Form(None),
temperature: float = Form(0.5),
max_new_tokens: int = Form(1024),
current_user: dict = Depends(get_current_user)
):
"""Report analysis endpoint at root level for frontend compatibility"""
# Get the analyze report endpoint function
temp_router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection)
for route in temp_router.routes:
if hasattr(route, 'path') and route.path == "/analyze-report":
return await route.endpoint(file, patient_id, temperature, max_new_tokens, current_user)
raise HTTPException(status_code=404, detail="Analyze report endpoint not found")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000) |