Andrew McCracken
Claude
commited on
Commit
·
8cfe5b7
1
Parent(s):
3f2ee19
Optimize for faster inference
Browse filesPerformance optimizations:
1. Reduced max_tokens from 512 to 256 (faster responses)
2. Reduced n_ctx from 4096 to 2048 (faster prompt processing)
3. Added token buffering in streaming (better perceived speed)
- Buffers 3 tokens or until whitespace
- Reduces network overhead
Expected speedup: 15s → 8-10s per response
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
- llm_handler.py +1 -1
- main.py +15 -4
llm_handler.py
CHANGED
|
@@ -51,7 +51,7 @@ class CybersecurityLLM:
|
|
| 51 |
logger.info("Initializing model...")
|
| 52 |
self.llm = Llama(
|
| 53 |
model_path=model_path,
|
| 54 |
-
n_ctx=
|
| 55 |
n_batch=512, # Batch size for prompt processing
|
| 56 |
n_threads=8, # Use all 8 vCPUs for maximum inference speed
|
| 57 |
n_gpu_layers=0, # CPU only
|
|
|
|
| 51 |
logger.info("Initializing model...")
|
| 52 |
self.llm = Llama(
|
| 53 |
model_path=model_path,
|
| 54 |
+
n_ctx=2048, # Reduced context window for faster prompt processing
|
| 55 |
n_batch=512, # Batch size for prompt processing
|
| 56 |
n_threads=8, # Use all 8 vCPUs for maximum inference speed
|
| 57 |
n_gpu_layers=0, # CPU only
|
main.py
CHANGED
|
@@ -158,7 +158,7 @@ app.add_middleware(
|
|
| 158 |
class ChatRequest(BaseModel):
|
| 159 |
message: str = Field(..., description="User's security question")
|
| 160 |
session_id: Optional[str] = Field(None, description="Session ID for conversation continuity")
|
| 161 |
-
max_tokens: Optional[int] = Field(
|
| 162 |
temperature: Optional[float] = Field(0.7, description="Response creativity (0-1)")
|
| 163 |
use_rag: Optional[bool] = Field(True, description="Use RAG for enhanced accuracy")
|
| 164 |
use_cache: Optional[bool] = Field(True, description="Use cached responses if available")
|
|
@@ -328,18 +328,29 @@ async def chat_stream(request: ChatRequest):
|
|
| 328 |
async def generate():
|
| 329 |
try:
|
| 330 |
full_response = ""
|
|
|
|
|
|
|
| 331 |
|
| 332 |
# Send initial metadata
|
| 333 |
yield f"data: {json.dumps({'type': 'start', 'session_id': session_id, 'model': MODEL_REPO, 'interaction_count': count})}\n\n"
|
| 334 |
|
| 335 |
-
# Stream tokens
|
| 336 |
for token in llm_instance.generate_stream(
|
| 337 |
request.message,
|
| 338 |
max_tokens=request.max_tokens
|
| 339 |
):
|
| 340 |
full_response += token
|
| 341 |
-
|
| 342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
|
| 344 |
# Log interaction
|
| 345 |
log_interaction(session_id, request.message, len(full_response))
|
|
|
|
| 158 |
class ChatRequest(BaseModel):
|
| 159 |
message: str = Field(..., description="User's security question")
|
| 160 |
session_id: Optional[str] = Field(None, description="Session ID for conversation continuity")
|
| 161 |
+
max_tokens: Optional[int] = Field(256, description="Maximum response length")
|
| 162 |
temperature: Optional[float] = Field(0.7, description="Response creativity (0-1)")
|
| 163 |
use_rag: Optional[bool] = Field(True, description="Use RAG for enhanced accuracy")
|
| 164 |
use_cache: Optional[bool] = Field(True, description="Use cached responses if available")
|
|
|
|
| 328 |
async def generate():
|
| 329 |
try:
|
| 330 |
full_response = ""
|
| 331 |
+
buffer = ""
|
| 332 |
+
buffer_size = 3 # Send every 3 tokens for better perceived speed
|
| 333 |
|
| 334 |
# Send initial metadata
|
| 335 |
yield f"data: {json.dumps({'type': 'start', 'session_id': session_id, 'model': MODEL_REPO, 'interaction_count': count})}\n\n"
|
| 336 |
|
| 337 |
+
# Stream tokens with buffering
|
| 338 |
for token in llm_instance.generate_stream(
|
| 339 |
request.message,
|
| 340 |
max_tokens=request.max_tokens
|
| 341 |
):
|
| 342 |
full_response += token
|
| 343 |
+
buffer += token
|
| 344 |
+
|
| 345 |
+
# Send buffer when it reaches buffer_size or contains whitespace
|
| 346 |
+
if len(buffer) >= buffer_size or ' ' in token or '\n' in token:
|
| 347 |
+
yield f"data: {json.dumps({'type': 'token', 'content': buffer})}\n\n"
|
| 348 |
+
buffer = ""
|
| 349 |
+
await asyncio.sleep(0)
|
| 350 |
+
|
| 351 |
+
# Send any remaining buffered tokens
|
| 352 |
+
if buffer:
|
| 353 |
+
yield f"data: {json.dumps({'type': 'token', 'content': buffer})}\n\n"
|
| 354 |
|
| 355 |
# Log interaction
|
| 356 |
log_interaction(session_id, request.message, len(full_response))
|