Andrew McCracken Claude commited on
Commit
efd4459
Β·
1 Parent(s): bfa102d

Add concurrent request handling with model pool

Browse files

Implemented ModelPool for true concurrent processing:
- Created ModelPool class with thread-safe queue
- Initializes 10 model instances (configurable via MODEL_POOL_SIZE)
- Each instance can handle one request simultaneously
- Automatic model checkout/return from pool
- Added pool statistics to /health endpoint

Configuration:
- MODEL_POOL_SIZE=10 (supports 10 concurrent users)
- 60s timeout if all instances busy
- Each model instance ~2.4GB VRAM
- Total VRAM: ~24GB for 10 instances (fits in 48GB GPU)

Sessions are handled via session_id parameter (already present)
Pool automatically balances load across instances

πŸ€– Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (2) hide show
  1. Dockerfile.gpu +5 -2
  2. main.py +111 -18
Dockerfile.gpu CHANGED
@@ -1,6 +1,6 @@
1
  # Use pre-built GPU image from Docker Hub
2
- # Build this image locally with: docker buildx build --platform linux/amd64 -f Dockerfile.base.gpu -t techdaskalos/cybersecchatbot:gpu . --push
3
- FROM techdaskalos/cybersecchatbot:gpu
4
 
5
  # Environment variables (already set in base image, but can override)
6
  ENV PYTHONUNBUFFERED=1
@@ -12,6 +12,9 @@ ENV CACHE_ENABLED=true
12
  # GPU configuration - offload all layers to GPU
13
  ENV N_GPU_LAYERS=35
14
 
 
 
 
15
  # Set Hugging Face cache to /data for persistence and write permissions
16
  ENV HF_HOME=/data/huggingface
17
 
 
1
  # Use pre-built GPU image from Docker Hub
2
+ # Build this image locally with: docker buildx build --platform linux/amd64 -f Dockerfile.base.gpu -t techdaskalos/cybersecchatbot:latest-gpu . --push
3
+ FROM techdaskalos/cybersecchatbot:latest-gpu
4
 
5
  # Environment variables (already set in base image, but can override)
6
  ENV PYTHONUNBUFFERED=1
 
12
  # GPU configuration - offload all layers to GPU
13
  ENV N_GPU_LAYERS=35
14
 
15
+ # Concurrent request handling - 10 model instances for 10 concurrent users
16
+ ENV MODEL_POOL_SIZE=10
17
+
18
  # Set Hugging Face cache to /data for persistence and write permissions
19
  ENV HF_HOME=/data/huggingface
20
 
main.py CHANGED
@@ -10,22 +10,98 @@ import uuid
10
  import os
11
  import sqlite3
12
  from contextlib import asynccontextmanager
 
 
13
 
14
  # Import our handlers
15
  from llm_handler import CybersecurityLLM
16
  from knowledge_base import RAGCybersecurityLLM
17
  from optimisations import PerformanceOptimizer, MemoryManager
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # Configuration from environment variables
20
  MODEL_REPO = os.getenv("MODEL_REPO", "daskalos-apps/phi4-cybersec-Q4_K_M")
21
  MODEL_FILENAME = os.getenv("MODEL_FILENAME", "phi4-mini-instruct-Q4_K_M.gguf")
22
  USE_RAG = os.getenv("USE_RAG", "true").lower() == "true"
23
  CACHE_ENABLED = os.getenv("CACHE_ENABLED", "true").lower() == "true"
 
24
 
25
  # Global instances
26
  llm_instance = None
27
  optimizer = None
28
  memory_manager = None
 
29
 
30
  # Database setup
31
  # Support multiple deployment platforms: /data (HF Spaces), /app/data (Render/Railway), or local
@@ -94,26 +170,31 @@ def log_interaction(session_id: str, message: str, response_length: int):
94
  @asynccontextmanager
95
  async def lifespan(app: FastAPI):
96
  """Startup and shutdown events"""
97
- global llm_instance, optimizer, memory_manager
98
 
99
  # Startup
100
  print(f"πŸš€ Loading model from Hugging Face: {MODEL_REPO}")
 
101
 
102
  # Initialize database
103
  init_db()
104
  print("βœ… Database initialized")
105
 
106
  try:
107
- if USE_RAG:
108
- llm_instance = RAGCybersecurityLLM(
109
- repo_id=MODEL_REPO,
110
- filename=MODEL_FILENAME
111
- )
112
- else:
113
- llm_instance = CybersecurityLLM(
114
- repo_id=MODEL_REPO,
115
- filename=MODEL_FILENAME
116
- )
 
 
 
 
117
 
118
  if CACHE_ENABLED:
119
  optimizer = PerformanceOptimizer()
@@ -125,6 +206,7 @@ async def lifespan(app: FastAPI):
125
  print(f"πŸ’Ύ Size: {llm_instance.get_model_info()['size_mb']:.2f} MB")
