Andrew McCracken Claude commited on
Commit
8cfe5b7
·
1 Parent(s): 3f2ee19

Optimize for faster inference

Browse files

Performance optimizations:
1. Reduced max_tokens from 512 to 256 (faster responses)
2. Reduced n_ctx from 4096 to 2048 (faster prompt processing)
3. Added token buffering in streaming (better perceived speed)
- Buffers 3 tokens or until whitespace
- Reduces network overhead

Expected speedup: 15s → 8-10s per response

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (2) hide show
  1. llm_handler.py +1 -1
  2. main.py +15 -4
llm_handler.py CHANGED
@@ -51,7 +51,7 @@ class CybersecurityLLM:
51
  logger.info("Initializing model...")
52
  self.llm = Llama(
53
  model_path=model_path,
54
- n_ctx=4096, # Context window
55
  n_batch=512, # Batch size for prompt processing
56
  n_threads=8, # Use all 8 vCPUs for maximum inference speed
57
  n_gpu_layers=0, # CPU only
 
51
  logger.info("Initializing model...")
52
  self.llm = Llama(
53
  model_path=model_path,
54
+ n_ctx=2048, # Reduced context window for faster prompt processing
55
  n_batch=512, # Batch size for prompt processing
56
  n_threads=8, # Use all 8 vCPUs for maximum inference speed
57
  n_gpu_layers=0, # CPU only
main.py CHANGED
@@ -158,7 +158,7 @@ app.add_middleware(
158
  class ChatRequest(BaseModel):
159
  message: str = Field(..., description="User's security question")
160
  session_id: Optional[str] = Field(None, description="Session ID for conversation continuity")
161
- max_tokens: Optional[int] = Field(512, description="Maximum response length")
162
  temperature: Optional[float] = Field(0.7, description="Response creativity (0-1)")
163
  use_rag: Optional[bool] = Field(True, description="Use RAG for enhanced accuracy")
164
  use_cache: Optional[bool] = Field(True, description="Use cached responses if available")
@@ -328,18 +328,29 @@ async def chat_stream(request: ChatRequest):
328
  async def generate():
329
  try:
330
  full_response = ""
 
 
331
 
332
  # Send initial metadata
333
  yield f"data: {json.dumps({'type': 'start', 'session_id': session_id, 'model': MODEL_REPO, 'interaction_count': count})}\n\n"
334
 
335
- # Stream tokens
336
  for token in llm_instance.generate_stream(
337
  request.message,
338
  max_tokens=request.max_tokens
339
  ):
340
  full_response += token
341
- yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n"
342
- await asyncio.sleep(0)
 
 
 
 
 
 
 
 
 
343
 
344
  # Log interaction
345
  log_interaction(session_id, request.message, len(full_response))
 
158
  class ChatRequest(BaseModel):
159
  message: str = Field(..., description="User's security question")
160
  session_id: Optional[str] = Field(None, description="Session ID for conversation continuity")
161
+ max_tokens: Optional[int] = Field(256, description="Maximum response length")
162
  temperature: Optional[float] = Field(0.7, description="Response creativity (0-1)")
163
  use_rag: Optional[bool] = Field(True, description="Use RAG for enhanced accuracy")
164
  use_cache: Optional[bool] = Field(True, description="Use cached responses if available")
 
328
  async def generate():
329
  try:
330
  full_response = ""
331
+ buffer = ""
332
+ buffer_size = 3 # Send every 3 tokens for better perceived speed
333
 
334
  # Send initial metadata
335
  yield f"data: {json.dumps({'type': 'start', 'session_id': session_id, 'model': MODEL_REPO, 'interaction_count': count})}\n\n"
336
 
337
+ # Stream tokens with buffering
338
  for token in llm_instance.generate_stream(
339
  request.message,
340
  max_tokens=request.max_tokens
341
  ):
342
  full_response += token
343
+ buffer += token
344
+
345
+ # Send buffer when it reaches buffer_size or contains whitespace
346
+ if len(buffer) >= buffer_size or ' ' in token or '\n' in token:
347
+ yield f"data: {json.dumps({'type': 'token', 'content': buffer})}\n\n"
348
+ buffer = ""
349
+ await asyncio.sleep(0)
350
+
351
+ # Send any remaining buffered tokens
352
+ if buffer:
353
+ yield f"data: {json.dumps({'type': 'token', 'content': buffer})}\n\n"
354
 
355
  # Log interaction
356
  log_interaction(session_id, request.message, len(full_response))