CyberSecChatbot / llm_handler.py
Andrew McCracken
Add GPU support
bfa102d
from llama_cpp import Llama
from typing import Generator, Optional, Dict, Any
import logging
import os
from huggingface_hub import hf_hub_download
import hashlib
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CybersecurityLLM:
def __init__(self,
repo_id: str = "daskalos-apps/phi4-cybersec-Q4_K_M",
filename: str = "phi4-mini-instruct-Q4_K_M.gguf",
local_dir: str = "./models",
force_download: bool = False):
"""
Initialize Phi-4 from Hugging Face
Args:
repo_id: Your Hugging Face repository ID
filename: The GGUF filename in the repository
local_dir: Local directory to cache the model
force_download: Force re-download even if cached
"""
# Create local directory if it doesn't exist
os.makedirs(local_dir, exist_ok=True)
# Download model from Hugging Face
logger.info(f"Loading model from Hugging Face: {repo_id}")
try:
model_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=local_dir,
local_dir_use_symlinks=False,
force_download=force_download
)
logger.info(f"Model downloaded/cached at: {model_path}")
except Exception as e:
logger.error(f"Failed to download model: {e}")
# Fallback to local file if exists
model_path = os.path.join(local_dir, filename)
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model not found locally or on Hugging Face: {repo_id}")
# Initialize llama.cpp with the model
logger.info("Initializing model...")
# Check for GPU support via environment variable
n_gpu_layers = int(os.getenv("N_GPU_LAYERS", "0"))
if n_gpu_layers > 0:
logger.info(f"GPU acceleration enabled: {n_gpu_layers} layers")
else:
logger.info("Running in CPU-only mode")
self.llm = Llama(
model_path=model_path,
n_ctx=4096, # Context window
n_batch=512, # Batch size for prompt processing
n_threads=6 if n_gpu_layers == 0 else 4, # Fewer threads needed with GPU
n_gpu_layers=n_gpu_layers, # GPU layers (0 for CPU-only)
seed=-1, # Random seed
f16_kv=True, # Use f16 for key/value cache (saves memory)
logits_all=False, # Only compute logits for last token
vocab_only=False, # Load full model
use_mmap=True, # Memory-map model for efficiency
use_mlock=False, # Don't lock model in RAM
verbose=True # Enable verbose for debugging
)
# Store model info
self.model_info = {
"repo_id": repo_id,
"filename": filename,
"path": model_path,
"size_mb": os.path.getsize(model_path) / (1024 * 1024)
}
# Cybersecurity-focused system prompt
self.system_prompt = """You are a cybersecurity expert assistant helping employees understand and implement security best practices. Your role is to provide clear, actionable guidance that non-technical users can understand and apply.
Core expertise areas:
• Email Security & Phishing Detection
• Password Management & Authentication
• Malware Prevention & Detection
• Safe Browsing & Download Practices
• Data Protection & Encryption
• Social Engineering Defense
• Remote Work Security
• Incident Response & Reporting
• Physical Security
• Mobile Device Security
• Cloud Security Basics
• Compliance Basics (GDPR, HIPAA, etc.)
Guidelines:
- Always prioritize user safety and security
- Provide step-by-step instructions when applicable
- Use simple language, avoid excessive jargon
- Include real-world examples
- Emphasize prevention over remediation
- Never ask users to disable security features
- If unsure, recommend consulting IT security team"""
# Phi-4 uses ChatML format
self.prompt_template = """<|system|>
{system}<|end|>
<|user|>
{user}<|end|>
<|assistant|>"""
self.stop_tokens = ["<|end|>", "<|user|>", "<|endoftext|>", "<|assistant|>"]
logger.info(f"Model ready! Size: {self.model_info['size_mb']:.2f} MB")
def format_prompt(self, user_input: str, context: Optional[str] = None) -> str:
"""Format prompt with optional context for RAG"""
if context:
user_input = f"Context: {context}\n\nQuestion: {user_input}"
return self.prompt_template.format(
system=self.system_prompt,
user=user_input
)
def generate(self,
prompt: str,
max_tokens: int = 512,
temperature: float = 0.7,
context: Optional[str] = None) -> Dict[str, Any]:
"""Generate response with metadata"""
full_prompt = self.format_prompt(prompt, context)
try:
response = self.llm(
full_prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=0.95,
top_k=40,
repeat_penalty=1.1,
stop=self.stop_tokens,
echo=False
)
text = response['choices'][0]['text'].strip()
return {
"response": text,
"tokens_used": response['usage']['total_tokens'],
"model": self.model_info['repo_id']
}
except Exception as e:
logger.error(f"Generation error: {e}")
return {
"response": "I apologize, but I encountered an error. Please try rephrasing your question.",
"error": str(e)
}
def generate_stream(self,
prompt: str,
max_tokens: int = 512,
context: Optional[str] = None) -> Generator:
"""Stream response tokens"""
full_prompt = self.format_prompt(prompt, context)
stream = self.llm(
full_prompt,
max_tokens=max_tokens,
temperature=0.7,
top_p=0.95,
top_k=40,
repeat_penalty=1.1,
stop=self.stop_tokens,
echo=False,
stream=True
)
for output in stream:
token = output['choices'][0].get('text', '')
if token:
yield token
def get_model_info(self) -> Dict[str, Any]:
"""Get information about the loaded model"""
return self.model_info