126
  print(f"πŸ”§ RAG: {'Enabled' if USE_RAG else 'Disabled'}")
127
  print(f"⚑ Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}")
 
128
 
129
  except Exception as e:
130
  print(f"❌ Failed to load model: {e}")
@@ -204,6 +286,7 @@ async def health_check():
204
  raise HTTPException(status_code=503, detail="Model not loaded")
205
 
206
  memory_status = memory_manager.check_memory() if memory_manager else {}
 
207
 
208
  return {
209
  "status": "healthy",
@@ -211,7 +294,8 @@ async def health_check():
211
  "version": "2.0.0",
212
  "memory": memory_status,
213
  "cache_enabled": CACHE_ENABLED,
214
- "rag_enabled": USE_RAG
 
215
  }
216
 
217
 
@@ -317,23 +401,28 @@ async def chat(request: ChatRequest):
317
 
318
  @app.post("/chat/stream")
319
  async def chat_stream(request: ChatRequest):
320
- """Streaming chat endpoint"""
321
- if llm_instance is None:
322
- raise HTTPException(status_code=503, detail="Model not loaded")
323
 
324
  # Track interaction
325
  count = increment_interaction()
326
  session_id = request.session_id or str(uuid.uuid4())
327
 
328
  async def generate():
 
329
  try:
330
  full_response = ""
331
 
332
- # Send initial metadata
333
- yield f"data: {json.dumps({'type': 'start', 'session_id': session_id, 'model': MODEL_REPO, 'interaction_count': count})}\n\n"
 
 
 
 
334
 
335
  # Stream tokens
336
- for token in llm_instance.generate_stream(
337
  request.message,
338
  max_tokens=request.max_tokens
339
  ):
@@ -348,6 +437,10 @@ async def chat_stream(request: ChatRequest):
348
 
349
  except Exception as e:
350
  yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
 
 
 
 
351
 
352
  return StreamingResponse(generate(), media_type="text/event-stream")
353
 
 
10
  import os
11
  import sqlite3
12
  from contextlib import asynccontextmanager
13
+ import queue
14
+ import threading
15
 
16
  # Import our handlers
17
  from llm_handler import CybersecurityLLM
18
  from knowledge_base import RAGCybersecurityLLM
19
  from optimisations import PerformanceOptimizer, MemoryManager
20
 
21
+
22
+ class ModelPool:
23
+ """Thread-safe pool of model instances for concurrent request handling"""
24
+
25
+ def __init__(self, pool_size: int, model_class, **model_kwargs):
26
+ """
27
+ Initialize a pool of model instances
28
+
29
+ Args:
30
+ pool_size: Number of model instances to create
31
+ model_class: The model class to instantiate (CybersecurityLLM or RAGCybersecurityLLM)
32
+ **model_kwargs: Arguments to pass to each model instance
33
+ """
34
+ self.pool_size = pool_size
35
+ self.model_class = model_class
36
+ self.model_kwargs = model_kwargs
37
+ self.pool = queue.Queue(maxsize=pool_size)
38
+ self.lock = threading.Lock()
39
+ self._initialize_pool()
40
+
41
+ def _initialize_pool(self):
42
+ """Create and add model instances to the pool"""
43
+ print(f"πŸ”„ Initializing model pool with {self.pool_size} instances...")
44
+ for i in range(self.pool_size):
45
+ print(f" Loading model instance {i + 1}/{self.pool_size}...")
46
+ model = self.model_class(**self.model_kwargs)
47
+ self.pool.put(model)
48
+ print(f"βœ… Model pool ready with {self.pool_size} instances")
49
+
50
+ async def get_model(self, timeout: float = 30.0):
51
+ """
52
+ Get an available model from the pool (async)
53
+
54
+ Args:
55
+ timeout: Maximum time to wait for an available model
56
+
57
+ Returns:
58
+ Model instance
59
+
60
+ Raises:
61
+ HTTPException: If no model available within timeout
62
+ """
63
+ start_time = asyncio.get_event_loop().time()
64
+
65
+ while True:
66
+ try:
67
+ # Try to get a model without blocking
68
+ model = self.pool.get_nowait()
69
+ return model
70
+ except queue.Empty:
71
+ # Check timeout
72
+ if asyncio.get_event_loop().time() - start_time > timeout:
73
+ raise HTTPException(
74
+ status_code=503,
75
+ detail=f"All {self.pool_size} model instances are busy. Please try again later."
76
+ )
77
+
78
+ # Wait a bit before trying again
79
+ await asyncio.sleep(0.1)
80
+
81
+ def return_model(self, model):
82
+ """Return a model to the pool"""
83
+ self.pool.put(model)
84
+
85
+ def get_stats(self) -> Dict[str, Any]:
86
+ """Get pool statistics"""
87
+ return {
88
+ "pool_size": self.pool_size,
89
+ "available": self.pool.qsize(),
90
+ "in_use": self.pool_size - self.pool.qsize()
91
+ }
92
+
93
  # Configuration from environment variables
