|
|
"""
|
|
|
API Server for Mamba Swarm
|
|
|
FastAPI-based server for serving the distributed Mamba language model
|
|
|
"""
|
|
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
from pydantic import BaseModel, Field
|
|
|
from typing import List, Optional, Dict, Any, AsyncGenerator
|
|
|
import asyncio
|
|
|
import json
|
|
|
import time
|
|
|
import logging
|
|
|
import torch
|
|
|
from contextlib import asynccontextmanager
|
|
|
import uvicorn
|
|
|
|
|
|
|
|
|
from system.mambaSwarm import SwarmEngine
|
|
|
from system.inference import InferenceEngine
|
|
|
from routing.router import Router
|
|
|
from training.trainer import setup_logging
|
|
|
|
|
|
|
|
|
class GenerationRequest(BaseModel):
|
|
|
prompt: str = Field(..., description="Input text prompt")
|
|
|
max_length: int = Field(default=100, ge=1, le=2048, description="Maximum generation length")
|
|
|
temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Sampling temperature")
|
|
|
top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p sampling")
|
|
|
top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling")
|
|
|
repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0, description="Repetition penalty")
|
|
|
stream: bool = Field(default=False, description="Enable streaming response")
|
|
|
domain: Optional[str] = Field(default=None, description="Specific domain for routing")
|
|
|
|
|
|
class GenerationResponse(BaseModel):
|
|
|
generated_text: str
|
|
|
prompt: str
|
|
|
generation_time: float
|
|
|
tokens_generated: int
|
|
|
model_info: Dict[str, Any]
|
|
|
|
|
|
class StreamingToken(BaseModel):
|
|
|
token: str
|
|
|
is_final: bool = False
|
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
|
status: str
|
|
|
swarm_status: Dict[str, Any]
|
|
|
system_info: Dict[str, Any]
|
|
|
timestamp: float
|
|
|
|
|
|
class ModelInfo(BaseModel):
|
|
|
total_parameters: int
|
|
|
active_encoders: int
|
|
|
total_encoders: int
|
|
|
memory_usage: Dict[str, float]
|
|
|
device_info: List[str]
|
|
|
|
|
|
|
|
|
swarm_engine: Optional[SwarmEngine] = None
|
|
|
inference_engine: Optional[InferenceEngine] = None
|
|
|
|
|
|
@asynccontextmanager
|
|
|
async def lifespan(app: FastAPI):
|
|
|
"""Manage application lifespan"""
|
|
|
global swarm_engine, inference_engine
|
|
|
|
|
|
|
|
|
logging.info("Initializing Mamba Swarm API Server...")
|
|
|
|
|
|
try:
|
|
|
|
|
|
swarm_engine = SwarmEngine()
|
|
|
await asyncio.get_event_loop().run_in_executor(None, swarm_engine.initialize)
|
|
|
|
|
|
|
|
|
inference_engine = InferenceEngine(swarm_engine)
|
|
|
|
|
|
logging.info("Mamba Swarm API Server initialized successfully")
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Failed to initialize swarm: {e}")
|
|
|
raise
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
logging.info("Shutting down Mamba Swarm API Server...")
|
|
|
if swarm_engine:
|
|
|
swarm_engine.shutdown()
|
|
|
|
|
|
|
|
|
app = FastAPI(
|
|
|
title="Mamba Swarm API",
|
|
|
description="Distributed Mamba Language Model API with 100 encoder units",
|
|
|
version="1.0.0",
|
|
|
lifespan=lifespan
|
|
|
)
|
|
|
|
|
|
|
|
|
app.add_middleware(
|
|
|
CORSMiddleware,
|
|
|
allow_origins=["*"],
|
|
|
allow_credentials=True,
|
|
|
allow_methods=["*"],
|
|
|
allow_headers=["*"],
|
|
|
)
|
|
|
|
|
|
|
|
|
async def get_swarm_engine() -> SwarmEngine:
|
|
|
if swarm_engine is None:
|
|
|
raise HTTPException(status_code=503, detail="Swarm engine not initialized")
|
|
|
return swarm_engine
|
|
|
|
|
|
async def get_inference_engine() -> InferenceEngine:
|
|
|
if inference_engine is None:
|
|
|
raise HTTPException(status_code=503, detail="Inference engine not initialized")
|
|
|
return inference_engine
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse)
|
|
|
async def health_check(swarm: SwarmEngine = Depends(get_swarm_engine)):
|
|
|
"""Health check endpoint"""
|
|
|
try:
|
|
|
swarm_status = swarm.get_status()
|
|
|
system_info = {
|
|
|
"cuda_available": torch.cuda.is_available(),
|
|
|
"cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
|
|
|
"python_version": "3.8+",
|
|
|
}
|
|
|
|
|
|
return HealthResponse(
|
|
|
status="healthy",
|
|
|
swarm_status=swarm_status,
|
|
|
system_info=system_info,
|
|
|
timestamp=time.time()
|
|
|
)
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=f"Health check failed: {str(e)}")
|
|
|
|
|
|
@app.get("/model/info", response_model=ModelInfo)
|
|
|
async def get_model_info(swarm: SwarmEngine = Depends(get_swarm_engine)):
|
|
|
"""Get model information"""
|
|
|
try:
|
|
|
info = swarm.get_model_info()
|
|
|
memory_stats = swarm.memory_manager.get_memory_stats()
|
|
|
|
|
|
return ModelInfo(
|
|
|
total_parameters=info.get("total_parameters", 7000000000),
|
|
|
active_encoders=info.get("active_encoders", 100),
|
|
|
total_encoders=info.get("total_encoders", 100),
|
|
|
memory_usage={
|
|
|
"system_memory_gb": memory_stats.used_memory,
|
|
|
"gpu_memory_gb": memory_stats.gpu_memory,
|
|
|
"cache_size_gb": memory_stats.cache_size
|
|
|
},
|
|
|
device_info=info.get("devices", ["cuda:0" if torch.cuda.is_available() else "cpu"])
|
|
|
)
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}")
|
|
|
|
|
|
@app.post("/generate", response_model=GenerationResponse)
|
|
|
async def generate_text(
|
|
|
request: GenerationRequest,
|
|
|
inference: InferenceEngine = Depends(get_inference_engine)
|
|
|
):
|
|
|
"""Generate text from prompt"""
|
|
|
try:
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
result = await asyncio.get_event_loop().run_in_executor(
|
|
|
None,
|
|
|
inference.generate,
|
|
|
request.prompt,
|
|
|
{
|
|
|
"max_length": request.max_length,
|
|
|
"temperature": request.temperature,
|
|
|
"top_p": request.top_p,
|
|
|
"top_k": request.top_k,
|
|
|
"repetition_penalty": request.repetition_penalty,
|
|
|
"domain": request.domain
|
|
|
}
|
|
|
)
|
|
|
|
|
|
generation_time = time.time() - start_time
|
|
|
|
|
|
return GenerationResponse(
|
|
|
generated_text=result["generated_text"],
|
|
|
prompt=request.prompt,
|
|
|
generation_time=generation_time,
|
|
|
tokens_generated=result.get("tokens_generated", 0),
|
|
|
model_info=result.get("model_info", {})
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
|
|
|
|
|
@app.post("/generate/stream")
|
|
|
async def generate_text_stream(
|
|
|
request: GenerationRequest,
|
|
|
inference: InferenceEngine = Depends(get_inference_engine)
|
|
|
):
|
|
|
"""Generate text with streaming response"""
|
|
|
if not request.stream:
|
|
|
raise HTTPException(status_code=400, detail="Streaming not requested")
|
|
|
|
|
|
async def generate_stream() -> AsyncGenerator[str, None]:
|
|
|
try:
|
|
|
|
|
|
generator = inference.generate_stream(
|
|
|
request.prompt,
|
|
|
{
|
|
|
"max_length": request.max_length,
|
|
|
"temperature": request.temperature,
|
|
|
"top_p": request.top_p,
|
|
|
"top_k": request.top_k,
|
|
|
"repetition_penalty": request.repetition_penalty,
|
|
|
"domain": request.domain
|
|
|
}
|
|
|
)
|
|
|
|
|
|
for token_data in generator:
|
|
|
streaming_token = StreamingToken(
|
|
|
token=token_data.get("token", ""),
|
|
|
is_final=token_data.get("is_final", False),
|
|
|
metadata=token_data.get("metadata", {})
|
|
|
)
|
|
|
|
|
|
yield f"data: {streaming_token.json()}\n\n"
|
|
|
|
|
|
if streaming_token.is_final:
|
|
|
break
|
|
|
|
|
|
except Exception as e:
|
|
|
error_token = StreamingToken(
|
|
|
token="",
|
|
|
is_final=True,
|
|
|
metadata={"error": str(e)}
|
|
|
)
|
|
|
yield f"data: {error_token.json()}\n\n"
|
|
|
|
|
|
return StreamingResponse(
|
|
|
generate_stream(),
|
|
|
media_type="text/plain",
|
|
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
|
|
)
|
|
|
|
|
|
@app.post("/generate/batch")
|
|
|
async def generate_batch(
|
|
|
requests: List[GenerationRequest],
|
|
|
inference: InferenceEngine = Depends(get_inference_engine)
|
|
|
):
|
|
|
"""Generate text for multiple prompts"""
|
|
|
if len(requests) > 10:
|
|
|
raise HTTPException(status_code=400, detail="Batch size too large (max 10)")
|
|
|
|
|
|
try:
|
|
|
|
|
|
tasks = []
|
|
|
for req in requests:
|
|
|
task = asyncio.get_event_loop().run_in_executor(
|
|
|
None,
|
|
|
inference.generate,
|
|
|
req.prompt,
|
|
|
{
|
|
|
"max_length": req.max_length,
|
|
|
"temperature": req.temperature,
|
|
|
"top_p": req.top_p,
|
|
|
"top_k": req.top_k,
|
|
|
"repetition_penalty": req.repetition_penalty,
|
|
|
"domain": req.domain
|
|
|
}
|
|
|
)
|
|
|
tasks.append(task)
|
|
|
|
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
|
|
responses = []
|
|
|
for i, (req, result) in enumerate(zip(requests, results)):
|
|
|
if isinstance(result, Exception):
|
|
|
responses.append({
|
|
|
"error": str(result),
|
|
|
"prompt": req.prompt,
|
|
|
"index": i
|
|
|
})
|
|
|
else:
|
|
|
responses.append(GenerationResponse(
|
|
|
generated_text=result["generated_text"],
|
|
|
prompt=req.prompt,
|
|
|
generation_time=result.get("generation_time", 0),
|
|
|
tokens_generated=result.get("tokens_generated", 0),
|
|
|
model_info=result.get("model_info", {})
|
|
|
))
|
|
|
|
|
|
return {"responses": responses}
|
|
|
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=f"Batch generation failed: {str(e)}")
|
|
|
|
|
|
@app.get("/metrics")
|
|
|
async def get_metrics(swarm: SwarmEngine = Depends(get_swarm_engine)):
|
|
|
"""Get system metrics"""
|
|
|
try:
|
|
|
metrics = {
|
|
|
"memory_report": swarm.memory_manager.get_memory_report(),
|
|
|
"swarm_metrics": swarm.get_metrics(),
|
|
|
"inference_stats": swarm.get_inference_stats() if hasattr(swarm, 'get_inference_stats') else {},
|
|
|
"timestamp": time.time()
|
|
|
}
|
|
|
return metrics
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=f"Failed to get metrics: {str(e)}")
|
|
|
|
|
|
@app.post("/admin/reload")
|
|
|
async def reload_model(
|
|
|
background_tasks: BackgroundTasks,
|
|
|
swarm: SwarmEngine = Depends(get_swarm_engine)
|
|
|
):
|
|
|
"""Reload the model (admin endpoint)"""
|
|
|
try:
|
|
|
background_tasks.add_task(swarm.reload_model)
|
|
|
return {"message": "Model reload initiated"}
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=f"Failed to reload model: {str(e)}")
|
|
|
|
|
|
@app.post("/admin/cleanup")
|
|
|
async def cleanup_memory(swarm: SwarmEngine = Depends(get_swarm_engine)):
|
|
|
"""Force memory cleanup (admin endpoint)"""
|
|
|
try:
|
|
|
swarm.memory_manager.cleanup_memory(aggressive=True)
|
|
|
return {"message": "Memory cleanup completed"}
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=f"Failed to cleanup memory: {str(e)}")
|
|
|
|
|
|
|
|
|
@app.exception_handler(HTTPException)
|
|
|
async def http_exception_handler(request, exc):
|
|
|
return {
|
|
|
"error": exc.detail,
|
|
|
"status_code": exc.status_code,
|
|
|
"timestamp": time.time()
|
|
|
}
|
|
|
|
|
|
@app.exception_handler(Exception)
|
|
|
async def general_exception_handler(request, exc):
|
|
|
logging.error(f"Unhandled exception: {exc}")
|
|
|
return {
|
|
|
"error": "Internal server error",
|
|
|
"status_code": 500,
|
|
|
"timestamp": time.time()
|
|
|
}
|
|
|
|
|
|
def run_server(host: str = "0.0.0.0", port: int = 8000, workers: int = 1):
|
|
|
"""Run the API server"""
|
|
|
setup_logging()
|
|
|
|
|
|
config = uvicorn.Config(
|
|
|
app=app,
|
|
|
host=host,
|
|
|
port=port,
|
|
|
workers=workers,
|
|
|
log_level="info",
|
|
|
access_log=True,
|
|
|
reload=False
|
|
|
)
|
|
|
|
|
|
server = uvicorn.Server(config)
|
|
|
server.run()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
run_server() |