Spaces:
Build error
Build error
| import os | |
| import time | |
| import json | |
| import requests | |
| import asyncio | |
| from datetime import datetime | |
| from typing import Dict, List, Optional | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks | |
| from fastapi.responses import StreamingResponse | |
| import uvicorn | |
| from pydantic import BaseModel | |
| from shared.models import ChatRequest, ChatResponse, ChatMessage, WorkerStatus | |
| from shared.chat_history import save_detailed_chat_log, initialize_chat_file | |
| app = FastAPI( | |
| title="Multi-Node Hugging Face API Gateway", | |
| description="API Gateway that routes requests to specialized worker nodes", | |
| version="1.0.0" | |
| ) | |
| # Initialize chat history file | |
| initialize_chat_file() | |
| # Configuration - in production, these would come from environment variables | |
| WORKER_NODES = { | |
| "sam-x-nano": os.getenv("NANO_WORKER_URL", "http://nano-worker:8000"), | |
| "sam-x-mini": os.getenv("MINI_WORKER_URL", "http://mini-worker:8000"), | |
| "sam-x-fast": os.getenv("FAST_WORKER_URL", "http://fast-worker:8000"), | |
| "sam-x-large": os.getenv("LARGE_WORKER_URL", "http://large-worker:8000"), | |
| } | |
| # In-memory worker status tracking (in production, use Redis or database) | |
| worker_status = {} | |
| def startup_event(): | |
| print("Starting Multi-Node Hugging Face API Gateway...") | |
| # Initialize worker status | |
| for model, url in WORKER_NODES.items(): | |
| worker_status[model] = {"active": True, "last_check": time.time(), "load": 0.0} | |
| def route_to_worker(chat_request: ChatRequest) -> Dict: | |
| """ | |
| Route the request to the appropriate worker node based on model | |
| """ | |
| model = chat_request.model.lower() | |
| if model not in WORKER_NODES: | |
| raise HTTPException(status_code=400, detail=f"Model {model} not available") | |
| worker_url = WORKER_NODES[model] | |
| # Make request to worker | |
| try: | |
| response = requests.post( | |
| f"{worker_url}/chat/completions", | |
| json=chat_request.dict(), | |
| timeout=300 # 5 minute timeout for long inference | |
| ) | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| print(f"Error contacting worker {worker_url}: {str(e)}") | |
| worker_status[model] = {"active": False, "last_check": time.time(), "load": 0.0} | |
| raise HTTPException(status_code=503, detail=f"Worker for model {model} is not available") | |
| except Exception as e: | |
| print(f"Unexpected error contacting worker {worker_url}: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def chat_completions(request: ChatRequest, background_tasks: BackgroundTasks): | |
| """ | |
| Main chat completions endpoint - routes to appropriate worker | |
| """ | |
| start_time = time.time() | |
| try: | |
| # Route to appropriate worker | |
| worker_response = route_to_worker(request) | |
| # Calculate processing time | |
| processing_time = time.time() - start_time | |
| # Extract response content | |
| response_content = "" | |
| if "choices" in worker_response and len(worker_response["choices"]) > 0: | |
| response_content = worker_response["choices"][0].get("message", {}).get("content", "") | |
| # Save chat history in background | |
| background_tasks.add_task( | |
| save_detailed_chat_log, | |
| request.dict(), | |
| response_content, | |
| request.model, | |
| processing_time | |
| ) | |
| return worker_response | |
| except HTTPException: | |
| # Re-raise HTTP exceptions | |
| raise | |
| except Exception as e: | |
| print(f"Error in chat_completions: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def list_models(): | |
| """ | |
| List available models | |
| """ | |
| available_models = [model for model, url in WORKER_NODES.items() | |
| if worker_status.get(model, {}).get("active", True)] | |
| return { | |
| "object": "list", | |
| "data": [ | |
| { | |
| "id": model, | |
| "object": "model", | |
| "created": int(time.time()), | |
| "owned_by": "multinode-hf-api" | |
| } | |
| for model in available_models | |
| ] | |
| } | |
| async def health_check(): | |
| """ | |
| Health check endpoint | |
| """ | |
| active_workers = {model: status for model, status in worker_status.items() | |
| if status.get("active", False)} | |
| return { | |
| "status": "healthy" if active_workers else "no_active_workers", | |
| "active_workers": list(active_workers.keys()), | |
| "total_workers": len(WORKER_NODES) | |
| } | |
| async def get_worker_status(): | |
| """ | |
| Get detailed status of all workers | |
| """ | |
| return worker_status | |
| async def simple_chat(message: str, model: str = "sam-x-nano", max_tokens: int = 512): | |
| """ | |
| Simplified chat endpoint for basic interactions | |
| """ | |
| chat_request = ChatRequest( | |
| messages=[ChatMessage(role="user", content=message)], | |
| model=model, | |
| max_tokens=max_tokens | |
| ) | |
| worker_response = route_to_worker(chat_request) | |
| if "choices" in worker_response and len(worker_response["choices"]) > 0: | |
| return {"response": worker_response["choices"][0]["message"]["content"]} | |
| else: | |
| raise HTTPException(status_code=500, detail="No response from worker") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |