| from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, BackgroundTasks |
| from fastapi.responses import StreamingResponse, JSONResponse, FileResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field |
| from typing import Optional, List, Dict, Any |
| from datetime import datetime |
| import asyncio |
| import json |
| import uuid |
| import os |
| import sqlite3 |
| from contextlib import asynccontextmanager |
|
|
| |
| from llm_handler import CybersecurityLLM |
| from knowledge_base import RAGCybersecurityLLM |
| from optimisations import PerformanceOptimizer, MemoryManager |
|
|
| |
| MODEL_REPO = os.getenv("MODEL_REPO", "daskalos-apps/phi4-cybersec-Q4_K_M") |
| MODEL_FILENAME = os.getenv("MODEL_FILENAME", "phi4-mini-instruct-Q4_K_M.gguf") |
| USE_RAG = os.getenv("USE_RAG", "true").lower() == "true" |
| CACHE_ENABLED = os.getenv("CACHE_ENABLED", "true").lower() == "true" |
|
|
| |
| llm_instance = None |
| optimizer = None |
| memory_manager = None |
|
|
| |
| |
| if os.path.exists("/data"): |
| DB_PATH = "/data/interactions.db" |
| elif os.path.exists("/app/data"): |
| DB_PATH = "/app/data/interactions.db" |
| else: |
| DB_PATH = "interactions.db" |
|
|
| def init_db(): |
| """Initialize SQLite database for interaction tracking""" |
| conn = sqlite3.connect(DB_PATH) |
| cursor = conn.cursor() |
| cursor.execute(""" |
| CREATE TABLE IF NOT EXISTS interactions ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| timestamp TEXT NOT NULL, |
| session_id TEXT, |
| message TEXT, |
| response_length INTEGER |
| ) |
| """) |
| cursor.execute(""" |
| CREATE TABLE IF NOT EXISTS interaction_count ( |
| id INTEGER PRIMARY KEY CHECK (id = 1), |
| count INTEGER DEFAULT 0 |
| ) |
| """) |
| cursor.execute("INSERT OR IGNORE INTO interaction_count (id, count) VALUES (1, 0)") |
| conn.commit() |
| conn.close() |
|
|
| def increment_interaction(): |
| """Increment interaction count and return new count""" |
| conn = sqlite3.connect(DB_PATH) |
| cursor = conn.cursor() |
| cursor.execute("UPDATE interaction_count SET count = count + 1 WHERE id = 1") |
| cursor.execute("SELECT count FROM interaction_count WHERE id = 1") |
| count = cursor.fetchone()[0] |
| conn.commit() |
| conn.close() |
| return count |
|
|
| def get_interaction_count(): |
| """Get current interaction count""" |
| conn = sqlite3.connect(DB_PATH) |
| cursor = conn.cursor() |
| cursor.execute("SELECT count FROM interaction_count WHERE id = 1") |
| count = cursor.fetchone()[0] |
| conn.close() |
| return count |
|
|
| def log_interaction(session_id: str, message: str, response_length: int): |
| """Log interaction details""" |
| conn = sqlite3.connect(DB_PATH) |
| cursor = conn.cursor() |
| cursor.execute( |
| "INSERT INTO interactions (timestamp, session_id, message, response_length) VALUES (?, ?, ?, ?)", |
| (datetime.now().isoformat(), session_id, message, response_length) |
| ) |
| conn.commit() |
| conn.close() |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """Startup and shutdown events""" |
| global llm_instance, optimizer, memory_manager |
|
|
| |
| print(f"π Loading model from Hugging Face: {MODEL_REPO}") |
|
|
| |
| init_db() |
| print("β
Database initialized") |
|
|
| try: |
| if USE_RAG: |
| llm_instance = RAGCybersecurityLLM( |
| repo_id=MODEL_REPO, |
| filename=MODEL_FILENAME |
| ) |
| else: |
| llm_instance = CybersecurityLLM( |
| repo_id=MODEL_REPO, |
| filename=MODEL_FILENAME |
| ) |
|
|
| if CACHE_ENABLED: |
| optimizer = PerformanceOptimizer() |
|
|
| memory_manager = MemoryManager() |
|
|
| print("β
Cybersecurity Chatbot ready!") |
| print(f"π¦ Model: {MODEL_REPO}") |
| print(f"πΎ Size: {llm_instance.get_model_info()['size_mb']:.2f} MB") |
| print(f"π§ RAG: {'Enabled' if USE_RAG else 'Disabled'}") |
| print(f"β‘ Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}") |
|
|
| except Exception as e: |
| print(f"β Failed to load model: {e}") |
| raise |
|
|
| yield |
|
|
| |
| print("π Shutting down...") |
|
|
|
|
| |
| app = FastAPI( |
| title="Cybersecurity Training Chatbot API", |
| description="AI-powered cybersecurity guidance using Phi-4 from Hugging Face", |
| version="2.0.0", |
| lifespan=lifespan |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| |
| class ChatRequest(BaseModel): |
| message: str = Field(..., description="User's security question") |
| session_id: Optional[str] = Field(None, description="Session ID for conversation continuity") |
| max_tokens: Optional[int] = Field(256, description="Maximum response length") |
| temperature: Optional[float] = Field(0.7, description="Response creativity (0-1)") |
| use_rag: Optional[bool] = Field(True, description="Use RAG for enhanced accuracy") |
| use_cache: Optional[bool] = Field(True, description="Use cached responses if available") |
|
|
|
|
| class ChatResponse(BaseModel): |
| response: str |
| session_id: str |
| timestamp: str |
| model: str |
| tokens_used: Optional[int] = None |
| cached: bool = False |
| sources: Optional[List[str]] = None |
|
|
|
|
| class ModelInfo(BaseModel): |
| repo_id: str |
| filename: str |
| size_mb: float |
| rag_enabled: bool |
| cache_enabled: bool |
|
|
|
|
| |
| sessions: Dict[str, List[Dict[str, Any]]] = {} |
|
|
|
|
| @app.get("/", response_model=Dict[str, str]) |
| async def root(): |
| """API root endpoint""" |
| return { |
| "message": "Cybersecurity Training Chatbot API", |
| "model": MODEL_REPO, |
| "documentation": "/docs", |
| "health": "/health" |
| } |
|
|
|
|
| @app.get("/health") |
| async def health_check(): |
| """Check API and model health""" |
| if llm_instance is None: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
| memory_status = memory_manager.check_memory() if memory_manager else {} |
|
|
| return { |
| "status": "healthy", |
| "model": MODEL_REPO, |
| "version": "2.0.0", |
| "memory": memory_status, |
| "cache_enabled": CACHE_ENABLED, |
| "rag_enabled": USE_RAG |
| } |
|
|
|
|
| @app.get("/model/info", response_model=ModelInfo) |
| async def model_info(): |
| """Get information about the loaded model""" |
| if llm_instance is None: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
| info = llm_instance.get_model_info() |
|
|
| return ModelInfo( |
| repo_id=info['repo_id'], |
| filename=info['filename'], |
| size_mb=info['size_mb'], |
| rag_enabled=USE_RAG, |
| cache_enabled=CACHE_ENABLED |
| ) |
|
|
|
|
| @app.post("/chat", response_model=ChatResponse) |
| async def chat(request: ChatRequest): |
| """Main chat endpoint""" |
| if llm_instance is None: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
| try: |
| |
| session_id = request.session_id or str(uuid.uuid4()) |
|
|
| |
| if session_id not in sessions: |
| sessions[session_id] = [] |
|
|
| |
| sessions[session_id].append({ |
| "role": "user", |
| "content": request.message, |
| "timestamp": datetime.now().isoformat() |
| }) |
|
|
| |
| cached = False |
| response_text = None |
| sources = None |
|
|
| if CACHE_ENABLED and request.use_cache and optimizer: |
| cached_response = optimizer.get_cached_response(request.message) |
| if cached_response: |
| response_text = cached_response |
| cached = True |
|
|
| |
| if response_text is None: |
| if USE_RAG and hasattr(llm_instance, 'generate_with_rag'): |
| result = llm_instance.generate_with_rag( |
| request.message, |
| max_tokens=request.max_tokens, |
| use_rag=request.use_rag |
| ) |
| sources = result.get('sources', []) |
| else: |
| result = llm_instance.generate( |
| request.message, |
| max_tokens=request.max_tokens, |
| temperature=request.temperature |
| ) |
|
|
| response_text = result["response"] |
|
|
| |
| if CACHE_ENABLED and optimizer and request.use_cache: |
| optimizer.cache_response(request.message, response_text) |
|
|
| |
| sessions[session_id].append({ |
| "role": "assistant", |
| "content": response_text, |
| "timestamp": datetime.now().isoformat() |
| }) |
|
|
| |
| if len(sessions[session_id]) > 20: |
| sessions[session_id] = sessions[session_id][-20:] |
|
|
| |
| if memory_manager: |
| memory_manager.optimize_if_needed() |
|
|
| return ChatResponse( |
| response=response_text, |
| session_id=session_id, |
| timestamp=datetime.now().isoformat(), |
| model=MODEL_REPO, |
| cached=cached, |
| sources=sources |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Chat error: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @app.post("/chat/stream") |
| async def chat_stream(request: ChatRequest): |
| """Streaming chat endpoint""" |
| if llm_instance is None: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
| |
| count = increment_interaction() |
| session_id = request.session_id or str(uuid.uuid4()) |
|
|
| async def generate(): |
| try: |
| full_response = "" |
|
|
| |
| yield f"data: {json.dumps({'type': 'start', 'session_id': session_id, 'model': MODEL_REPO, 'interaction_count': count})}\n\n" |
|
|
| |
| for token in llm_instance.generate_stream( |
| request.message, |
| max_tokens=request.max_tokens |
| ): |
| full_response += token |
| yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n" |
| await asyncio.sleep(0) |
|
|
| |
| log_interaction(session_id, request.message, len(full_response)) |
|
|
| yield f"data: {json.dumps({'type': 'end'})}\n\n" |
|
|
| except Exception as e: |
| yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" |
|
|
| return StreamingResponse(generate(), media_type="text/event-stream") |
|
|
|
|
| @app.websocket("/ws/chat") |
| async def websocket_chat(websocket: WebSocket): |
| """WebSocket endpoint for real-time chat""" |
| await websocket.accept() |
|
|
| if llm_instance is None: |
| await websocket.send_json({"type": "error", "message": "Model not loaded"}) |
| await websocket.close() |
| return |
|
|
| session_id = str(uuid.uuid4()) |
|
|
| try: |
| await websocket.send_json({ |
| "type": "connected", |
| "session_id": session_id, |
| "model": MODEL_REPO |
| }) |
|
|
| while True: |
| |
| data = await websocket.receive_text() |
| request = json.loads(data) |
|
|
| |
| await websocket.send_json({ |
| "type": "acknowledged", |
| "session_id": session_id |
| }) |
|
|
| |
| full_response = "" |
|
|
| for token in llm_instance.generate_stream(request.get('message', '')): |
| full_response += token |
| await websocket.send_json({ |
| "type": "token", |
| "content": token |
| }) |
| await asyncio.sleep(0) |
|
|
| |
| await websocket.send_json({ |
| "type": "complete", |
| "full_response": full_response |
| }) |
|
|
| except WebSocketDisconnect: |
| if session_id in sessions: |
| del sessions[session_id] |
|
|
|
|
| @app.get("/sessions/{session_id}") |
| async def get_session(session_id: str): |
| """Retrieve session history""" |
| if session_id not in sessions: |
| raise HTTPException(status_code=404, detail="Session not found") |
|
|
| return { |
| "session_id": session_id, |
| "messages": sessions[session_id], |
| "model": MODEL_REPO |
| } |
|
|
|
|
| @app.delete("/sessions/{session_id}") |
| async def clear_session(session_id: str): |
| """Clear session history""" |
| if session_id in sessions: |
| del sessions[session_id] |
|
|
| return {"message": "Session cleared"} |
|
|
|
|
| @app.get("/interactions/count") |
| async def get_interactions_count(): |
| """Get total interaction count""" |
| count = get_interaction_count() |
| return {"count": count} |
|
|
|
|
| @app.get("/metrics") |
| async def get_metrics(): |
| """Get performance metrics""" |
| metrics = { |
| "model": MODEL_REPO, |
| "sessions_active": len(sessions), |
| "total_messages": sum(len(s) for s in sessions.values()), |
| "total_interactions": get_interaction_count() |
| } |
|
|
| if optimizer: |
| metrics["cache"] = optimizer.get_metrics() |
|
|
| if memory_manager: |
| metrics["memory"] = memory_manager.check_memory() |
|
|
| return metrics |
|
|
|
|
| @app.post("/cache/clear") |
| async def clear_cache(): |
| """Clear response cache""" |
| if not CACHE_ENABLED or not optimizer: |
| raise HTTPException(status_code=400, detail="Cache not enabled") |
|
|
| optimizer.clear_cache() |
| return {"message": "Cache cleared"} |
|
|
|
|
| @app.get("/test") |
| async def serve_test_interface(): |
| """Serve the test interface HTML""" |
| return FileResponse("test_interface.html") |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
|
|
| uvicorn.run( |
| app, |
| host="0.0.0.0", |
| port=8000, |
| log_level="info", |
| access_log=True |
| ) |
|
|