94
  MODEL_REPO = os.getenv("MODEL_REPO", "daskalos-apps/phi4-cybersec-Q4_K_M")
95
  MODEL_FILENAME = os.getenv("MODEL_FILENAME", "phi4-mini-instruct-Q4_K_M.gguf")
96
  USE_RAG = os.getenv("USE_RAG", "true").lower() == "true"
97
  CACHE_ENABLED = os.getenv("CACHE_ENABLED", "true").lower() == "true"
98
+ MODEL_POOL_SIZE = int(os.getenv("MODEL_POOL_SIZE", "10")) # Number of concurrent model instances
99
 
100
  # Global instances
101
  llm_instance = None
102
  optimizer = None
103
  memory_manager = None
104
+ model_pool = None # Pool of model instances for concurrent processing
105
 
106
  # Database setup
107
  # Support multiple deployment platforms: /data (HF Spaces), /app/data (Render/Railway), or local
 
170
  @asynccontextmanager
171
  async def lifespan(app: FastAPI):
172
  """Startup and shutdown events"""
173
+ global llm_instance, optimizer, memory_manager, model_pool
174
 
175
  # Startup
176
  print(f"πŸš€ Loading model from Hugging Face: {MODEL_REPO}")
177
+ print(f"πŸ“Š Concurrent instances: {MODEL_POOL_SIZE}")
178
 
179
  # Initialize database
180
  init_db()
181
  print("βœ… Database initialized")
182
 
183
  try:
184
+ # Initialize model pool for concurrent requests
185
+ model_class = RAGCybersecurityLLM if USE_RAG else CybersecurityLLM
186
+ model_pool = ModelPool(
187
+ pool_size=MODEL_POOL_SIZE,
188
+ model_class=model_class,
189
+ repo_id=MODEL_REPO,
190
+ filename=MODEL_FILENAME
191
+ )
192
+
193
+ # Keep one instance for backward compatibility (health checks, etc.)
194
+ llm_instance = model_class(
195
+ repo_id=MODEL_REPO,
196
+ filename=MODEL_FILENAME
197
+ )
198
 
199
  if CACHE_ENABLED:
200
  optimizer = PerformanceOptimizer()
 
206
  print(f"πŸ’Ύ Size: {llm_instance.get_model_info()['size_mb']:.2f} MB")
207
  print(f"πŸ”§ RAG: {'Enabled' if USE_RAG else 'Disabled'}")
208
  print(f"⚑ Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}")
209
+ print(f"πŸ‘₯ Concurrent capacity: {MODEL_POOL_SIZE} users")
210
 
211
  except Exception as e:
212
  print(f"❌ Failed to load model: {e}")
 
286
  raise HTTPException(status_code=503, detail="Model not loaded")
287
 
288
  memory_status = memory_manager.check_memory() if memory_manager else {}
289
+ pool_status = model_pool.get_stats() if model_pool else {"pool_size": 0, "available": 0, "in_use": 0}
290
 
291
  return {
292
  "status": "healthy",
 
294
  "version": "2.0.0",
295
  "memory": memory_status,
296
  "cache_enabled": CACHE_ENABLED,
297
+ "rag_enabled": USE_RAG,
298
+ "concurrent_capacity": pool_status
299
  }
300
 
301
 
 
401
 
402
  @app.post("/chat/stream")
403
  async def chat_stream(request: ChatRequest):
404
+ """Streaming chat endpoint with concurrent request support"""
405
+ if model_pool is None:
406
+ raise HTTPException(status_code=503, detail="Model pool not initialized")
407
 
408
  # Track interaction
409
  count = increment_interaction()
410
  session_id = request.session_id or str(uuid.uuid4())
411
 
412
  async def generate():
413
+ model = None
414
  try:
415
  full_response = ""
416
 
417
+ # Get a model from the pool (will wait if all busy)
418
+ model = await model_pool.get_model(timeout=60.0)
419
+
420
+ # Send initial metadata with pool stats
421
+ pool_stats = model_pool.get_stats()
422
+ yield f"data: {json.dumps({{'type': 'start', 'session_id': session_id, 'model': MODEL_REPO, 'interaction_count': count, 'pool_available': pool_stats['available']})}\n\n"
423
 
424
  # Stream tokens
425
+ for token in model.generate_stream(
426
  request.message,
427
  max_tokens=request.max_tokens
428
  ):
 
437
 
438
  except Exception as e:
439
  yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
440
+ finally:
441
+ # Always return the model to the pool
442
+ if model is not None:
443
+ model_pool.return_model(model)
444
 
445
  return StreamingResponse(generate(), media_type="text/event-stream")
446