File size: 3,757 Bytes
f8e78b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
Utility functions for the Security Scanner MCP server.
"""

import json
import logging
from pathlib import Path
from typing import Any, Dict

# Project root directory
PROJECT_ROOT = Path(__file__).parent.parent


def load_config() -> Dict[str, Any]:
    """
    Load configuration from mcp_config.json.

    Returns:
        Dictionary containing configuration settings
    """
    config_path = PROJECT_ROOT / "mcp_config.json"

    if not config_path.exists():
        raise FileNotFoundError(f"Configuration file not found: {config_path}")

    with open(config_path, "r", encoding="utf-8") as f:
        config = json.load(f)

    return config


def setup_logging(debug: bool = False) -> logging.Logger:
    """
    Set up logging based on configuration.

    Args:
        debug: If True, set logging level to DEBUG

    Returns:
        Configured logger instance
    """
    config = load_config()
    log_config = config.get("logging", {})

    # Determine log level
    if debug:
        log_level = logging.DEBUG
    else:
        log_level = getattr(logging, log_config.get("level", "INFO"))

    # Create logs directory if it doesn't exist
    log_file = log_config.get("file", "logs/mcp_server.log")
    log_path = PROJECT_ROOT / log_file
    log_path.parent.mkdir(parents=True, exist_ok=True)

    # Configure logging
    handlers = []

    # File handler
    file_handler = logging.FileHandler(log_path, encoding="utf-8")
    file_handler.setLevel(log_level)
    handlers.append(file_handler)

    # Console handler
    if log_config.get("console", True):
        console_handler = logging.StreamHandler()
        console_handler.setLevel(log_level)
        handlers.append(console_handler)

    # Set up formatter
    if log_config.get("json_format", False):
        formatter = logging.Formatter(
            '{"time":"%(asctime)s","level":"%(levelname)s","message":"%(message)s"}'
        )
    else:
        formatter = logging.Formatter(
            "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
        )

    for handler in handlers:
        handler.setFormatter(formatter)

    # Create and configure logger
    logger = logging.getLogger("security-scanner-mcp")
    logger.setLevel(log_level)
    logger.handlers = []  # Clear any existing handlers

    for handler in handlers:
        logger.addHandler(handler)

    return logger


def validate_severity_threshold(threshold: str) -> bool:
    """
    Validate severity threshold value.

    Args:
        threshold: Severity threshold string

    Returns:
        True if valid, False otherwise
    """
    valid_thresholds = ["CRITICAL", "HIGH", "MEDIUM", "LOW"]
    return threshold.upper() in valid_thresholds


def get_severity_order() -> Dict[str, int]:
    """
    Get severity level ordering from configuration.

    Returns:
        Dictionary mapping severity levels to numeric order
    """
    config = load_config()
    return config.get("severity", {}).get("thresholds", {
        "CRITICAL": 0,
        "HIGH": 1,
        "MEDIUM": 2,
        "LOW": 3
    })


def filter_by_severity(
    vulnerabilities: list,
    threshold: str
) -> list:
    """
    Filter vulnerabilities by severity threshold.

    Args:
        vulnerabilities: List of vulnerability dictionaries
        threshold: Minimum severity threshold

    Returns:
        Filtered list of vulnerabilities
    """
    severity_order = get_severity_order()
    threshold_value = severity_order.get(threshold.upper(), 2)

    filtered = []
    for vuln in vulnerabilities:
        vuln_severity = vuln.get("severity", "LOW").upper()
        vuln_value = severity_order.get(vuln_severity, 3)

        if vuln_value <= threshold_value:
            filtered.append(vuln)

    return filtered