Spaces:
Sleeping
Sleeping
| import asyncio | |
| import logging | |
| from fastapi.concurrency import asynccontextmanager | |
| import uvicorn | |
| import os | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, Response, WebSocket, WebSocketDisconnect, status, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from models.connection_manager import ConnectionManager | |
| from models.request_payload import RequestPayload | |
| from utils.package_manager import PackageManager | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| IS_DEV = os.environ.get('ENV', 'DEV') != 'PROD' | |
| WEBSOCKET_SECURE_TOKEN = os.getenv("SECURE_TOKEN") | |
| WHITELIST_CHANNEL_IDS = os.getenv('WHITELIST_CHANNEL_IDS') | |
| X_REQUEST_USER = os.environ.get('X_REQUEST_USER') | |
| X_API_KEY = os.environ.get('X_API_KEY') | |
| WHITELIST_CHANNEL_IDS = WHITELIST_CHANNEL_IDS.split(',') if WHITELIST_CHANNEL_IDS is not None else [] | |
| app = FastAPI() | |
| # Initialize the connection manager | |
| manager = ConnectionManager() | |
| package = PackageManager() | |
| logging.basicConfig( | |
| level=logging.WARNING, | |
| format='%(asctime)s %(name)s %(levelname)-8s %(message)s', | |
| datefmt='(%H:%M:%S)' | |
| ) | |
| # CORS Middleware: restrict access to only trusted origins | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| #allow_origins=["https://your-frontend-domain.com"], | |
| #allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def root(): | |
| return Response(status_code=status.HTTP_200_OK, data='ok') | |
| def healthcheck(): | |
| return Response(status_code=status.HTTP_200_OK, data='ok') | |
| async def hi_mlhub(payload: RequestPayload): | |
| if manager.available is not None: | |
| request_id, compressed_data = package.gzip(payload) | |
| # Send binary data to all connected WebSocket clients | |
| await manager.send_bytes(manager.available, compressed_data) | |
| try: | |
| # Wait for the response with a timeout (e.g., 10 seconds) | |
| data = await asyncio.wait_for(manager.listen(manager.available, request_id), timeout=10.0) | |
| return JSONResponse(status_code=status.HTTP_200_OK, content=data) | |
| except Exception: | |
| return JSONResponse(status_code=status.HTTP_504_GATEWAY_TIMEOUT, content={ "error": "Timeout" }) | |
| else: | |
| return JSONResponse(status_code=status.HTTP_502_BAD_GATEWAY, content={ "error": "MLaaS is not available." }) | |
| # Simple token-based authentication dependency | |
| def is_valid_token(token: str): | |
| return token == WEBSOCKET_SECURE_TOKEN | |
| def is_valid_apikey(channel_id: str): | |
| return channel_id is not None and channel_id in WHITELIST_CHANNEL_IDS | |
| # WebSocket endpoint | |
| async def websocket_endpoint(websocket: WebSocket): | |
| headers = websocket.headers | |
| token = headers.get("x-token") | |
| channel_id = headers.get("x-channel-id") | |
| if not is_valid_token(token): | |
| return HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") | |
| if not is_valid_apikey(channel_id): | |
| return HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No permission") | |
| await manager.connect(channel_id, websocket) | |
| try: | |
| while True: | |
| # Common receiver | |
| data = await manager.receive_text(channel_id) | |
| print(f"Message from MLaaS: {data}") | |
| # Notify the manager that a message was received | |
| await manager.notify(channel_id, data) | |
| # Broadcast the message to all clients | |
| #await manager.broadcast(f"Client {channel_id} says: {data}") | |
| except WebSocketDisconnect: | |
| manager.disconnect(channel_id) | |
| await manager.broadcast(f"A client has disconnected with ID: {channel_id}") | |
| return None | |
| def is_valid(u, p): | |
| return u == X_REQUEST_USER and p == X_API_KEY | |
| if __name__ == "__main__": | |
| uvicorn.run('app:app', host='0.0.0.0', port=7860, reload=True) |