Andrew McCracken
Claude
commited on
Commit
Β·
efd4459
1
Parent(s):
bfa102d
Add concurrent request handling with model pool
Browse filesImplemented ModelPool for true concurrent processing:
- Created ModelPool class with thread-safe queue
- Initializes 10 model instances (configurable via MODEL_POOL_SIZE)
- Each instance can handle one request simultaneously
- Automatic model checkout/return from pool
- Added pool statistics to /health endpoint
Configuration:
- MODEL_POOL_SIZE=10 (supports 10 concurrent users)
- 60s timeout if all instances busy
- Each model instance ~2.4GB VRAM
- Total VRAM: ~24GB for 10 instances (fits in 48GB GPU)
Sessions are handled via session_id parameter (already present)
Pool automatically balances load across instances
π€ Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
- Dockerfile.gpu +5 -2
- main.py +111 -18
Dockerfile.gpu
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# Use pre-built GPU image from Docker Hub
|
| 2 |
-
# Build this image locally with: docker buildx build --platform linux/amd64 -f Dockerfile.base.gpu -t techdaskalos/cybersecchatbot:gpu . --push
|
| 3 |
-
FROM techdaskalos/cybersecchatbot:gpu
|
| 4 |
|
| 5 |
# Environment variables (already set in base image, but can override)
|
| 6 |
ENV PYTHONUNBUFFERED=1
|
|
@@ -12,6 +12,9 @@ ENV CACHE_ENABLED=true
|
|
| 12 |
# GPU configuration - offload all layers to GPU
|
| 13 |
ENV N_GPU_LAYERS=35
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
# Set Hugging Face cache to /data for persistence and write permissions
|
| 16 |
ENV HF_HOME=/data/huggingface
|
| 17 |
|
|
|
|
| 1 |
# Use pre-built GPU image from Docker Hub
|
| 2 |
+
# Build this image locally with: docker buildx build --platform linux/amd64 -f Dockerfile.base.gpu -t techdaskalos/cybersecchatbot:latest-gpu . --push
|
| 3 |
+
FROM techdaskalos/cybersecchatbot:latest-gpu
|
| 4 |
|
| 5 |
# Environment variables (already set in base image, but can override)
|
| 6 |
ENV PYTHONUNBUFFERED=1
|
|
|
|
| 12 |
# GPU configuration - offload all layers to GPU
|
| 13 |
ENV N_GPU_LAYERS=35
|
| 14 |
|
| 15 |
+
# Concurrent request handling - 10 model instances for 10 concurrent users
|
| 16 |
+
ENV MODEL_POOL_SIZE=10
|
| 17 |
+
|
| 18 |
# Set Hugging Face cache to /data for persistence and write permissions
|
| 19 |
ENV HF_HOME=/data/huggingface
|
| 20 |
|
main.py
CHANGED
|
@@ -10,22 +10,98 @@ import uuid
|
|
| 10 |
import os
|
| 11 |
import sqlite3
|
| 12 |
from contextlib import asynccontextmanager
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# Import our handlers
|
| 15 |
from llm_handler import CybersecurityLLM
|
| 16 |
from knowledge_base import RAGCybersecurityLLM
|
| 17 |
from optimisations import PerformanceOptimizer, MemoryManager
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
# Configuration from environment variables
|
| 20 |
MODEL_REPO = os.getenv("MODEL_REPO", "daskalos-apps/phi4-cybersec-Q4_K_M")
|
| 21 |
MODEL_FILENAME = os.getenv("MODEL_FILENAME", "phi4-mini-instruct-Q4_K_M.gguf")
|
| 22 |
USE_RAG = os.getenv("USE_RAG", "true").lower() == "true"
|
| 23 |
CACHE_ENABLED = os.getenv("CACHE_ENABLED", "true").lower() == "true"
|
|
|
|
| 24 |
|
| 25 |
# Global instances
|
| 26 |
llm_instance = None
|
| 27 |
optimizer = None
|
| 28 |
memory_manager = None
|
|
|
|
| 29 |
|
| 30 |
# Database setup
|
| 31 |
# Support multiple deployment platforms: /data (HF Spaces), /app/data (Render/Railway), or local
|
|
@@ -94,26 +170,31 @@ def log_interaction(session_id: str, message: str, response_length: int):
|
|
| 94 |
@asynccontextmanager
|
| 95 |
async def lifespan(app: FastAPI):
|
| 96 |
"""Startup and shutdown events"""
|
| 97 |
-
global llm_instance, optimizer, memory_manager
|
| 98 |
|
| 99 |
# Startup
|
| 100 |
print(f"π Loading model from Hugging Face: {MODEL_REPO}")
|
|
|
|
| 101 |
|
| 102 |
# Initialize database
|
| 103 |
init_db()
|
| 104 |
print("β
Database initialized")
|
| 105 |
|
| 106 |
try:
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
if CACHE_ENABLED:
|
| 119 |
optimizer = PerformanceOptimizer()
|
|
@@ -125,6 +206,7 @@ async def lifespan(app: FastAPI):
|
|
| 125 |
print(f"πΎ Size: {llm_instance.get_model_info()['size_mb']:.2f} MB")
|
| 126 |
print(f"π§ RAG: {'Enabled' if USE_RAG else 'Disabled'}")
|
| 127 |
print(f"β‘ Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}")
|
|
|
|
| 128 |
|
| 129 |
except Exception as e:
|
| 130 |
print(f"β Failed to load model: {e}")
|
|
@@ -204,6 +286,7 @@ async def health_check():
|
|
| 204 |
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 205 |
|
| 206 |
memory_status = memory_manager.check_memory() if memory_manager else {}
|
|
|
|
| 207 |
|
| 208 |
return {
|
| 209 |
"status": "healthy",
|
|
@@ -211,7 +294,8 @@ async def health_check():
|
|
| 211 |
"version": "2.0.0",
|
| 212 |
"memory": memory_status,
|
| 213 |
"cache_enabled": CACHE_ENABLED,
|
| 214 |
-
"rag_enabled": USE_RAG
|
|
|
|
| 215 |
}
|
| 216 |
|
| 217 |
|
|
@@ -317,23 +401,28 @@ async def chat(request: ChatRequest):
|
|
| 317 |
|
| 318 |
@app.post("/chat/stream")
|
| 319 |
async def chat_stream(request: ChatRequest):
|
| 320 |
-
"""Streaming chat endpoint"""
|
| 321 |
-
if
|
| 322 |
-
raise HTTPException(status_code=503, detail="Model not
|
| 323 |
|
| 324 |
# Track interaction
|
| 325 |
count = increment_interaction()
|
| 326 |
session_id = request.session_id or str(uuid.uuid4())
|
| 327 |
|
| 328 |
async def generate():
|
|
|
|
| 329 |
try:
|
| 330 |
full_response = ""
|
| 331 |
|
| 332 |
-
#
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
|
| 335 |
# Stream tokens
|
| 336 |
-
for token in
|
| 337 |
request.message,
|
| 338 |
max_tokens=request.max_tokens
|
| 339 |
):
|
|
@@ -348,6 +437,10 @@ async def chat_stream(request: ChatRequest):
|
|
| 348 |
|
| 349 |
except Exception as e:
|
| 350 |
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
|
| 352 |
return StreamingResponse(generate(), media_type="text/event-stream")
|
| 353 |
|
|
|
|
| 10 |
import os
|
| 11 |
import sqlite3
|
| 12 |
from contextlib import asynccontextmanager
|
| 13 |
+
import queue
|
| 14 |
+
import threading
|
| 15 |
|
| 16 |
# Import our handlers
|
| 17 |
from llm_handler import CybersecurityLLM
|
| 18 |
from knowledge_base import RAGCybersecurityLLM
|
| 19 |
from optimisations import PerformanceOptimizer, MemoryManager
|
| 20 |
|
| 21 |
+
|
| 22 |
+
class ModelPool:
|
| 23 |
+
"""Thread-safe pool of model instances for concurrent request handling"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, pool_size: int, model_class, **model_kwargs):
|
| 26 |
+
"""
|
| 27 |
+
Initialize a pool of model instances
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
pool_size: Number of model instances to create
|
| 31 |
+
model_class: The model class to instantiate (CybersecurityLLM or RAGCybersecurityLLM)
|
| 32 |
+
**model_kwargs: Arguments to pass to each model instance
|
| 33 |
+
"""
|
| 34 |
+
self.pool_size = pool_size
|
| 35 |
+
self.model_class = model_class
|
| 36 |
+
self.model_kwargs = model_kwargs
|
| 37 |
+
self.pool = queue.Queue(maxsize=pool_size)
|
| 38 |
+
self.lock = threading.Lock()
|
| 39 |
+
self._initialize_pool()
|
| 40 |
+
|
| 41 |
+
def _initialize_pool(self):
|
| 42 |
+
"""Create and add model instances to the pool"""
|
| 43 |
+
print(f"π Initializing model pool with {self.pool_size} instances...")
|
| 44 |
+
for i in range(self.pool_size):
|
| 45 |
+
print(f" Loading model instance {i + 1}/{self.pool_size}...")
|
| 46 |
+
model = self.model_class(**self.model_kwargs)
|
| 47 |
+
self.pool.put(model)
|
| 48 |
+
print(f"β
Model pool ready with {self.pool_size} instances")
|
| 49 |
+
|
| 50 |
+
async def get_model(self, timeout: float = 30.0):
|
| 51 |
+
"""
|
| 52 |
+
Get an available model from the pool (async)
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
timeout: Maximum time to wait for an available model
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
Model instance
|
| 59 |
+
|
| 60 |
+
Raises:
|
| 61 |
+
HTTPException: If no model available within timeout
|
| 62 |
+
"""
|
| 63 |
+
start_time = asyncio.get_event_loop().time()
|
| 64 |
+
|
| 65 |
+
while True:
|
| 66 |
+
try:
|
| 67 |
+
# Try to get a model without blocking
|
| 68 |
+
model = self.pool.get_nowait()
|
| 69 |
+
return model
|
| 70 |
+
except queue.Empty:
|
| 71 |
+
# Check timeout
|
| 72 |
+
if asyncio.get_event_loop().time() - start_time > timeout:
|
| 73 |
+
raise HTTPException(
|
| 74 |
+
status_code=503,
|
| 75 |
+
detail=f"All {self.pool_size} model instances are busy. Please try again later."
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Wait a bit before trying again
|
| 79 |
+
await asyncio.sleep(0.1)
|
| 80 |
+
|
| 81 |
+
def return_model(self, model):
|
| 82 |
+
"""Return a model to the pool"""
|
| 83 |
+
self.pool.put(model)
|
| 84 |
+
|
| 85 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 86 |
+
"""Get pool statistics"""
|
| 87 |
+
return {
|
| 88 |
+
"pool_size": self.pool_size,
|
| 89 |
+
"available": self.pool.qsize(),
|
| 90 |
+
"in_use": self.pool_size - self.pool.qsize()
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
# Configuration from environment variables
|
| 94 |
MODEL_REPO = os.getenv("MODEL_REPO", "daskalos-apps/phi4-cybersec-Q4_K_M")
|
| 95 |
MODEL_FILENAME = os.getenv("MODEL_FILENAME", "phi4-mini-instruct-Q4_K_M.gguf")
|
| 96 |
USE_RAG = os.getenv("USE_RAG", "true").lower() == "true"
|
| 97 |
CACHE_ENABLED = os.getenv("CACHE_ENABLED", "true").lower() == "true"
|
| 98 |
+
MODEL_POOL_SIZE = int(os.getenv("MODEL_POOL_SIZE", "10")) # Number of concurrent model instances
|
| 99 |
|
| 100 |
# Global instances
|
| 101 |
llm_instance = None
|
| 102 |
optimizer = None
|
| 103 |
memory_manager = None
|
| 104 |
+
model_pool = None # Pool of model instances for concurrent processing
|
| 105 |
|
| 106 |
# Database setup
|
| 107 |
# Support multiple deployment platforms: /data (HF Spaces), /app/data (Render/Railway), or local
|
|
|
|
| 170 |
@asynccontextmanager
|
| 171 |
async def lifespan(app: FastAPI):
|
| 172 |
"""Startup and shutdown events"""
|
| 173 |
+
global llm_instance, optimizer, memory_manager, model_pool
|
| 174 |
|
| 175 |
# Startup
|
| 176 |
print(f"π Loading model from Hugging Face: {MODEL_REPO}")
|
| 177 |
+
print(f"π Concurrent instances: {MODEL_POOL_SIZE}")
|
| 178 |
|
| 179 |
# Initialize database
|
| 180 |
init_db()
|
| 181 |
print("β
Database initialized")
|
| 182 |
|
| 183 |
try:
|
| 184 |
+
# Initialize model pool for concurrent requests
|
| 185 |
+
model_class = RAGCybersecurityLLM if USE_RAG else CybersecurityLLM
|
| 186 |
+
model_pool = ModelPool(
|
| 187 |
+
pool_size=MODEL_POOL_SIZE,
|
| 188 |
+
model_class=model_class,
|
| 189 |
+
repo_id=MODEL_REPO,
|
| 190 |
+
filename=MODEL_FILENAME
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# Keep one instance for backward compatibility (health checks, etc.)
|
| 194 |
+
llm_instance = model_class(
|
| 195 |
+
repo_id=MODEL_REPO,
|
| 196 |
+
filename=MODEL_FILENAME
|
| 197 |
+
)
|
| 198 |
|
| 199 |
if CACHE_ENABLED:
|
| 200 |
optimizer = PerformanceOptimizer()
|
|
|
|
| 206 |
print(f"πΎ Size: {llm_instance.get_model_info()['size_mb']:.2f} MB")
|
| 207 |
print(f"π§ RAG: {'Enabled' if USE_RAG else 'Disabled'}")
|
| 208 |
print(f"β‘ Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}")
|
| 209 |
+
print(f"π₯ Concurrent capacity: {MODEL_POOL_SIZE} users")
|
| 210 |
|
| 211 |
except Exception as e:
|
| 212 |
print(f"β Failed to load model: {e}")
|
|
|
|
| 286 |
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 287 |
|
| 288 |
memory_status = memory_manager.check_memory() if memory_manager else {}
|
| 289 |
+
pool_status = model_pool.get_stats() if model_pool else {"pool_size": 0, "available": 0, "in_use": 0}
|
| 290 |
|
| 291 |
return {
|
| 292 |
"status": "healthy",
|
|
|
|
| 294 |
"version": "2.0.0",
|
| 295 |
"memory": memory_status,
|
| 296 |
"cache_enabled": CACHE_ENABLED,
|
| 297 |
+
"rag_enabled": USE_RAG,
|
| 298 |
+
"concurrent_capacity": pool_status
|
| 299 |
}
|
| 300 |
|
| 301 |
|
|
|
|
| 401 |
|
| 402 |
@app.post("/chat/stream")
|
| 403 |
async def chat_stream(request: ChatRequest):
|
| 404 |
+
"""Streaming chat endpoint with concurrent request support"""
|
| 405 |
+
if model_pool is None:
|
| 406 |
+
raise HTTPException(status_code=503, detail="Model pool not initialized")
|
| 407 |
|
| 408 |
# Track interaction
|
| 409 |
count = increment_interaction()
|
| 410 |
session_id = request.session_id or str(uuid.uuid4())
|
| 411 |
|
| 412 |
async def generate():
|
| 413 |
+
model = None
|
| 414 |
try:
|
| 415 |
full_response = ""
|
| 416 |
|
| 417 |
+
# Get a model from the pool (will wait if all busy)
|
| 418 |
+
model = await model_pool.get_model(timeout=60.0)
|
| 419 |
+
|
| 420 |
+
# Send initial metadata with pool stats
|
| 421 |
+
pool_stats = model_pool.get_stats()
|
| 422 |
+
yield f"data: {json.dumps({{'type': 'start', 'session_id': session_id, 'model': MODEL_REPO, 'interaction_count': count, 'pool_available': pool_stats['available']})}\n\n"
|
| 423 |
|
| 424 |
# Stream tokens
|
| 425 |
+
for token in model.generate_stream(
|
| 426 |
request.message,
|
| 427 |
max_tokens=request.max_tokens
|
| 428 |
):
|
|
|
|
| 437 |
|
| 438 |
except Exception as e:
|
| 439 |
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
| 440 |
+
finally:
|
| 441 |
+
# Always return the model to the pool
|
| 442 |
+
if model is not None:
|
| 443 |
+
model_pool.return_model(model)
|
| 444 |
|
| 445 |
return StreamingResponse(generate(), media_type="text/event-stream")
|
| 446 |
|