|
|
"""LLM provider abstraction and implementations.""" |
|
|
|
|
|
import asyncio |
|
|
import logging |
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Optional |
|
|
|
|
|
import httpx |
|
|
|
|
|
from src.utils import config |
|
|
|
|
|
|
|
|
class BaseLLMProvider(ABC): |
|
|
"""Abstract base class for LLM providers.""" |
|
|
|
|
|
@abstractmethod |
|
|
async def analyze(self, log_text: str) -> str: |
|
|
""" |
|
|
Analyze a security log using the LLM. |
|
|
|
|
|
Args: |
|
|
log_text: Raw log text to analyze |
|
|
|
|
|
Returns: |
|
|
Analysis response from the LLM |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
class OpenAIProvider(BaseLLMProvider): |
|
|
"""OpenAI GPT provider.""" |
|
|
|
|
|
def __init__(self, api_key: str, model: str = "gpt-4-turbo"): |
|
|
self.api_key = api_key |
|
|
self.model = model |
|
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
async def analyze(self, log_text: str) -> str: |
|
|
"""Call OpenAI API for analysis.""" |
|
|
from src.llm.prompts import get_analysis_prompt |
|
|
|
|
|
prompt = get_analysis_prompt(log_text) |
|
|
|
|
|
try: |
|
|
async with httpx.AsyncClient() as client: |
|
|
response = await client.post( |
|
|
"https://api.openai.com/v1/chat/completions", |
|
|
headers={"Authorization": f"Bearer {self.api_key}"}, |
|
|
json={ |
|
|
"model": self.model, |
|
|
"messages": [{"role": "user", "content": prompt}], |
|
|
"temperature": 0.7, |
|
|
"max_tokens": 1000, |
|
|
}, |
|
|
timeout=30, |
|
|
) |
|
|
response.raise_for_status() |
|
|
return response.json()["choices"][0]["message"]["content"] |
|
|
except Exception as e: |
|
|
self.logger.error(f"OpenAI API error: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
class LocalLLMProvider(BaseLLMProvider): |
|
|
"""Local LLM provider (Ollama, LM Studio, etc.).""" |
|
|
|
|
|
def __init__(self, base_url: str = "http://localhost:11434", model: str = "mistral"): |
|
|
self.base_url = base_url |
|
|
self.model = model |
|
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
async def analyze(self, log_text: str) -> str: |
|
|
"""Call local LLM endpoint for analysis.""" |
|
|
from src.llm.prompts import get_analysis_prompt |
|
|
|
|
|
prompt = get_analysis_prompt(log_text) |
|
|
|
|
|
try: |
|
|
async with httpx.AsyncClient() as client: |
|
|
response = await client.post( |
|
|
f"{self.base_url}/api/generate", |
|
|
json={ |
|
|
"model": self.model, |
|
|
"prompt": prompt, |
|
|
"stream": False, |
|
|
}, |
|
|
timeout=60, |
|
|
) |
|
|
response.raise_for_status() |
|
|
return response.json()["response"] |
|
|
except Exception as e: |
|
|
self.logger.error(f"Local LLM error: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
class MockLLMProvider(BaseLLMProvider): |
|
|
"""Mock LLM for testing and demonstrations.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
async def analyze(self, log_text: str) -> str: |
|
|
"""Return a deterministic mock analysis.""" |
|
|
|
|
|
await asyncio.sleep(0.5) |
|
|
|
|
|
|
|
|
risk_level = "MEDIUM" |
|
|
if any(word in log_text.lower() for word in ["critical", "ransomware", "breach"]): |
|
|
risk_level = "CRITICAL" |
|
|
elif any(word in log_text.lower() for word in ["failed", "denied", "suspicious"]): |
|
|
risk_level = "HIGH" |
|
|
elif len(log_text) < 100: |
|
|
risk_level = "LOW" |
|
|
|
|
|
return f"""**What Happened:** |
|
|
An authentication anomaly was detected in the system logs. Multiple failed login attempts were recorded from unusual IP addresses during off-peak hours. |
|
|
|
|
|
**Risk Level:** |
|
|
{risk_level} |
|
|
|
|
|
**Suggested Actions:** |
|
|
- Review authentication logs for the affected accounts |
|
|
- Check IP reputation of source addresses |
|
|
- Consider implementing rate-limiting on login endpoints |
|
|
- Notify affected users of suspicious activity |
|
|
|
|
|
**Key Indicators:** |
|
|
- Multiple failed authentication attempts (5+ in short timeframe) |
|
|
- Unusual geographic location |
|
|
- Off-peak access time |
|
|
""" |
|
|
|
|
|
|
|
|
def create_provider() -> BaseLLMProvider: |
|
|
"""Factory function to create appropriate LLM provider based on config.""" |
|
|
from src.utils.config import LLMProvider |
|
|
|
|
|
if config.llm_provider == LLMProvider.OPENAI: |
|
|
return OpenAIProvider(config.openai_api_key, config.model_name) |
|
|
elif config.llm_provider == LLMProvider.LOCAL: |
|
|
return LocalLLMProvider(model=config.model_name) |
|
|
else: |
|
|
return MockLLMProvider() |
|
|
|