|
|
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, BackgroundTasks |
|
|
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import Optional, List, Dict, Any |
|
|
from datetime import datetime |
|
|
import asyncio |
|
|
import json |
|
|
import uuid |
|
|
import os |
|
|
import sqlite3 |
|
|
from contextlib import asynccontextmanager |
|
|
import queue |
|
|
import threading |
|
|
|
|
|
|
|
|
from llm_handler import CybersecurityLLM |
|
|
from knowledge_base import RAGCybersecurityLLM |
|
|
from optimisations import PerformanceOptimizer, MemoryManager |
|
|
|
|
|
|
|
|
class ModelPool: |
|
|
"""Thread-safe pool of model instances for concurrent request handling""" |
|
|
|
|
|
def __init__(self, pool_size: int, model_class, **model_kwargs): |
|
|
""" |
|
|
Initialize a pool of model instances |
|
|
|
|
|
Args: |
|
|
pool_size: Number of model instances to create |
|
|
model_class: The model class to instantiate (CybersecurityLLM or RAGCybersecurityLLM) |
|
|
**model_kwargs: Arguments to pass to each model instance |
|
|
""" |
|
|
self.pool_size = pool_size |
|
|
self.model_class = model_class |
|
|
self.model_kwargs = model_kwargs |
|
|
self.pool = queue.Queue(maxsize=pool_size) |
|
|
self.lock = threading.Lock() |
|
|
self._initialize_pool() |
|
|
|
|
|
def _initialize_pool(self): |
|
|
"""Create and add model instances to the pool""" |
|
|
print(f"🔄 Initializing model pool with {self.pool_size} instances...") |
|
|
for i in range(self.pool_size): |
|
|
print(f" Loading model instance {i + 1}/{self.pool_size}...") |
|
|
model = self.model_class(**self.model_kwargs) |
|
|
self.pool.put(model) |
|
|
print(f"✅ Model pool ready with {self.pool_size} instances") |
|
|
|
|
|
async def get_model(self, timeout: float = 30.0): |
|
|
""" |
|
|
Get an available model from the pool (async) |
|
|
|
|
|
Args: |
|
|
timeout: Maximum time to wait for an available model |
|
|
|
|
|
Returns: |
|
|
Model instance |
|
|
|
|
|
Raises: |
|
|
HTTPException: If no model available within timeout |
|
|
""" |
|
|
start_time = asyncio.get_event_loop().time() |
|
|
|
|
|
while True: |
|
|
try: |
|
|
|
|
|
model = self.pool.get_nowait() |
|
|
return model |
|
|
except queue.Empty: |
|
|
|
|
|
if asyncio.get_event_loop().time() - start_time > timeout: |
|
|
raise HTTPException( |
|
|
status_code=503, |
|
|
detail=f"All {self.pool_size} model instances are busy. Please try again later." |
|
|
) |
|
|
|
|
|
|
|
|
await asyncio.sleep(0.1) |
|
|
|
|
|
def return_model(self, model): |
|
|
"""Return a model to the pool""" |
|
|
self.pool.put(model) |
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
|
"""Get pool statistics""" |
|
|
return { |
|
|
"pool_size": self.pool_size, |
|
|
"available": self.pool.qsize(), |
|
|
"in_use": self.pool_size - self.pool.qsize() |
|
|
} |
|
|
|
|
|
|
|
|
MODEL_REPO = os.getenv("MODEL_REPO", "daskalos-apps/phi4-cybersec-Q4_K_M") |
|
|
MODEL_FILENAME = os.getenv("MODEL_FILENAME", "phi4-mini-instruct-Q4_K_M.gguf") |
|
|
USE_RAG = os.getenv("USE_RAG", "true").lower() == "true" |
|
|
CACHE_ENABLED = os.getenv("CACHE_ENABLED", "true").lower() == "true" |
|
|
MODEL_POOL_SIZE = int(os.getenv("MODEL_POOL_SIZE", "10")) |
|
|
|
|
|
|
|
|
llm_instance = None |
|
|
optimizer = None |
|
|
memory_manager = None |
|
|
model_pool = None |
|
|
|
|
|
|
|
|
|
|
|
if os.path.exists("/data"): |
|
|
DB_PATH = "/data/interactions.db" |
|
|
elif os.path.exists("/app/data"): |
|
|
DB_PATH = "/app/data/interactions.db" |
|
|
else: |
|
|
DB_PATH = "interactions.db" |
|
|
|
|
|
def init_db(): |
|
|
"""Initialize SQLite database for interaction tracking""" |
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS interactions ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
timestamp TEXT NOT NULL, |
|
|
session_id TEXT, |
|
|
message TEXT, |
|
|
response_length INTEGER |
|
|
) |
|
|
""") |
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS interaction_count ( |
|
|
id INTEGER PRIMARY KEY CHECK (id = 1), |
|
|
count INTEGER DEFAULT 0 |
|
|
) |
|
|
""") |
|
|
cursor.execute("INSERT OR IGNORE INTO interaction_count (id, count) VALUES (1, 0)") |
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
|
|
|
db_lock = threading.Lock() |
|
|
|
|
|
def increment_interaction(): |
|
|
"""Increment interaction count and return new count (thread-safe)""" |
|
|
with db_lock: |
|
|
conn = sqlite3.connect(DB_PATH, check_same_thread=False, timeout=10.0) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute("UPDATE interaction_count SET count = count + 1 WHERE id = 1") |
|
|
cursor.execute("SELECT count FROM interaction_count WHERE id = 1") |
|
|
count = cursor.fetchone()[0] |
|
|
conn.commit() |
|
|
conn.close() |
|
|
return count |
|
|
|
|
|
def get_interaction_count(): |
|
|
"""Get current interaction count (thread-safe)""" |
|
|
with db_lock: |
|
|
conn = sqlite3.connect(DB_PATH, check_same_thread=False, timeout=10.0) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute("SELECT count FROM interaction_count WHERE id = 1") |
|
|
count = cursor.fetchone()[0] |
|
|
conn.close() |
|
|
return count |
|
|
|
|
|
def log_interaction(session_id: str, message: str, response_length: int): |
|
|
"""Log interaction details (thread-safe)""" |
|
|
with db_lock: |
|
|
conn = sqlite3.connect(DB_PATH, check_same_thread=False, timeout=10.0) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute( |
|
|
"INSERT INTO interactions (timestamp, session_id, message, response_length) VALUES (?, ?, ?, ?)", |
|
|
(datetime.now().isoformat(), session_id, message, response_length) |
|
|
) |
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
"""Startup and shutdown events""" |
|
|
global llm_instance, optimizer, memory_manager, model_pool |
|
|
|
|
|
|
|
|
print(f"🚀 Loading model from Hugging Face: {MODEL_REPO}") |
|
|
print(f"📊 Concurrent instances: {MODEL_POOL_SIZE}") |
|
|
|
|
|
|
|
|
init_db() |
|
|
print("✅ Database initialized") |
|
|
|
|
|
try: |
|
|
|
|
|
model_class = RAGCybersecurityLLM if USE_RAG else CybersecurityLLM |
|
|
model_pool = ModelPool( |
|
|
pool_size=MODEL_POOL_SIZE, |
|
|
model_class=model_class, |
|
|
repo_id=MODEL_REPO, |
|
|
filename=MODEL_FILENAME |
|
|
) |
|
|
|
|
|
|
|
|
llm_instance = model_class( |
|
|
repo_id=MODEL_REPO, |
|
|
filename=MODEL_FILENAME |
|
|
) |
|
|
|
|
|
if CACHE_ENABLED: |
|
|
optimizer = PerformanceOptimizer() |
|
|
|
|
|
memory_manager = MemoryManager() |
|
|
|
|
|
print("✅ Cybersecurity Chatbot ready!") |
|
|
print(f"📦 Model: {MODEL_REPO}") |
|
|
print(f"💾 Size: {llm_instance.get_model_info()['size_mb']:.2f} MB") |
|
|
print(f"🔧 RAG: {'Enabled' if USE_RAG else 'Disabled'}") |
|
|
print(f"⚡ Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}") |
|
|
print(f"👥 Concurrent capacity: {MODEL_POOL_SIZE} users") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Failed to load model: {e}") |
|
|
raise |
|
|
|
|
|
yield |
|
|
|
|
|
|
|
|
print("👋 Shutting down...") |
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Cybersecurity Training Chatbot API", |
|
|
description="AI-powered cybersecurity guidance using Phi-4 from Hugging Face", |
|
|
version="2.0.0", |
|
|
lifespan=lifespan |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
message: str = Field(..., description="User's security question") |
|
|
session_id: Optional[str] = Field(None, description="Session ID for conversation continuity") |
|
|
max_tokens: Optional[int] = Field(256, description="Maximum response length") |
|
|
temperature: Optional[float] = Field(0.7, description="Response creativity (0-1)") |
|
|
use_rag: Optional[bool] = Field(True, description="Use RAG for enhanced accuracy") |
|
|
use_cache: Optional[bool] = Field(True, description="Use cached responses if available") |
|
|
|
|
|
|
|
|
class ChatResponse(BaseModel): |
|
|
response: str |
|
|
session_id: str |
|
|
timestamp: str |
|
|
model: str |
|
|
tokens_used: Optional[int] = None |
|
|
cached: bool = False |
|
|
sources: Optional[List[str]] = None |
|
|
|
|
|
|
|
|
class ModelInfo(BaseModel): |
|
|
repo_id: str |
|
|
filename: str |
|
|
size_mb: float |
|
|
rag_enabled: bool |
|
|
cache_enabled: bool |
|
|
|
|
|
|
|
|
|
|
|
sessions: Dict[str, List[Dict[str, Any]]] = {} |
|
|
sessions_lock = threading.Lock() |
|
|
|
|
|
|
|
|
@app.get("/", response_model=Dict[str, str]) |
|
|
async def root(): |
|
|
"""API root endpoint""" |
|
|
return { |
|
|
"message": "Cybersecurity Training Chatbot API", |
|
|
"model": MODEL_REPO, |
|
|
"documentation": "/docs", |
|
|
"health": "/health" |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Check API and model health""" |
|
|
if llm_instance is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
memory_status = memory_manager.check_memory() if memory_manager else {} |
|
|
pool_status = model_pool.get_stats() if model_pool else {"pool_size": 0, "available": 0, "in_use": 0} |
|
|
|
|
|
return { |
|
|
"status": "healthy", |
|
|
"model": MODEL_REPO, |
|
|
"version": "2.0.0", |
|
|
"memory": memory_status, |
|
|
"cache_enabled": CACHE_ENABLED, |
|
|
"rag_enabled": USE_RAG, |
|
|
"concurrent_capacity": pool_status |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/model/info", response_model=ModelInfo) |
|
|
async def model_info(): |
|
|
"""Get information about the loaded model""" |
|
|
if llm_instance is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
info = llm_instance.get_model_info() |
|
|
|
|
|
return ModelInfo( |
|
|
repo_id=info['repo_id'], |
|
|
filename=info['filename'], |
|
|
size_mb=info['size_mb'], |
|
|
rag_enabled=USE_RAG, |
|
|
cache_enabled=CACHE_ENABLED |
|
|
) |
|
|
|
|
|
|
|
|
@app.post("/chat", response_model=ChatResponse) |
|
|
async def chat(request: ChatRequest): |
|
|
"""Main chat endpoint""" |
|
|
if llm_instance is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
try: |
|
|
|
|
|
session_id = request.session_id or str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
with sessions_lock: |
|
|
if session_id not in sessions: |
|
|
sessions[session_id] = [] |
|
|
|
|
|
|
|
|
sessions[session_id].append({ |
|
|
"role": "user", |
|
|
"content": request.message, |
|
|
"timestamp": datetime.now().isoformat() |
|
|
}) |
|
|
|
|
|
|
|
|
cached = False |
|
|
response_text = None |
|
|
sources = None |
|
|
|
|
|
if CACHE_ENABLED and request.use_cache and optimizer: |
|
|
cached_response = optimizer.get_cached_response(request.message) |
|
|
if cached_response: |
|
|
response_text = cached_response |
|
|
cached = True |
|
|
|
|
|
|
|
|
if response_text is None: |
|
|
if USE_RAG and hasattr(llm_instance, 'generate_with_rag'): |
|
|
result = llm_instance.generate_with_rag( |
|
|
request.message, |
|
|
max_tokens=request.max_tokens, |
|
|
use_rag=request.use_rag |
|
|
) |
|
|
sources = result.get('sources', []) |
|
|
else: |
|
|
result = llm_instance.generate( |
|
|
request.message, |
|
|
max_tokens=request.max_tokens, |
|
|
temperature=request.temperature |
|
|
) |
|
|
|
|
|
response_text = result["response"] |
|
|
|
|
|
|
|
|
if CACHE_ENABLED and optimizer and request.use_cache: |
|
|
optimizer.cache_response(request.message, response_text) |
|
|
|
|
|
|
|
|
with sessions_lock: |
|
|
sessions[session_id].append({ |
|
|
"role": "assistant", |
|
|
"content": response_text, |
|
|
"timestamp": datetime.now().isoformat() |
|
|
}) |
|
|
|
|
|
|
|
|
if len(sessions[session_id]) > 20: |
|
|
sessions[session_id] = sessions[session_id][-20:] |
|
|
|
|
|
|
|
|
if memory_manager: |
|
|
memory_manager.optimize_if_needed() |
|
|
|
|
|
return ChatResponse( |
|
|
response=response_text, |
|
|
session_id=session_id, |
|
|
timestamp=datetime.now().isoformat(), |
|
|
model=MODEL_REPO, |
|
|
cached=cached, |
|
|
sources=sources |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Chat error: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@app.post("/chat/stream") |
|
|
async def chat_stream(request: ChatRequest): |
|
|
"""Streaming chat endpoint with concurrent request support""" |
|
|
if model_pool is None: |
|
|
raise HTTPException(status_code=503, detail="Model pool not initialized") |
|
|
|
|
|
|
|
|
count = increment_interaction() |
|
|
session_id = request.session_id or str(uuid.uuid4()) |
|
|
|
|
|
async def generate(): |
|
|
model = None |
|
|
try: |
|
|
full_response = "" |
|
|
|
|
|
|
|
|
model = await model_pool.get_model(timeout=60.0) |
|
|
|
|
|
|
|
|
pool_stats = model_pool.get_stats() |
|
|
start_data = { |
|
|
'type': 'start', |
|
|
'session_id': session_id, |
|
|
'model': MODEL_REPO, |
|
|
'interaction_count': count, |
|
|
'pool_available': pool_stats['available'] |
|
|
} |
|
|
yield f"data: {json.dumps(start_data)}\n\n" |
|
|
|
|
|
|
|
|
for token in model.generate_stream( |
|
|
request.message, |
|
|
max_tokens=request.max_tokens |
|
|
): |
|
|
full_response += token |
|
|
token_data = {'type': 'token', 'content': token} |
|
|
yield f"data: {json.dumps(token_data)}\n\n" |
|
|
await asyncio.sleep(0) |
|
|
|
|
|
|
|
|
log_interaction(session_id, request.message, len(full_response)) |
|
|
|
|
|
end_data = {'type': 'end'} |
|
|
yield f"data: {json.dumps(end_data)}\n\n" |
|
|
|
|
|
except Exception as e: |
|
|
error_data = {'type': 'error', 'message': str(e)} |
|
|
yield f"data: {json.dumps(error_data)}\n\n" |
|
|
finally: |
|
|
|
|
|
if model is not None: |
|
|
model_pool.return_model(model) |
|
|
|
|
|
return StreamingResponse(generate(), media_type="text/event-stream") |
|
|
|
|
|
|
|
|
@app.websocket("/ws/chat") |
|
|
async def websocket_chat(websocket: WebSocket): |
|
|
"""WebSocket endpoint for real-time chat""" |
|
|
await websocket.accept() |
|
|
|
|
|
if llm_instance is None: |
|
|
await websocket.send_json({"type": "error", "message": "Model not loaded"}) |
|
|
await websocket.close() |
|
|
return |
|
|
|
|
|
session_id = str(uuid.uuid4()) |
|
|
|
|
|
try: |
|
|
await websocket.send_json({ |
|
|
"type": "connected", |
|
|
"session_id": session_id, |
|
|
"model": MODEL_REPO |
|
|
}) |
|
|
|
|
|
while True: |
|
|
|
|
|
data = await websocket.receive_text() |
|
|
request = json.loads(data) |
|
|
|
|
|
|
|
|
await websocket.send_json({ |
|
|
"type": "acknowledged", |
|
|
"session_id": session_id |
|
|
}) |
|
|
|
|
|
|
|
|
full_response = "" |
|
|
|
|
|
for token in llm_instance.generate_stream(request.get('message', '')): |
|
|
full_response += token |
|
|
await websocket.send_json({ |
|
|
"type": "token", |
|
|
"content": token |
|
|
}) |
|
|
await asyncio.sleep(0) |
|
|
|
|
|
|
|
|
await websocket.send_json({ |
|
|
"type": "complete", |
|
|
"full_response": full_response |
|
|
}) |
|
|
|
|
|
except WebSocketDisconnect: |
|
|
with sessions_lock: |
|
|
if session_id in sessions: |
|
|
del sessions[session_id] |
|
|
|
|
|
|
|
|
@app.get("/sessions/{session_id}") |
|
|
async def get_session(session_id: str): |
|
|
"""Retrieve session history""" |
|
|
with sessions_lock: |
|
|
if session_id not in sessions: |
|
|
raise HTTPException(status_code=404, detail="Session not found") |
|
|
|
|
|
return { |
|
|
"session_id": session_id, |
|
|
"messages": sessions[session_id].copy(), |
|
|
"model": MODEL_REPO |
|
|
} |
|
|
|
|
|
|
|
|
@app.delete("/sessions/{session_id}") |
|
|
async def clear_session(session_id: str): |
|
|
"""Clear session history""" |
|
|
with sessions_lock: |
|
|
if session_id in sessions: |
|
|
del sessions[session_id] |
|
|
|
|
|
return {"message": "Session cleared"} |
|
|
|
|
|
|
|
|
@app.get("/interactions/count") |
|
|
async def get_interactions_count(): |
|
|
"""Get total interaction count""" |
|
|
count = get_interaction_count() |
|
|
return {"count": count} |
|
|
|
|
|
|
|
|
@app.get("/metrics") |
|
|
async def get_metrics(): |
|
|
"""Get performance metrics""" |
|
|
metrics = { |
|
|
"model": MODEL_REPO, |
|
|
"sessions_active": len(sessions), |
|
|
"total_messages": sum(len(s) for s in sessions.values()), |
|
|
"total_interactions": get_interaction_count() |
|
|
} |
|
|
|
|
|
if optimizer: |
|
|
metrics["cache"] = optimizer.get_metrics() |
|
|
|
|
|
if memory_manager: |
|
|
metrics["memory"] = memory_manager.check_memory() |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
@app.post("/cache/clear") |
|
|
async def clear_cache(): |
|
|
"""Clear response cache""" |
|
|
if not CACHE_ENABLED or not optimizer: |
|
|
raise HTTPException(status_code=400, detail="Cache not enabled") |
|
|
|
|
|
optimizer.clear_cache() |
|
|
return {"message": "Cache cleared"} |
|
|
|
|
|
|
|
|
@app.get("/test") |
|
|
async def serve_test_interface(): |
|
|
"""Serve the test interface HTML""" |
|
|
return FileResponse("test_interface.html") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
config = uvicorn.Config( |
|
|
app, |
|
|
host="0.0.0.0", |
|
|
port=8000, |
|
|
log_level="info", |
|
|
access_log=True, |
|
|
workers=1, |
|
|
limit_concurrency=100, |
|
|
timeout_keep_alive=120, |
|
|
backlog=2048, |
|
|
loop="asyncio" |
|
|
) |
|
|
|
|
|
server = uvicorn.Server(config) |
|
|
server.run() |
|
|
|