Spaces:
Paused
Paused
File size: 4,303 Bytes
d8328bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
"""Simple HTTP server for the NexaSci model to enable sharing across processes."""
from __future__ import annotations
import json
import sys
from pathlib import Path
from typing import Any, Dict, List
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from .client_llm import Message, NexaSciModelClient
# Add project root to path if running as module
if __name__ == "__main__" or "agent.model_server" in sys.modules:
project_root = Path(__file__).resolve().parents[1]
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
app = FastAPI(title="NexaSci Model Server", version="0.1.0")
# Global model client (loaded once)
_model_client: NexaSciModelClient | None = None
class GenerateRequest(BaseModel):
messages: List[Dict[str, str]]
max_new_tokens: int | None = None
temperature: float | None = None
top_p: float | None = None
class GenerateResponse(BaseModel):
text: str
model_loaded: bool
@app.on_event("startup")
async def load_model() -> None:
"""Load the model when the server starts."""
global _model_client
import time
print("=" * 80)
print("Loading NexaSci model (this may take 30-60 seconds)...")
print("=" * 80)
print("Step 1: Loading tokenizer...")
start_time = time.time()
try:
# Set tokenizers parallelism to avoid warnings
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
_model_client = NexaSciModelClient()
elapsed = time.time() - start_time
print(f"✓ Model loaded successfully in {elapsed:.1f}s")
if torch.cuda.is_available():
print(f"✓ GPU: {torch.cuda.get_device_name(0)}")
total_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
allocated = torch.cuda.memory_allocated(0) / (1024**3)
print(f"✓ GPU Memory: {allocated:.1f} GB / {total_mem:.1f} GB allocated")
print("=" * 80)
print("Model server ready! Listening on http://0.0.0.0:8001")
print("=" * 80)
except Exception as e:
elapsed = time.time() - start_time
print(f"✗ Failed to load model after {elapsed:.1f}s: {e}")
import traceback
traceback.print_exc()
raise
@app.get("/health")
async def health_check() -> Dict[str, Any]:
"""Health check endpoint."""
gpu_available = torch.cuda.is_available()
result = {
"status": "healthy",
"model_loaded": _model_client is not None,
"gpu_available": gpu_available,
}
if gpu_available and _model_client is not None:
# Check if model is actually on GPU
try:
model_device = next(_model_client.model.parameters()).device
result["model_device"] = str(model_device)
result["gpu_name"] = torch.cuda.get_device_name(0)
result["gpu_memory_allocated_gb"] = round(torch.cuda.memory_allocated(0) / (1024**3), 2)
result["gpu_memory_total_gb"] = round(torch.cuda.get_device_properties(0).total_memory / (1024**3), 2)
except Exception as e:
result["model_device_check_error"] = str(e)
return result
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest) -> GenerateResponse:
"""Generate text from the model."""
if _model_client is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
messages = [Message(role=msg["role"], content=msg["content"]) for msg in request.messages]
text = _model_client.generate(
messages,
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
top_p=request.top_p,
)
return GenerateResponse(text=text, model_loaded=True)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
@app.get("/tools")
async def list_tools() -> Dict[str, List[str]]:
"""List available tools."""
if _model_client is None:
return {"tools": []}
return {"tools": list(_model_client.available_tools)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)
|