Fix websocket pattern + improve beacon override logic to prevent false positives/negatives
ab9baba
verified
| #!/usr/bin/env python3 | |
| """ | |
| C2Sentinel - Network Traffic C2 Beacon Detection Model | |
| A machine learning model for detecting Command and Control (C2) beacon | |
| communications in network traffic. Built on a fine-tuned LogBERT transformer | |
| architecture. | |
| Author: Daniel Ostrow | |
| Website: https://neuralintellect.com | |
| Features: | |
| - Detection of 34+ C2 framework behavioral patterns across all ports | |
| - Smart context inference for additional metadata (process, DNS, reputation) | |
| - Legitimate service pattern recognition (SSH keepalive, health checks) | |
| - Reconnaissance support (IP enrichment, IOC generation) | |
| - Comprehensive scripting API for automation | |
| Uses safetensors format for secure model serialization. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import json | |
| import math | |
| import socket | |
| import struct | |
| import hashlib | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple, Optional, Union, Any, Callable | |
| from dataclasses import dataclass, asdict, field | |
| from collections import defaultdict | |
| from enum import Enum | |
| import re | |
| from datetime import datetime | |
| import ipaddress | |
| # Safetensors for safe model serialization | |
| from safetensors.torch import save_file, load_file | |
| # ============================================================================ | |
| # ENUMS AND CONSTANTS | |
| # ============================================================================ | |
| class DetectionMethod(Enum): | |
| """Detection method used for classification.""" | |
| SIGNATURE = "signature" | |
| BEHAVIORAL = "behavioral" | |
| ML = "ml" | |
| CONTEXT = "context" | |
| HEURISTIC = "heuristic" | |
| WHITELIST = "whitelist" | |
| class TrafficType(Enum): | |
| """Classification of traffic type.""" | |
| C2_BEACON = "c2_beacon" | |
| C2_EXFIL = "c2_exfiltration" | |
| C2_LATERAL = "c2_lateral_movement" | |
| LEGITIMATE = "legitimate" | |
| SUSPICIOUS = "suspicious" | |
| UNKNOWN = "unknown" | |
| class ServiceType(Enum): | |
| """Known service types for context.""" | |
| SSH = "ssh" | |
| HTTP = "http" | |
| HTTPS = "https" | |
| DNS = "dns" | |
| DATABASE = "database" | |
| API = "api" | |
| STREAMING = "streaming" | |
| GAMING = "gaming" | |
| VPN = "vpn" | |
| MONITORING = "monitoring" | |
| UNKNOWN = "unknown" | |
| class C2SentinelConfig: | |
| """Configuration for LogBERT-C2Sentinel model.""" | |
| num_features: int = 40 | |
| d_model: int = 256 | |
| nhead: int = 8 | |
| num_encoder_layers: int = 6 | |
| dim_feedforward: int = 1024 | |
| dropout: float = 0.1 | |
| max_seq_length: int = 512 | |
| num_c2_types: int = 35 | |
| version: str = "2.0.0" | |
| def to_dict(self) -> dict: | |
| return asdict(self) | |
| def from_dict(cls, d: dict) -> 'C2SentinelConfig': | |
| return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__}) | |
| # High-confidence C2 ports - these are VERY rarely used legitimately | |
| C2_INDICATOR_PORTS = { | |
| 4444, # Metasploit default | |
| 4445, # Metasploit alternative | |
| 5555, # Metasploit (Note: Android debug uses this too) | |
| 31337, # Elite/Sliver | |
| 40056, # Havoc default | |
| } | |
| # Ports commonly used by C2 (but also legitimate traffic) | |
| C2_COMMON_PORTS = { | |
| 80, # HTTP | |
| 443, # HTTPS | |
| 53, # DNS | |
| 8080, # HTTP alt | |
| 8443, # HTTPS alt | |
| 8888, # Sliver default | |
| } | |
| # Known legitimate service ports with expected behaviors | |
| LEGITIMATE_SERVICE_PORTS = { | |
| 22: ServiceType.SSH, | |
| 80: ServiceType.HTTP, | |
| 443: ServiceType.HTTPS, | |
| 53: ServiceType.DNS, | |
| 3306: ServiceType.DATABASE, # MySQL | |
| 5432: ServiceType.DATABASE, # PostgreSQL | |
| 6379: ServiceType.DATABASE, # Redis | |
| 27017: ServiceType.DATABASE, # MongoDB | |
| 5000: ServiceType.API, # Flask default | |
| 3000: ServiceType.API, # Node.js default | |
| 8080: ServiceType.API, # Common API port | |
| 9090: ServiceType.MONITORING,# Prometheus | |
| 3100: ServiceType.MONITORING,# Grafana Loki | |
| } | |
| # C2 Framework Signatures | |
| C2_SIGNATURES = { | |
| 'metasploit': { | |
| 'ports': [4444, 4445, 5555], | |
| 'interval_range': (1, 30), | |
| 'packet_sizes': [(50, 200), (500, 2000)], | |
| 'jitter_range': (0.0, 0.3), | |
| }, | |
| 'cobalt_strike': { | |
| 'ports': [50050], | |
| 'interval_range': (30, 300), | |
| 'packet_sizes': [(68, 200), (200, 1000)], | |
| 'jitter_range': (0.0, 0.5), | |
| }, | |
| 'sliver': { | |
| 'ports': [8888, 31337], | |
| 'interval_range': (5, 60), | |
| 'packet_sizes': [(100, 500)], | |
| 'jitter_range': (0.0, 0.3), | |
| }, | |
| 'havoc': { | |
| 'ports': [40056], | |
| 'interval_range': (2, 30), | |
| 'packet_sizes': [(64, 256)], | |
| 'jitter_range': (0.0, 0.2), | |
| }, | |
| } | |
| # ============================================================================ | |
| # LEGITIMATE SERVICE PATTERNS - Key to reducing false positives | |
| # ============================================================================ | |
| class LegitimatePattern: | |
| """Defines a known legitimate traffic pattern.""" | |
| name: str | |
| service_type: ServiceType | |
| port: Optional[int] = None | |
| ports: Optional[List[int]] = None | |
| min_packet_size: int = 0 | |
| max_packet_size: int = 100000 | |
| symmetric_ratio: Tuple[float, float] = (0.0, 10.0) # sent/recv ratio range | |
| max_interval_cv: float = 1.0 # coefficient of variation for intervals | |
| max_size_cv: float = 1.0 # coefficient of variation for sizes | |
| description: str = "" | |
| def matches(self, connections: List[Dict], stats: Dict) -> Tuple[bool, float]: | |
| """Check if connections match this legitimate pattern. Returns (matches, confidence).""" | |
| if not connections: | |
| return False, 0.0 | |
| ports = set(conn.get('dst_port', 0) for conn in connections) | |
| # Check port match | |
| if self.port and self.port not in ports: | |
| return False, 0.0 | |
| if self.ports and not any(p in ports for p in self.ports): | |
| return False, 0.0 | |
| # Check packet sizes | |
| bytes_sent = [conn.get('bytes_sent', 0) for conn in connections] | |
| bytes_recv = [conn.get('bytes_recv', 0) for conn in connections] | |
| if bytes_sent: | |
| if max(bytes_sent) > self.max_packet_size or min(bytes_sent) < self.min_packet_size: | |
| return False, 0.0 | |
| # Check ratio | |
| total_sent = sum(bytes_sent) | |
| total_recv = sum(bytes_recv) | |
| if total_recv > 0: | |
| ratio = total_sent / total_recv | |
| if not (self.symmetric_ratio[0] <= ratio <= self.symmetric_ratio[1]): | |
| return False, 0.0 | |
| # CRITICAL: Check size variance - legitimate traffic has HIGH variance | |
| # C2 traffic has LOW variance (consistent beacon sizes) | |
| recv_cv = stats.get('recv_cv', 0) | |
| sent_cv = stats.get('sent_cv', 0) | |
| # If BOTH sent and recv are very consistent (CV < 0.3), this is likely C2 | |
| # Legitimate patterns should have at least some variance | |
| if recv_cv < 0.3 and sent_cv < 0.3: | |
| # Exception: SSH keepalive is intentionally consistent but tiny | |
| if self.name == "ssh_keepalive": | |
| pass # Allow SSH keepalive to match | |
| else: | |
| # Too consistent for legitimate traffic - likely C2 | |
| return False, 0.0 | |
| return True, 0.8 | |
| # Pre-defined legitimate patterns | |
| LEGITIMATE_PATTERNS = [ | |
| LegitimatePattern( | |
| name="ssh_keepalive", | |
| service_type=ServiceType.SSH, | |
| port=22, | |
| min_packet_size=20, | |
| max_packet_size=100, # Keepalive packets are very small | |
| symmetric_ratio=(0.8, 1.2), # Nearly symmetric | |
| max_interval_cv=0.3, | |
| max_size_cv=0.15, # Very consistent sizes | |
| description="SSH keepalive probes - small symmetric packets at regular intervals" | |
| ), | |
| LegitimatePattern( | |
| name="ssh_interactive", | |
| service_type=ServiceType.SSH, | |
| port=22, | |
| min_packet_size=20, | |
| max_packet_size=50000, | |
| symmetric_ratio=(0.01, 100.0), # Can be asymmetric | |
| max_interval_cv=2.0, # Very variable timing (human typing) | |
| max_size_cv=2.0, # Very variable sizes | |
| description="Interactive SSH session with variable human-driven timing" | |
| ), | |
| LegitimatePattern( | |
| name="health_check", | |
| service_type=ServiceType.MONITORING, | |
| ports=[80, 443, 8080, 8443, 9090], | |
| min_packet_size=50, | |
| max_packet_size=10000, | |
| symmetric_ratio=(0.01, 0.5), # Small requests, larger responses | |
| max_interval_cv=0.3, # Regular intervals | |
| max_size_cv=1.0, # Response sizes can vary (status data) | |
| description="Health check endpoint with variable response sizes" | |
| ), | |
| LegitimatePattern( | |
| name="database_heartbeat", | |
| service_type=ServiceType.DATABASE, | |
| ports=[3306, 5432, 6379, 27017], | |
| min_packet_size=20, | |
| max_packet_size=100000, | |
| symmetric_ratio=(0.01, 100.0), | |
| max_interval_cv=0.3, | |
| max_size_cv=5.0, # Query results vary dramatically | |
| description="Database connection with variable query responses" | |
| ), | |
| LegitimatePattern( | |
| name="websocket_stream", | |
| service_type=ServiceType.API, | |
| ports=[80, 443, 8080], | |
| min_packet_size=100, # WebSocket frames are typically larger | |
| max_packet_size=100000, | |
| symmetric_ratio=(0.001, 0.3), # Receives much more than sends (streaming) | |
| max_interval_cv=1.5, # Irregular timing (event-driven) | |
| max_size_cv=2.0, # High variance in response sizes (required) | |
| description="WebSocket streaming connection with variable push data" | |
| ), | |
| ] | |
| # ============================================================================ | |
| # CONTEXT INFERENCE SYSTEM | |
| # ============================================================================ | |
| class ConnectionContext: | |
| """ | |
| Additional context for connection analysis. | |
| Provide any available context to improve detection accuracy. | |
| All fields are optional - more context = better analysis. | |
| """ | |
| # Process information | |
| process_name: Optional[str] = None | |
| process_path: Optional[str] = None | |
| process_pid: Optional[int] = None | |
| parent_process: Optional[str] = None | |
| command_line: Optional[str] = None | |
| # Network metadata | |
| dns_queries: Optional[List[str]] = None # Associated DNS lookups | |
| resolved_hostname: Optional[str] = None | |
| tls_sni: Optional[str] = None # TLS Server Name Indication | |
| tls_ja3: Optional[str] = None # JA3 fingerprint | |
| tls_ja3s: Optional[str] = None # JA3S fingerprint | |
| certificate_issuer: Optional[str] = None | |
| certificate_subject: Optional[str] = None | |
| certificate_valid: Optional[bool] = None | |
| http_user_agent: Optional[str] = None | |
| http_host: Optional[str] = None | |
| # Reputation and intelligence | |
| ip_reputation: Optional[float] = None # 0.0 (bad) to 1.0 (good) | |
| domain_reputation: Optional[float] = None | |
| known_good: Optional[bool] = None # Explicitly whitelisted | |
| known_bad: Optional[bool] = None # Explicitly blacklisted | |
| threat_intel_match: Optional[str] = None # Matched threat intel indicator | |
| # Host context | |
| source_hostname: Optional[str] = None | |
| source_user: Optional[str] = None | |
| source_is_server: Optional[bool] = None | |
| source_is_workstation: Optional[bool] = None | |
| # Additional metadata | |
| geo_country: Optional[str] = None | |
| geo_asn: Optional[str] = None | |
| tags: Optional[List[str]] = None | |
| def to_dict(self) -> Dict[str, Any]: | |
| return {k: v for k, v in asdict(self).items() if v is not None} | |
| class ContextInference: | |
| """ | |
| Smart context inference engine. | |
| Uses additional context to refine detection decisions and reduce false positives. | |
| """ | |
| # Known legitimate process names | |
| KNOWN_LEGITIMATE_PROCESSES = { | |
| 'sshd', 'ssh', 'openssh', 'dropbear', # SSH | |
| 'chrome', 'firefox', 'safari', 'edge', 'brave', # Browsers | |
| 'curl', 'wget', 'httpd', 'nginx', 'apache2', # HTTP tools/servers | |
| 'python', 'python3', 'node', 'java', 'ruby', # Interpreters | |
| 'postgres', 'mysql', 'mongod', 'redis-server', # Databases | |
| 'docker', 'containerd', 'kubelet', # Container tools | |
| 'systemd', 'init', 'launchd', # System processes | |
| 'prometheus', 'grafana', 'telegraf', # Monitoring | |
| 'code', 'code-server', 'vim', 'emacs', # Editors | |
| 'git', 'git-remote-https', # Version control | |
| 'apt', 'yum', 'dnf', 'brew', 'pip', # Package managers | |
| 'zoom', 'slack', 'teams', 'discord', # Communication | |
| 'spotify', 'vlc', 'mpv', # Media | |
| } | |
| # Suspicious process names (often used by malware or C2) | |
| SUSPICIOUS_PROCESSES = { | |
| 'powershell', 'cmd', 'wscript', 'cscript', 'mshta', # Windows scripting | |
| 'rundll32', 'regsvr32', 'msiexec', # Windows LOLBins | |
| 'nc', 'netcat', 'ncat', 'socat', # Network utilities (legit but suspicious) | |
| 'mimikatz', 'procdump', 'psexec', # Known attack tools | |
| 'beacon', 'payload', 'implant', 'agent', # Common C2 names | |
| } | |
| # Known C2 JA3 fingerprints (example - would be populated from threat intel) | |
| KNOWN_C2_JA3 = { | |
| '72a589da586844d7f0818ce684948eea', # Cobalt Strike (example) | |
| '51c64c77e60f3980eea90869b68c58a8', # Metasploit (example) | |
| } | |
| # Suspicious TLS certificate patterns | |
| SUSPICIOUS_CERT_PATTERNS = [ | |
| r'localhost', | |
| r'test\.', | |
| r'example\.', | |
| r'\.local$', | |
| r'^C2', | |
| r'beacon', | |
| ] | |
| def __init__(self): | |
| self.whitelist_ips: set = set() | |
| self.whitelist_domains: set = set() | |
| self.blacklist_ips: set = set() | |
| self.blacklist_domains: set = set() | |
| self.custom_rules: List[Callable] = [] | |
| def add_whitelist_ip(self, ip: str): | |
| """Add IP to whitelist.""" | |
| self.whitelist_ips.add(ip) | |
| def add_whitelist_domain(self, domain: str): | |
| """Add domain to whitelist.""" | |
| self.whitelist_domains.add(domain.lower()) | |
| def add_blacklist_ip(self, ip: str): | |
| """Add IP to blacklist.""" | |
| self.blacklist_ips.add(ip) | |
| def add_blacklist_domain(self, domain: str): | |
| """Add domain to blacklist.""" | |
| self.blacklist_domains.add(domain.lower()) | |
| def add_custom_rule(self, rule: Callable[[List[Dict], ConnectionContext], Tuple[Optional[float], str]]): | |
| """ | |
| Add custom inference rule. | |
| Rule should return (probability_modifier, reason) or (None, "") to skip. | |
| """ | |
| self.custom_rules.append(rule) | |
| def infer(self, connections: List[Dict], context: Optional[ConnectionContext] = None) -> Dict[str, Any]: | |
| """ | |
| Perform context-based inference. | |
| Returns inference results that can modify detection probability. | |
| """ | |
| result = { | |
| 'probability_modifier': 1.0, | |
| 'confidence_boost': 0.0, | |
| 'is_whitelisted': False, | |
| 'is_blacklisted': False, | |
| 'matched_patterns': [], | |
| 'risk_factors': [], | |
| 'mitigating_factors': [], | |
| 'service_type': ServiceType.UNKNOWN, | |
| 'recommendations': [], | |
| } | |
| if not connections: | |
| return result | |
| dst_ips = set(conn.get('dst_ip', '') for conn in connections) | |
| ports = set(conn.get('dst_port', 0) for conn in connections) | |
| # Check whitelists | |
| for ip in dst_ips: | |
| if ip in self.whitelist_ips: | |
| result['is_whitelisted'] = True | |
| result['probability_modifier'] *= 0.1 | |
| result['mitigating_factors'].append(f"Destination IP {ip} is whitelisted") | |
| # Check blacklists | |
| for ip in dst_ips: | |
| if ip in self.blacklist_ips: | |
| result['is_blacklisted'] = True | |
| result['probability_modifier'] *= 3.0 | |
| result['risk_factors'].append(f"Destination IP {ip} is blacklisted") | |
| if context: | |
| result = self._apply_context(result, connections, context, ports) | |
| # Apply custom rules | |
| for rule in self.custom_rules: | |
| try: | |
| modifier, reason = rule(connections, context) | |
| if modifier is not None: | |
| result['probability_modifier'] *= modifier | |
| if modifier < 1.0: | |
| result['mitigating_factors'].append(reason) | |
| elif modifier > 1.0: | |
| result['risk_factors'].append(reason) | |
| except Exception: | |
| pass | |
| return result | |
| def _apply_context(self, result: Dict, connections: List[Dict], | |
| context: ConnectionContext, ports: set) -> Dict: | |
| """Apply context-based inference rules.""" | |
| # Process name analysis | |
| if context.process_name: | |
| proc_lower = context.process_name.lower() | |
| if proc_lower in self.KNOWN_LEGITIMATE_PROCESSES: | |
| result['mitigating_factors'].append(f"Known legitimate process: {context.process_name}") | |
| result['probability_modifier'] *= 0.5 | |
| if proc_lower in self.SUSPICIOUS_PROCESSES: | |
| result['risk_factors'].append(f"Suspicious process: {context.process_name}") | |
| result['probability_modifier'] *= 1.5 | |
| # SSH-specific checks | |
| if proc_lower in ('sshd', 'ssh', 'openssh') and 22 in ports: | |
| result['mitigating_factors'].append("SSH process on SSH port - expected behavior") | |
| result['probability_modifier'] *= 0.3 | |
| result['service_type'] = ServiceType.SSH | |
| # Explicit known_good/known_bad flags | |
| if context.known_good: | |
| result['is_whitelisted'] = True | |
| result['probability_modifier'] *= 0.1 | |
| result['mitigating_factors'].append("Explicitly marked as known good") | |
| if context.known_bad: | |
| result['is_blacklisted'] = True | |
| result['probability_modifier'] *= 5.0 | |
| result['risk_factors'].append("Explicitly marked as known bad") | |
| # Reputation scores | |
| if context.ip_reputation is not None: | |
| if context.ip_reputation > 0.8: | |
| result['mitigating_factors'].append(f"Good IP reputation: {context.ip_reputation:.2f}") | |
| result['probability_modifier'] *= 0.6 | |
| elif context.ip_reputation < 0.3: | |
| result['risk_factors'].append(f"Poor IP reputation: {context.ip_reputation:.2f}") | |
| result['probability_modifier'] *= 1.5 | |
| if context.domain_reputation is not None: | |
| if context.domain_reputation > 0.8: | |
| result['mitigating_factors'].append(f"Good domain reputation: {context.domain_reputation:.2f}") | |
| result['probability_modifier'] *= 0.6 | |
| elif context.domain_reputation < 0.3: | |
| result['risk_factors'].append(f"Poor domain reputation: {context.domain_reputation:.2f}") | |
| result['probability_modifier'] *= 1.5 | |
| # TLS/JA3 analysis | |
| if context.tls_ja3: | |
| if context.tls_ja3 in self.KNOWN_C2_JA3: | |
| result['risk_factors'].append(f"Known C2 JA3 fingerprint: {context.tls_ja3}") | |
| result['probability_modifier'] *= 3.0 | |
| # Certificate analysis | |
| if context.certificate_subject: | |
| for pattern in self.SUSPICIOUS_CERT_PATTERNS: | |
| if re.search(pattern, context.certificate_subject, re.IGNORECASE): | |
| result['risk_factors'].append(f"Suspicious certificate subject: {context.certificate_subject}") | |
| result['probability_modifier'] *= 1.3 | |
| break | |
| if context.certificate_valid is False: | |
| result['risk_factors'].append("Invalid TLS certificate") | |
| result['probability_modifier'] *= 1.4 | |
| # Threat intel match | |
| if context.threat_intel_match: | |
| result['is_blacklisted'] = True | |
| result['risk_factors'].append(f"Threat intel match: {context.threat_intel_match}") | |
| result['probability_modifier'] *= 5.0 | |
| # DNS analysis | |
| if context.dns_queries: | |
| # Check for suspicious DNS patterns | |
| for query in context.dns_queries: | |
| query_lower = query.lower() | |
| # Check against domain blacklist | |
| if query_lower in self.blacklist_domains: | |
| result['risk_factors'].append(f"Blacklisted domain: {query}") | |
| result['probability_modifier'] *= 2.0 | |
| # Check against whitelist | |
| if query_lower in self.whitelist_domains: | |
| result['mitigating_factors'].append(f"Whitelisted domain: {query}") | |
| result['probability_modifier'] *= 0.5 | |
| # DGA-like patterns (high entropy) | |
| if len(query) > 20 and self._calculate_entropy(query) > 3.5: | |
| result['risk_factors'].append(f"Possible DGA domain: {query}") | |
| result['probability_modifier'] *= 1.3 | |
| # Geo analysis | |
| if context.geo_country: | |
| # Could integrate with threat intel for high-risk countries | |
| pass | |
| return result | |
| def _calculate_entropy(self, s: str) -> float: | |
| """Calculate Shannon entropy of a string.""" | |
| if not s: | |
| return 0.0 | |
| prob = [s.count(c) / len(s) for c in set(s)] | |
| return -sum(p * math.log2(p) for p in prob if p > 0) | |
| # ============================================================================ | |
| # NEURAL NETWORK COMPONENTS | |
| # ============================================================================ | |
| class PositionalEncoding(nn.Module): | |
| """Positional encoding for transformer.""" | |
| def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1): | |
| super().__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0) | |
| self.register_buffer('pe', pe) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.pe[:, :x.size(1)] | |
| return self.dropout(x) | |
| class LogBERTC2Sentinel(nn.Module): | |
| """LogBERT-based model for C2 beacon detection.""" | |
| def __init__(self, config: C2SentinelConfig): | |
| super().__init__() | |
| self.config = config | |
| # Feature projection | |
| self.feature_projection = nn.Sequential( | |
| nn.Linear(config.num_features, config.d_model), | |
| nn.LayerNorm(config.d_model), | |
| nn.GELU(), | |
| nn.Dropout(config.dropout) | |
| ) | |
| # Positional encoding | |
| self.pos_encoder = PositionalEncoding(config.d_model, config.max_seq_length, config.dropout) | |
| # Transformer encoder | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=config.d_model, | |
| nhead=config.nhead, | |
| dim_feedforward=config.dim_feedforward, | |
| dropout=config.dropout, | |
| activation='gelu', | |
| batch_first=True | |
| ) | |
| self.transformer_encoder = nn.TransformerEncoder(encoder_layer, config.num_encoder_layers) | |
| # Multi-task heads | |
| self.c2_head = nn.Sequential( | |
| nn.Linear(config.d_model, config.d_model // 2), | |
| nn.GELU(), | |
| nn.Dropout(config.dropout), | |
| nn.Linear(config.d_model // 2, 1) | |
| ) | |
| self.anomaly_head = nn.Sequential( | |
| nn.Linear(config.d_model, config.d_model // 2), | |
| nn.GELU(), | |
| nn.Dropout(config.dropout), | |
| nn.Linear(config.d_model // 2, 1), | |
| nn.Sigmoid() | |
| ) | |
| self.evasion_head = nn.Sequential( | |
| nn.Linear(config.d_model, config.d_model // 2), | |
| nn.GELU(), | |
| nn.Dropout(config.dropout), | |
| nn.Linear(config.d_model // 2, 1), | |
| nn.Sigmoid() | |
| ) | |
| self.c2_type_head = nn.Sequential( | |
| nn.Linear(config.d_model, config.d_model // 2), | |
| nn.GELU(), | |
| nn.Dropout(config.dropout), | |
| nn.Linear(config.d_model // 2, config.num_c2_types) | |
| ) | |
| self.confidence_head = nn.Sequential( | |
| nn.Linear(config.d_model, config.d_model // 4), | |
| nn.GELU(), | |
| nn.Linear(config.d_model // 4, 1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: | |
| if x.dim() == 2: | |
| x = x.unsqueeze(1) | |
| x = self.feature_projection(x) | |
| x = self.pos_encoder(x) | |
| encoded = self.transformer_encoder(x, src_key_padding_mask=mask) | |
| if mask is not None: | |
| mask_expanded = (~mask).unsqueeze(-1).float() | |
| pooled = (encoded * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1).clamp(min=1) | |
| else: | |
| pooled = encoded.mean(dim=1) | |
| return { | |
| 'c2_logits': self.c2_head(pooled), | |
| 'anomaly_score': self.anomaly_head(pooled), | |
| 'evasion_score': self.evasion_head(pooled), | |
| 'c2_type_logits': self.c2_type_head(pooled), | |
| 'confidence': self.confidence_head(pooled) | |
| } | |
| # ============================================================================ | |
| # FEATURE EXTRACTION | |
| # ============================================================================ | |
| class FeatureExtractor: | |
| """Extracts 40-dimensional feature vectors from network traffic.""" | |
| C2_TYPES = [ | |
| 'unknown', 'metasploit', 'cobalt_strike', 'sliver', 'havoc', | |
| 'mythic', 'poshc2', 'merlin', 'empire', 'covenant', | |
| 'brute_ratel', 'koadic', 'pupy', 'silenttrinity', 'faction', | |
| 'ibombshell', 'godoh', 'dnscat2', 'iodine', 'dns_generic', | |
| 'http_custom', 'https_custom', 'websocket', 'domain_fronting', | |
| 'cloud_fronting', 'cdn_abuse', 'apt_generic', 'apt28', 'apt29', | |
| 'apt41', 'lazarus', 'fin7', 'turla', 'winnti', 'custom' | |
| ] | |
| METASPLOIT_PORTS = {4444, 4445, 5555} | |
| def __init__(self): | |
| self.connection_cache = defaultdict(list) | |
| self.destination_history = defaultdict(set) | |
| def check_metasploit_signature(self, connections: List[Dict]) -> Tuple[bool, float]: | |
| """Check for Metasploit-specific signatures.""" | |
| if not connections: | |
| return False, 0.0 | |
| confidence = 0.0 | |
| indicators = 0 | |
| ports = set(conn.get('dst_port', 0) for conn in connections) | |
| metasploit_port_match = ports & self.METASPLOIT_PORTS | |
| if not metasploit_port_match: | |
| return False, 0.0 | |
| if 4444 in metasploit_port_match: | |
| confidence += 0.6 | |
| indicators += 2 | |
| elif 4445 in metasploit_port_match or 5555 in metasploit_port_match: | |
| confidence += 0.4 | |
| indicators += 1 | |
| if len(connections) > 1: | |
| timestamps = sorted([conn.get('timestamp', 0) for conn in connections]) | |
| intervals = np.diff(timestamps) | |
| if len(intervals) > 0: | |
| mean_interval = np.mean(intervals) | |
| if 1 <= mean_interval <= 30: | |
| confidence += 0.15 | |
| indicators += 1 | |
| bytes_sent = [conn.get('bytes_sent', 0) for conn in connections] | |
| if bytes_sent: | |
| mean_size = np.mean(bytes_sent) | |
| if 50 <= mean_size <= 200: | |
| confidence += 0.1 | |
| indicators += 1 | |
| dst_ips = [conn.get('dst_ip', '') for conn in connections] | |
| if dst_ips: | |
| unique_dsts = len(set(dst_ips)) | |
| if unique_dsts == 1 and len(dst_ips) >= 3: | |
| confidence += 0.1 | |
| indicators += 1 | |
| is_metasploit = indicators >= 2 and confidence >= 0.5 | |
| return is_metasploit, min(confidence, 1.0) | |
| def check_ssh_keepalive(self, connections: List[Dict]) -> Tuple[bool, float]: | |
| """ | |
| Check for SSH keepalive pattern to prevent false positives. | |
| SSH keepalive characteristics: | |
| - Port 22 | |
| - Very small packets (typically 48-64 bytes) | |
| - Nearly symmetric (sent ≈ recv) | |
| - Regular intervals (typically 30s, 60s, 120s) | |
| - Very consistent sizes | |
| Returns (is_ssh_keepalive, confidence) | |
| """ | |
| if not connections or len(connections) < 3: | |
| return False, 0.0 | |
| ports = set(conn.get('dst_port', 0) for conn in connections) | |
| # Must be on SSH port | |
| if 22 not in ports: | |
| return False, 0.0 | |
| bytes_sent = [conn.get('bytes_sent', 0) for conn in connections] | |
| bytes_recv = [conn.get('bytes_recv', 0) for conn in connections] | |
| if not bytes_sent or not bytes_recv: | |
| return False, 0.0 | |
| mean_sent = np.mean(bytes_sent) | |
| mean_recv = np.mean(bytes_recv) | |
| # Check for small packets (keepalive probes are tiny) | |
| if mean_sent > 100 or mean_recv > 100: | |
| # Larger packets = actual SSH traffic, not just keepalive | |
| return False, 0.0 | |
| # Check for symmetric traffic (keepalive is bidirectional probe) | |
| if mean_recv > 0: | |
| ratio = mean_sent / mean_recv | |
| if not (0.5 <= ratio <= 2.0): | |
| # Asymmetric = data transfer, not keepalive | |
| return False, 0.0 | |
| # Check for consistent sizes (keepalive is always same size) | |
| sent_cv = np.std(bytes_sent) / (mean_sent + 1e-6) | |
| recv_cv = np.std(bytes_recv) / (mean_recv + 1e-6) | |
| if sent_cv > 0.2 or recv_cv > 0.2: | |
| # Variable sizes = not keepalive | |
| return False, 0.0 | |
| # Check for regular intervals (keepalive is very regular) | |
| timestamps = sorted([conn.get('timestamp', 0) for conn in connections]) | |
| if len(timestamps) > 1: | |
| intervals = np.diff(timestamps) | |
| if len(intervals) > 0: | |
| mean_interval = np.mean(intervals) | |
| interval_cv = np.std(intervals) / (mean_interval + 1e-6) | |
| # Check if intervals match common keepalive values (15, 30, 60, 120 seconds) | |
| common_keepalive_intervals = [15, 30, 60, 120, 180, 300] | |
| closest_match = min(common_keepalive_intervals, key=lambda x: abs(x - mean_interval)) | |
| interval_match = abs(mean_interval - closest_match) / closest_match < 0.2 | |
| if interval_cv < 0.15 and interval_match: | |
| # Very regular intervals matching keepalive pattern | |
| confidence = 0.95 | |
| elif interval_cv < 0.2: | |
| confidence = 0.85 | |
| else: | |
| return False, 0.0 | |
| return True, confidence | |
| return False, 0.0 | |
| def check_legitimate_patterns(self, connections: List[Dict]) -> Tuple[Optional[LegitimatePattern], float]: | |
| """ | |
| Check if connections match any known legitimate patterns. | |
| Returns (matched_pattern, confidence) or (None, 0.0) | |
| """ | |
| if not connections: | |
| return None, 0.0 | |
| # Calculate stats once | |
| bytes_sent = [conn.get('bytes_sent', 0) for conn in connections] | |
| bytes_recv = [conn.get('bytes_recv', 0) for conn in connections] | |
| stats = { | |
| 'mean_sent': np.mean(bytes_sent) if bytes_sent else 0, | |
| 'mean_recv': np.mean(bytes_recv) if bytes_recv else 0, | |
| 'sent_cv': np.std(bytes_sent) / (np.mean(bytes_sent) + 1e-6) if bytes_sent else 0, | |
| 'recv_cv': np.std(bytes_recv) / (np.mean(bytes_recv) + 1e-6) if bytes_recv else 0, | |
| } | |
| for pattern in LEGITIMATE_PATTERNS: | |
| matches, confidence = pattern.matches(connections, stats) | |
| if matches: | |
| return pattern, confidence | |
| return None, 0.0 | |
| def extract_features(self, connections: List[Dict]) -> np.ndarray: | |
| """Extract 40 features from connection records.""" | |
| if not connections: | |
| return np.zeros(40) | |
| features = np.zeros(40) | |
| # Parse timestamps | |
| timestamps = [] | |
| for conn in connections: | |
| ts = conn.get('timestamp', 0) | |
| if isinstance(ts, str): | |
| try: | |
| ts = datetime.fromisoformat(ts.replace('Z', '+00:00')).timestamp() | |
| except: | |
| ts = 0 | |
| timestamps.append(float(ts)) | |
| timestamps = np.array(sorted(timestamps)) | |
| # === TIMING FEATURES (0-9) === | |
| if len(timestamps) > 1: | |
| intervals = np.diff(timestamps) | |
| intervals = intervals[intervals > 0] | |
| if len(intervals) > 0: | |
| features[0] = np.mean(intervals) | |
| features[1] = np.std(intervals) | |
| features[2] = np.std(intervals) / (np.mean(intervals) + 1e-6) | |
| features[3] = np.median(intervals) | |
| features[4] = np.min(intervals) | |
| features[5] = np.max(intervals) | |
| if len(intervals) > 2: | |
| sorted_intervals = np.sort(intervals) | |
| mode_estimate = sorted_intervals[len(sorted_intervals)//2] | |
| regularity = 1.0 - np.mean(np.abs(intervals - mode_estimate) / (mode_estimate + 1e-6)) | |
| features[6] = max(0, min(1, regularity)) | |
| if len(intervals) >= 8: | |
| fft = np.fft.fft(intervals - np.mean(intervals)) | |
| power = np.abs(fft[:len(fft)//2])**2 | |
| features[7] = np.max(power) / (np.sum(power) + 1e-6) | |
| hours = [(ts % 86400) / 3600 for ts in timestamps] | |
| features[8] = np.std(hours) / 12.0 | |
| business_hours = sum(1 for h in hours if 9 <= h <= 17) / len(hours) | |
| features[9] = business_hours | |
| # === DESTINATION FEATURES (10-17) === | |
| dst_ips = [conn.get('dst_ip', '') for conn in connections] | |
| dst_ports = [conn.get('dst_port', 0) for conn in connections] | |
| unique_dsts = len(set(dst_ips)) | |
| features[10] = unique_dsts | |
| features[11] = unique_dsts / len(connections) if connections else 0 | |
| if dst_ips: | |
| dst_counts = defaultdict(int) | |
| for ip in dst_ips: | |
| dst_counts[ip] += 1 | |
| max_persistence = max(dst_counts.values()) | |
| features[12] = max_persistence / len(connections) | |
| features[13] = len([c for c in dst_counts.values() if c > 1]) / len(dst_counts) if dst_counts else 0 | |
| unique_ports = len(set(dst_ports)) | |
| features[14] = unique_ports | |
| features[15] = 1.0 if 443 in dst_ports or 80 in dst_ports else 0.0 | |
| high_port_ratio = sum(1 for p in dst_ports if p > 10000) / len(dst_ports) if dst_ports else 0 | |
| features[16] = high_port_ratio | |
| msf_port_hit = any(p in self.METASPLOIT_PORTS for p in dst_ports) | |
| features[17] = 1.0 if msf_port_hit else 0.0 | |
| # === PAYLOAD FEATURES (18-27) === | |
| bytes_sent = [conn.get('bytes_sent', 0) for conn in connections] | |
| bytes_recv = [conn.get('bytes_recv', 0) for conn in connections] | |
| if bytes_sent: | |
| features[18] = np.mean(bytes_sent) | |
| features[19] = np.std(bytes_sent) | |
| features[20] = np.std(bytes_sent) / (np.mean(bytes_sent) + 1e-6) | |
| if bytes_recv: | |
| features[21] = np.mean(bytes_recv) | |
| features[22] = np.std(bytes_recv) | |
| total_sent = sum(bytes_sent) | |
| total_recv = sum(bytes_recv) | |
| features[23] = total_sent / (total_recv + 1e-6) if total_recv else 0 | |
| if len(bytes_sent) > 1: | |
| unique_sizes = len(set(bytes_sent)) | |
| features[24] = 1.0 - (unique_sizes / len(bytes_sent)) | |
| features[25] = sum(1 for b in bytes_sent if b < 500) / len(bytes_sent) if bytes_sent else 0 | |
| if bytes_sent: | |
| size_hist, _ = np.histogram(bytes_sent, bins=10) | |
| size_hist = size_hist / (sum(size_hist) + 1e-6) | |
| entropy = -np.sum(size_hist * np.log2(size_hist + 1e-6)) | |
| features[26] = entropy / 3.32 | |
| features[27] = len(connections) | |
| # === EVASION DETECTION FEATURES (28-35) === | |
| if len(timestamps) > 5: | |
| intervals = np.diff(timestamps) | |
| if len(intervals) > 0: | |
| jitter_pattern = np.abs(np.diff(intervals)) | |
| if len(jitter_pattern) > 0: | |
| features[28] = np.mean(jitter_pattern) / (np.mean(intervals) + 1e-6) | |
| autocorr = np.correlate(intervals - np.mean(intervals), intervals - np.mean(intervals), mode='full') | |
| autocorr = autocorr[len(autocorr)//2:] | |
| if len(autocorr) > 1: | |
| features[29] = autocorr[1] / (autocorr[0] + 1e-6) | |
| if len(timestamps) > 3: | |
| intervals = np.diff(timestamps) | |
| burst_threshold = np.mean(intervals) * 0.1 | |
| bursts = sum(1 for i in intervals if i < burst_threshold) | |
| features[30] = bursts / len(intervals) if intervals.size > 0 else 0 | |
| if timestamps.size > 0: | |
| session_length = timestamps[-1] - timestamps[0] | |
| features[31] = min(session_length / 86400, 1.0) | |
| if len(timestamps) > 10: | |
| window_size = len(timestamps) // 5 | |
| window_counts = [] | |
| for i in range(5): | |
| start_idx = i * window_size | |
| end_idx = start_idx + window_size | |
| window_counts.append(end_idx - start_idx) | |
| features[32] = 1.0 - (np.std(window_counts) / (np.mean(window_counts) + 1e-6)) | |
| protocols = [conn.get('protocol', 'tcp').lower() for conn in connections] | |
| unique_protocols = len(set(protocols)) | |
| features[33] = 1.0 if unique_protocols == 1 else 1.0 / unique_protocols | |
| features[34] = sum(1 for p in dst_ports if p in [80, 443, 8080, 8443]) / len(dst_ports) if dst_ports else 0 | |
| features[35] = sum(1 for p in dst_ports if p == 443) / len(dst_ports) if dst_ports else 0 | |
| # === ADVANCED PATTERN FEATURES (36-39) === | |
| if timestamps.size > 0: | |
| night_hours = sum(1 for ts in timestamps if 0 <= (ts % 86400) / 3600 < 6) | |
| features[36] = night_hours / len(timestamps) | |
| if len(timestamps) > 1: | |
| intervals = np.diff(timestamps) | |
| fast_beacon_ratio = sum(1 for i in intervals if 1 <= i <= 5) / len(intervals) if len(intervals) > 0 else 0 | |
| features[37] = fast_beacon_ratio | |
| durations = [conn.get('duration', 0) for conn in connections] | |
| if durations: | |
| features[38] = np.mean(durations) | |
| features[39] = np.std(durations) / (np.mean(durations) + 1e-6) if np.mean(durations) > 0 else 0 | |
| return features.astype(np.float32) | |
| # ============================================================================ | |
| # LOG PARSING | |
| # ============================================================================ | |
| class LogParser: | |
| """Parses various log formats into connection records.""" | |
| def parse_zeek_conn(log_line: str) -> Optional[Dict]: | |
| """Parse Zeek/Bro conn.log format.""" | |
| try: | |
| # Skip header lines | |
| if log_line.startswith('#'): | |
| return None | |
| parts = log_line.strip().split('\t') | |
| # Minimum fields: ts, uid, orig_h, orig_p, resp_h, resp_p, proto, service, duration, orig_bytes, resp_bytes | |
| if len(parts) >= 11: | |
| return { | |
| 'timestamp': float(parts[0]), | |
| 'src_ip': parts[2], | |
| 'src_port': int(parts[3]) if parts[3] != '-' else 0, | |
| 'dst_ip': parts[4], | |
| 'dst_port': int(parts[5]) if parts[5] != '-' else 0, | |
| 'protocol': parts[6], | |
| 'duration': float(parts[8]) if parts[8] != '-' else 0, | |
| 'bytes_sent': int(parts[9]) if parts[9] != '-' else 0, | |
| 'bytes_recv': int(parts[10]) if parts[10] != '-' else 0 | |
| } | |
| except: | |
| pass | |
| return None | |
| def parse_syslog(log_line: str) -> Optional[Dict]: | |
| """Parse common syslog/firewall formats.""" | |
| from datetime import datetime | |
| # Linux iptables format: SRC=x.x.x.x DST=x.x.x.x SPT=xxx DPT=xxx LEN=xxx | |
| iptables_match = re.search( | |
| r'(\w{3}\s+\d+\s+\d+:\d+:\d+).*?SRC=(\d+\.\d+\.\d+\.\d+).*?DST=(\d+\.\d+\.\d+\.\d+).*?SPT=(\d+).*?DPT=(\d+)(?:.*?LEN=(\d+))?', | |
| log_line, re.IGNORECASE | |
| ) | |
| if iptables_match: | |
| try: | |
| ts_str = iptables_match.group(1) | |
| # Parse timestamp like "Jan 18 10:00:00" | |
| dt = datetime.strptime(f"2026 {ts_str}", "%Y %b %d %H:%M:%S") | |
| return { | |
| 'timestamp': dt.timestamp(), | |
| 'src_ip': iptables_match.group(2), | |
| 'dst_ip': iptables_match.group(3), | |
| 'src_port': int(iptables_match.group(4)), | |
| 'dst_port': int(iptables_match.group(5)), | |
| 'protocol': 'tcp', | |
| 'bytes_sent': int(iptables_match.group(6) or 0), | |
| 'bytes_recv': 0 | |
| } | |
| except: | |
| pass | |
| # Windows Firewall format: TimeGenerated=xxx SourceAddress=xxx DestAddress=xxx DestPort=xxx | |
| win_match = re.search( | |
| r'TimeGenerated=(\S+).*?(?:SourceAddress|SourceIP)=(\d+\.\d+\.\d+\.\d+).*?(?:DestAddress|DestinationIP)=(\d+\.\d+\.\d+\.\d+).*?(?:DestPort|DestinationPort)=(\d+)', | |
| log_line, re.IGNORECASE | |
| ) | |
| if win_match: | |
| try: | |
| ts_str = win_match.group(1) | |
| dt = datetime.fromisoformat(ts_str.replace('Z', '+00:00')) | |
| return { | |
| 'timestamp': dt.timestamp(), | |
| 'src_ip': win_match.group(2), | |
| 'dst_ip': win_match.group(3), | |
| 'src_port': 0, | |
| 'dst_port': int(win_match.group(4)), | |
| 'protocol': 'tcp', | |
| 'bytes_sent': 0, | |
| 'bytes_recv': 0 | |
| } | |
| except: | |
| pass | |
| # Generic key=value format | |
| kv_match = re.findall(r'(\w+)=(\S+)', log_line) | |
| if kv_match: | |
| kv = dict(kv_match) | |
| dst_ip = kv.get('dst') or kv.get('DST') or kv.get('DestAddress') or kv.get('dest_ip') | |
| dst_port = kv.get('dport') or kv.get('DPT') or kv.get('DestPort') or kv.get('dest_port') | |
| if dst_ip and dst_port: | |
| try: | |
| return { | |
| 'timestamp': 0, | |
| 'src_ip': kv.get('src') or kv.get('SRC') or kv.get('SourceAddress') or '', | |
| 'dst_ip': dst_ip, | |
| 'src_port': int(kv.get('sport') or kv.get('SPT') or kv.get('SourcePort') or 0), | |
| 'dst_port': int(dst_port), | |
| 'protocol': kv.get('proto') or kv.get('Protocol') or 'tcp', | |
| 'bytes_sent': int(kv.get('bytes') or kv.get('LEN') or 0), | |
| 'bytes_recv': 0 | |
| } | |
| except: | |
| pass | |
| return None | |
| def parse_csv(log_line: str, headers: List[str] = None) -> Optional[Dict]: | |
| """Parse CSV log format.""" | |
| from datetime import datetime | |
| if not headers or log_line.startswith('timestamp'): | |
| return None # Skip header row | |
| try: | |
| parts = log_line.strip().split(',') | |
| if len(parts) >= 5: | |
| # Try to map by position if we have standard columns | |
| ts_str = parts[0].strip() | |
| try: | |
| dt = datetime.fromisoformat(ts_str.replace('Z', '+00:00')) | |
| ts = dt.timestamp() | |
| except: | |
| ts = 0 | |
| return { | |
| 'timestamp': ts, | |
| 'src_ip': parts[1].strip() if len(parts) > 1 else '', | |
| 'src_port': int(parts[2].strip()) if len(parts) > 2 and parts[2].strip().isdigit() else 0, | |
| 'dst_ip': parts[3].strip() if len(parts) > 3 else '', | |
| 'dst_port': int(parts[4].strip()) if len(parts) > 4 and parts[4].strip().isdigit() else 0, | |
| 'protocol': parts[5].strip() if len(parts) > 5 else 'tcp', | |
| 'bytes_sent': int(parts[6].strip()) if len(parts) > 6 and parts[6].strip().isdigit() else 0, | |
| 'bytes_recv': int(parts[7].strip()) if len(parts) > 7 and parts[7].strip().isdigit() else 0 | |
| } | |
| except: | |
| pass | |
| return None | |
| def parse_json(log_line: str) -> Optional[Dict]: | |
| """Parse JSON log format.""" | |
| try: | |
| data = json.loads(log_line) | |
| return { | |
| 'timestamp': data.get('timestamp', data.get('@timestamp', 0)), | |
| 'src_ip': data.get('src_ip', data.get('source_ip', data.get('src', ''))), | |
| 'dst_ip': data.get('dst_ip', data.get('dest_ip', data.get('dst', ''))), | |
| 'src_port': int(data.get('src_port', data.get('source_port', 0))), | |
| 'dst_port': int(data.get('dst_port', data.get('dest_port', 0))), | |
| 'protocol': data.get('protocol', 'tcp'), | |
| 'bytes_sent': int(data.get('bytes_sent', data.get('bytes_out', 0))), | |
| 'bytes_recv': int(data.get('bytes_recv', data.get('bytes_in', 0))), | |
| 'duration': float(data.get('duration', 0)) | |
| } | |
| except: | |
| return None | |
| # ============================================================================ | |
| # RECONNAISSANCE SUPPORT | |
| # ============================================================================ | |
| class ReconSupport: | |
| """ | |
| Reconnaissance and enrichment support for scripting. | |
| Provides IP analysis, network intelligence, and enrichment functions | |
| useful for security automation and scripting. | |
| """ | |
| # Known CDN/Cloud provider IP ranges (simplified - in production, use full lists) | |
| KNOWN_CDNS = { | |
| 'cloudflare': ['104.16.0.0/12', '172.64.0.0/13', '131.0.72.0/22'], | |
| 'aws': ['52.0.0.0/6', '54.0.0.0/6'], | |
| 'google': ['35.190.0.0/16', '35.220.0.0/14', '142.250.0.0/15'], | |
| 'azure': ['13.64.0.0/11', '40.64.0.0/10'], | |
| 'akamai': ['23.0.0.0/12', '104.64.0.0/10'], | |
| } | |
| # Private IP ranges | |
| PRIVATE_RANGES = [ | |
| ipaddress.ip_network('10.0.0.0/8'), | |
| ipaddress.ip_network('172.16.0.0/12'), | |
| ipaddress.ip_network('192.168.0.0/16'), | |
| ipaddress.ip_network('127.0.0.0/8'), | |
| ipaddress.ip_network('169.254.0.0/16'), | |
| ] | |
| def analyze_ip(cls, ip: str) -> Dict[str, Any]: | |
| """ | |
| Analyze an IP address for reconnaissance purposes. | |
| Returns enrichment data about the IP. | |
| """ | |
| result = { | |
| 'ip': ip, | |
| 'is_valid': False, | |
| 'is_private': False, | |
| 'is_loopback': False, | |
| 'is_multicast': False, | |
| 'is_cdn': False, | |
| 'cdn_provider': None, | |
| 'ip_version': None, | |
| 'reverse_dns': None, | |
| 'numeric': None, | |
| } | |
| try: | |
| ip_obj = ipaddress.ip_address(ip) | |
| result['is_valid'] = True | |
| result['ip_version'] = ip_obj.version | |
| result['is_private'] = ip_obj.is_private | |
| result['is_loopback'] = ip_obj.is_loopback | |
| result['is_multicast'] = ip_obj.is_multicast | |
| # Convert to numeric for range analysis | |
| if isinstance(ip_obj, ipaddress.IPv4Address): | |
| result['numeric'] = int(ip_obj) | |
| # Check CDN ranges | |
| for cdn, ranges in cls.KNOWN_CDNS.items(): | |
| for range_str in ranges: | |
| try: | |
| network = ipaddress.ip_network(range_str) | |
| if ip_obj in network: | |
| result['is_cdn'] = True | |
| result['cdn_provider'] = cdn | |
| break | |
| except: | |
| pass | |
| if result['is_cdn']: | |
| break | |
| # Try reverse DNS (optional, may fail) | |
| try: | |
| result['reverse_dns'] = socket.gethostbyaddr(ip)[0] | |
| except: | |
| pass | |
| except ValueError: | |
| pass | |
| return result | |
| def analyze_connection_patterns(cls, connections: List[Dict]) -> Dict[str, Any]: | |
| """ | |
| Analyze connection patterns for reconnaissance. | |
| Provides high-level pattern analysis useful for threat hunting. | |
| """ | |
| if not connections: | |
| return {'error': 'No connections provided'} | |
| dst_ips = [conn.get('dst_ip', '') for conn in connections] | |
| dst_ports = [conn.get('dst_port', 0) for conn in connections] | |
| bytes_sent = [conn.get('bytes_sent', 0) for conn in connections] | |
| bytes_recv = [conn.get('bytes_recv', 0) for conn in connections] | |
| timestamps = sorted([conn.get('timestamp', 0) for conn in connections]) | |
| intervals = np.diff(timestamps) if len(timestamps) > 1 else [] | |
| # Destination analysis | |
| unique_dsts = set(dst_ips) | |
| dst_analysis = {} | |
| for ip in unique_dsts: | |
| if ip: | |
| dst_analysis[ip] = cls.analyze_ip(ip) | |
| # Port analysis | |
| port_counts = defaultdict(int) | |
| for port in dst_ports: | |
| port_counts[port] += 1 | |
| # Calculate statistics | |
| result = { | |
| 'connection_count': len(connections), | |
| 'unique_destinations': len(unique_dsts), | |
| 'unique_ports': len(set(dst_ports)), | |
| # Timing analysis | |
| 'timing': { | |
| 'duration_seconds': timestamps[-1] - timestamps[0] if len(timestamps) > 1 else 0, | |
| 'mean_interval': float(np.mean(intervals)) if len(intervals) > 0 else 0, | |
| 'interval_stddev': float(np.std(intervals)) if len(intervals) > 0 else 0, | |
| 'interval_cv': float(np.std(intervals) / (np.mean(intervals) + 1e-6)) if len(intervals) > 0 else 0, | |
| }, | |
| # Volume analysis | |
| 'volume': { | |
| 'total_sent': sum(bytes_sent), | |
| 'total_recv': sum(bytes_recv), | |
| 'mean_sent': float(np.mean(bytes_sent)) if bytes_sent else 0, | |
| 'mean_recv': float(np.mean(bytes_recv)) if bytes_recv else 0, | |
| 'sent_recv_ratio': sum(bytes_sent) / (sum(bytes_recv) + 1e-6) if bytes_recv else 0, | |
| }, | |
| # Port distribution | |
| 'ports': dict(port_counts), | |
| # Destination enrichment | |
| 'destinations': dst_analysis, | |
| # Pattern indicators | |
| 'indicators': { | |
| 'single_destination': len(unique_dsts) == 1, | |
| 'consistent_timing': float(np.std(intervals) / (np.mean(intervals) + 1e-6)) < 0.3 if len(intervals) > 0 else False, | |
| 'consistent_sizes': float(np.std(bytes_sent) / (np.mean(bytes_sent) + 1e-6)) < 0.2 if bytes_sent and np.mean(bytes_sent) > 0 else False, | |
| 'uses_common_port': bool(set(dst_ports) & {80, 443, 53, 22}), | |
| 'uses_high_port': any(p > 10000 for p in dst_ports), | |
| 'has_cdn_destination': any(d.get('is_cdn', False) for d in dst_analysis.values()), | |
| 'all_private_destinations': all(d.get('is_private', False) for d in dst_analysis.values() if d.get('is_valid')), | |
| }, | |
| } | |
| return result | |
| def generate_iocs(cls, connections: List[Dict], result: Dict) -> Dict[str, List[str]]: | |
| """ | |
| Generate Indicators of Compromise (IOCs) from analysis. | |
| Returns IOCs suitable for threat intelligence sharing. | |
| """ | |
| iocs = { | |
| 'ips': [], | |
| 'ports': [], | |
| 'timing_signatures': [], | |
| 'behavioral_indicators': [], | |
| } | |
| if not result.get('is_c2', False): | |
| return iocs | |
| # Extract destination IPs | |
| dst_ips = set(conn.get('dst_ip', '') for conn in connections if conn.get('dst_ip')) | |
| iocs['ips'] = list(dst_ips) | |
| # Extract ports | |
| dst_ports = set(conn.get('dst_port', 0) for conn in connections if conn.get('dst_port')) | |
| iocs['ports'] = [str(p) for p in dst_ports] | |
| # Generate timing signature | |
| timestamps = sorted([conn.get('timestamp', 0) for conn in connections]) | |
| if len(timestamps) > 1: | |
| intervals = np.diff(timestamps) | |
| mean_interval = np.mean(intervals) | |
| iocs['timing_signatures'].append(f"beacon_interval:{mean_interval:.1f}s±{np.std(intervals):.1f}s") | |
| # Behavioral indicators | |
| if result.get('c2_type'): | |
| iocs['behavioral_indicators'].append(f"c2_type:{result['c2_type']}") | |
| if result.get('evasion_score', 0) > 0.5: | |
| iocs['behavioral_indicators'].append("evasion_detected") | |
| return iocs | |
| # ============================================================================ | |
| # MAIN API CLASS | |
| # ============================================================================ | |
| class AnalysisResult: | |
| """Structured result from C2 analysis.""" | |
| is_c2: bool | |
| c2_probability: float | |
| anomaly_score: float | |
| evasion_score: float | |
| confidence: float | |
| c2_type: str | |
| c2_type_confidence: float | |
| detection_method: str | |
| immediate_detection: bool | |
| # Context-based adjustments | |
| context_applied: bool = False | |
| original_probability: float = 0.0 | |
| probability_modifier: float = 1.0 | |
| # Legitimate pattern matching | |
| matched_legitimate_pattern: Optional[str] = None | |
| legitimate_confidence: float = 0.0 | |
| # Risk analysis | |
| risk_factors: List[str] = field(default_factory=list) | |
| mitigating_factors: List[str] = field(default_factory=list) | |
| # Service classification | |
| service_type: str = "unknown" | |
| # Recommendations | |
| recommendations: List[str] = field(default_factory=list) | |
| # Raw features | |
| features: List[float] = field(default_factory=list) | |
| # Connection-level details for scripting | |
| connections_analyzed: int = 0 | |
| suspicious_connections: List[Dict] = field(default_factory=list) | |
| iocs: Dict[str, Any] = field(default_factory=dict) | |
| time_range: Dict[str, float] = field(default_factory=dict) | |
| destination_summary: Dict[str, Any] = field(default_factory=dict) | |
| def to_dict(self) -> Dict[str, Any]: | |
| return asdict(self) | |
| def to_json(self, indent: int = 2) -> str: | |
| """Return JSON-formatted result for scripting.""" | |
| return json.dumps(self.to_dict(), indent=indent, default=str) | |
| def to_ioc_format(self) -> Dict[str, Any]: | |
| """Return IOCs in STIX-like format for threat intel platforms.""" | |
| return { | |
| 'type': 'indicator', | |
| 'spec_version': '2.1', | |
| 'pattern_type': 'c2-beacon', | |
| 'valid_from': self.time_range.get('start'), | |
| 'labels': ['malicious-activity', 'c2'] if self.is_c2 else ['benign'], | |
| 'confidence': int(self.confidence * 100), | |
| 'indicators': self.iocs | |
| } | |
| def __repr__(self) -> str: | |
| status = "C2 DETECTED" if self.is_c2 else "Clean" | |
| return f"<AnalysisResult: {status} | prob={self.c2_probability:.3f} | type={self.c2_type}>" | |
| class C2Sentinel: | |
| """ | |
| Main API for LogBERT-C2Sentinel. | |
| Advanced C2 detection with context inference and reconnaissance support. | |
| Usage: | |
| # Load pre-trained model | |
| sentinel = C2Sentinel.load('c2_sentinel') | |
| # Basic analysis | |
| result = sentinel.analyze(connections) | |
| # With context | |
| context = ConnectionContext(process_name='sshd', known_good=True) | |
| result = sentinel.analyze(connections, context=context) | |
| # Batch analysis | |
| results = sentinel.analyze_batch([conn_list1, conn_list2, ...]) | |
| # With reconnaissance | |
| recon = sentinel.recon.analyze_connection_patterns(connections) | |
| iocs = sentinel.recon.generate_iocs(connections, result) | |
| """ | |
| def __init__(self, model: LogBERTC2Sentinel, config: C2SentinelConfig, device: str = 'auto'): | |
| self.model = model | |
| self.config = config | |
| self.feature_extractor = FeatureExtractor() | |
| self.log_parser = LogParser() | |
| self.context_engine = ContextInference() | |
| self.recon = ReconSupport() | |
| if device == 'auto': | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| else: | |
| self.device = torch.device(device) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| def analyze( | |
| self, | |
| connections: List[Dict], | |
| threshold: float = 0.5, | |
| context: Optional[ConnectionContext] = None, | |
| include_features: bool = False, | |
| strict_mode: bool = False | |
| ) -> AnalysisResult: | |
| """ | |
| Analyze connections for C2 activity. | |
| Args: | |
| connections: List of connection records | |
| threshold: Detection threshold (default 0.5, use 0.7 for fewer false positives) | |
| context: Optional ConnectionContext with additional metadata | |
| include_features: Include raw feature vector in result | |
| strict_mode: Require higher confidence for C2 detection | |
| Returns: | |
| AnalysisResult with comprehensive detection information | |
| """ | |
| ports = set(conn.get('dst_port', 0) for conn in connections) | |
| # Initialize result | |
| result = AnalysisResult( | |
| is_c2=False, | |
| c2_probability=0.0, | |
| anomaly_score=0.0, | |
| evasion_score=0.0, | |
| confidence=0.0, | |
| c2_type='none', | |
| c2_type_confidence=0.0, | |
| detection_method='none', | |
| immediate_detection=False, | |
| ) | |
| if not connections: | |
| return result | |
| # ================================================================ | |
| # PHASE 1: Check for known legitimate patterns FIRST | |
| # ================================================================ | |
| # Check SSH keepalive specifically (common false positive) | |
| is_ssh_keepalive, ssh_ka_confidence = self.feature_extractor.check_ssh_keepalive(connections) | |
| if is_ssh_keepalive: | |
| result.matched_legitimate_pattern = "ssh_keepalive" | |
| result.legitimate_confidence = ssh_ka_confidence | |
| result.service_type = ServiceType.SSH.value | |
| result.mitigating_factors.append(f"Matches SSH keepalive pattern (confidence: {ssh_ka_confidence:.2f})") | |
| result.detection_method = DetectionMethod.WHITELIST.value | |
| result.recommendations.append("SSH keepalive is normal system behavior") | |
| # SSH keepalive should NOT be flagged as C2 | |
| result.is_c2 = False | |
| result.c2_probability = 0.05 # Very low probability | |
| result.confidence = ssh_ka_confidence | |
| return result | |
| # Check other legitimate patterns | |
| matched_pattern, pattern_confidence = self.feature_extractor.check_legitimate_patterns(connections) | |
| if matched_pattern and pattern_confidence > 0.7: | |
| result.matched_legitimate_pattern = matched_pattern.name | |
| result.legitimate_confidence = pattern_confidence | |
| result.service_type = matched_pattern.service_type.value | |
| result.mitigating_factors.append(f"Matches {matched_pattern.name} pattern: {matched_pattern.description}") | |
| # ================================================================ | |
| # PHASE 2: Check for high-confidence C2 signatures | |
| # ================================================================ | |
| is_msf, msf_confidence = self.feature_extractor.check_metasploit_signature(connections) | |
| if is_msf: | |
| result.is_c2 = True | |
| result.c2_probability = msf_confidence | |
| result.anomaly_score = 0.95 | |
| result.evasion_score = 0.1 | |
| result.confidence = msf_confidence | |
| result.c2_type = 'metasploit' | |
| result.c2_type_confidence = msf_confidence | |
| result.immediate_detection = True | |
| result.detection_method = DetectionMethod.SIGNATURE.value | |
| result.risk_factors.append("Matches Metasploit signature (high-confidence C2 port + behavior)") | |
| if include_features: | |
| result.features = self.feature_extractor.extract_features(connections).tolist() | |
| return result | |
| # ================================================================ | |
| # PHASE 3: ML-based behavioral analysis | |
| # ================================================================ | |
| features = self.feature_extractor.extract_features(connections) | |
| features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(features_tensor) | |
| c2_prob = torch.sigmoid(outputs['c2_logits']).item() | |
| result.original_probability = c2_prob | |
| result.anomaly_score = outputs['anomaly_score'].item() | |
| result.evasion_score = outputs['evasion_score'].item() | |
| result.confidence = outputs['confidence'].item() | |
| # Get C2 type prediction | |
| c2_type_probs = F.softmax(outputs['c2_type_logits'], dim=-1) | |
| c2_type_idx = torch.argmax(c2_type_probs, dim=-1).item() | |
| result.c2_type = FeatureExtractor.C2_TYPES[c2_type_idx] | |
| result.c2_type_confidence = c2_type_probs[0, c2_type_idx].item() | |
| # ================================================================ | |
| # PHASE 4: Behavioral refinement | |
| # ================================================================ | |
| beacon_indicators = 0 # Initialize here so it's always defined | |
| dst_ips = set(conn.get('dst_ip', '') for conn in connections) | |
| bytes_recv = [conn.get('bytes_recv', 0) for conn in connections] | |
| bytes_sent = [conn.get('bytes_sent', 0) for conn in connections] | |
| recv_cv = np.std(bytes_recv) / (np.mean(bytes_recv) + 1e-6) if bytes_recv else 0 | |
| sent_cv = np.std(bytes_sent) / (np.mean(bytes_sent) + 1e-6) if bytes_sent else 0 | |
| total_sent = sum(bytes_sent) | |
| total_recv = sum(bytes_recv) | |
| req_resp_ratio = total_sent / (total_recv + 1e-6) if total_recv else float('inf') | |
| # Multiple destinations with high variance = likely benign | |
| if len(dst_ips) > 5 and bytes_recv and recv_cv > 2: | |
| c2_prob *= 0.4 | |
| result.mitigating_factors.append("Multiple destinations with high response variance") | |
| # Single destination analysis | |
| if len(dst_ips) == 1 and len(connections) >= 5: | |
| timestamps = sorted([c.get('timestamp', 0) for c in connections]) | |
| if len(timestamps) > 1: | |
| intervals = np.diff(timestamps) | |
| mean_interval = np.mean(intervals) if len(intervals) > 0 else 0 | |
| interval_cv = np.std(intervals) / (mean_interval + 1e-6) if mean_interval > 0 else 0 | |
| # Response variance analysis | |
| if recv_cv > 0.5: | |
| c2_prob *= 0.5 | |
| result.mitigating_factors.append("High response size variance (likely data retrieval)") | |
| elif recv_cv < 0.2 and sent_cv < 0.2: | |
| c2_prob = min(1.0, c2_prob * 1.4) | |
| result.risk_factors.append("Very consistent request/response sizes") | |
| # Request/response ratio | |
| if req_resp_ratio < 0.1: | |
| c2_prob *= 0.4 | |
| result.mitigating_factors.append("Asymmetric traffic (small requests, large responses)") | |
| elif 0.2 < req_resp_ratio < 0.8: | |
| c2_prob = min(1.0, c2_prob * 1.2) | |
| result.risk_factors.append("Balanced request/response ratio (C2-like)") | |
| # Beacon regularity | |
| if interval_cv < 0.3 and mean_interval > 0 and recv_cv < 0.3: | |
| c2_prob = min(1.0, c2_prob * 1.3) | |
| result.risk_factors.append("Regular timing with consistent sizes") | |
| # Slow beacon detection | |
| if mean_interval > 60 and recv_cv < 0.15 and sent_cv < 0.15: | |
| c2_prob = min(1.0, c2_prob * 1.5) | |
| result.risk_factors.append("APT-style slow beacon pattern") | |
| # ============================================================ | |
| # CRITICAL: Explicit beacon pattern override | |
| # When ALL classic beacon indicators are present, force detection | |
| # ============================================================ | |
| beacon_indicators = 0 | |
| # Indicator 1: Very regular timing (CV < 0.15) | |
| if interval_cv < 0.15: | |
| beacon_indicators += 1 | |
| # Indicator 2: Very consistent sizes (both sent and recv CV < 0.15) | |
| if recv_cv < 0.15 and sent_cv < 0.15: | |
| beacon_indicators += 1 | |
| # Indicator 3: Small packet sizes (typical heartbeat) | |
| mean_sent = np.mean(bytes_sent) if bytes_sent else 0 | |
| mean_recv = np.mean(bytes_recv) if bytes_recv else 0 | |
| if mean_sent < 500 and mean_recv < 500: | |
| beacon_indicators += 1 | |
| # Indicator 4: Regular interval in beacon range (5s - 300s) | |
| if 5 <= mean_interval <= 300: | |
| beacon_indicators += 1 | |
| # Indicator 5: Sufficient sample size | |
| if len(connections) >= 8: | |
| beacon_indicators += 1 | |
| # If 4+ of 5 indicators present, this is almost certainly C2 | |
| if beacon_indicators >= 4: | |
| c2_prob = max(c2_prob, 0.85) # Force high probability | |
| result.risk_factors.append(f"Classic C2 beacon pattern detected ({beacon_indicators}/5 indicators)") | |
| result.detection_method = DetectionMethod.BEHAVIORAL.value | |
| elif beacon_indicators >= 3: | |
| c2_prob = max(c2_prob, 0.65) # Likely C2 | |
| result.risk_factors.append(f"Probable C2 beacon pattern ({beacon_indicators}/5 indicators)") | |
| # ================================================================ | |
| # PHASE 5: Apply legitimate pattern discount | |
| # Balance between legitimate patterns and beacon indicators | |
| # ================================================================ | |
| # Check if we have a very strong beacon signal that should override patterns | |
| very_strong_beacon = False | |
| if beacon_indicators >= 5: | |
| very_strong_beacon = True | |
| elif beacon_indicators >= 4: | |
| # Check if timing and size CVs are extremely low (strong C2 signature) | |
| if len(connections) >= 5: | |
| timestamps = sorted([c.get('timestamp', 0) for c in connections]) | |
| if len(timestamps) > 1: | |
| intervals = np.diff(timestamps) | |
| interval_cv = np.std(intervals) / (np.mean(intervals) + 1e-6) | |
| if interval_cv < 0.1 and recv_cv < 0.1: | |
| very_strong_beacon = True | |
| if matched_pattern and pattern_confidence > 0.5: | |
| # Very strong beacon signals override legitimate patterns (except SSH on port 22) | |
| if very_strong_beacon and matched_pattern.name != 'ssh_keepalive': | |
| result.mitigating_factors.append(f"{matched_pattern.name} pattern overridden by very strong beacon signal") | |
| elif pattern_confidence >= 0.75 and not very_strong_beacon: | |
| # Strong legitimate pattern without strong beacon - apply full discount | |
| discount = 1.0 - (pattern_confidence * 0.8) # Up to 80% reduction | |
| c2_prob *= discount | |
| result.mitigating_factors.append(f"Strong {matched_pattern.name} pattern match (conf: {pattern_confidence:.0%})") | |
| result.detection_method = DetectionMethod.WHITELIST.value | |
| elif beacon_indicators >= 4 and pattern_confidence < 0.6: | |
| # Strong beacon + weak pattern match - beacon wins | |
| result.mitigating_factors.append(f"Weak {matched_pattern.name} match overridden by beacon indicators") | |
| elif beacon_indicators >= 3: | |
| # Moderate beacon + moderate pattern - apply reduced discount | |
| discount = 1.0 - (pattern_confidence * 0.4) # Max 40% reduction | |
| c2_prob *= discount | |
| result.mitigating_factors.append(f"{matched_pattern.name} pattern reduces probability by {(1-discount)*100:.0f}%") | |
| else: | |
| # Weak/no beacon - apply full discount | |
| discount = 1.0 - (pattern_confidence * 0.7) # Up to 70% reduction | |
| c2_prob *= discount | |
| result.mitigating_factors.append(f"{matched_pattern.name} pattern reduces probability by {(1-discount)*100:.0f}%") | |
| # ================================================================ | |
| # PHASE 6: Apply context inference (always check whitelist/blacklist) | |
| # ================================================================ | |
| # Always run inference to check whitelist/blacklist | |
| inference = self.context_engine.infer(connections, context) | |
| if inference['probability_modifier'] != 1.0 or context: | |
| result.context_applied = True | |
| result.probability_modifier = inference['probability_modifier'] | |
| c2_prob *= inference['probability_modifier'] | |
| result.risk_factors.extend(inference['risk_factors']) | |
| result.mitigating_factors.extend(inference['mitigating_factors']) | |
| result.recommendations.extend(inference['recommendations']) | |
| if inference['is_whitelisted']: | |
| result.mitigating_factors.append("Destination is whitelisted") | |
| if inference['is_blacklisted']: | |
| result.risk_factors.append("Destination is blacklisted") | |
| if inference['service_type'] != ServiceType.UNKNOWN: | |
| result.service_type = inference['service_type'].value | |
| # ================================================================ | |
| # PHASE 7: Final decision | |
| # ================================================================ | |
| # Apply strict mode if requested | |
| effective_threshold = threshold | |
| if strict_mode: | |
| effective_threshold = max(threshold, 0.7) | |
| result.c2_probability = min(max(c2_prob, 0.0), 1.0) | |
| result.is_c2 = result.c2_probability >= effective_threshold | |
| result.detection_method = DetectionMethod.ML.value if not result.context_applied else DetectionMethod.CONTEXT.value | |
| if result.is_c2: | |
| result.c2_type = FeatureExtractor.C2_TYPES[c2_type_idx] if c2_type_idx > 0 else 'unknown' | |
| else: | |
| result.c2_type = 'none' | |
| # Add recommendations based on analysis | |
| if result.is_c2: | |
| result.recommendations.append("Investigate destination IP for known C2 infrastructure") | |
| result.recommendations.append("Check for associated process and user activity") | |
| if result.evasion_score > 0.5: | |
| result.recommendations.append("C2 may be using evasion techniques - correlate with other telemetry") | |
| if include_features: | |
| result.features = features.tolist() | |
| # ================================================================ | |
| # PHASE 8: Populate machine-readable output fields | |
| # ================================================================ | |
| result.connections_analyzed = len(connections) | |
| # Time range | |
| timestamps = [c.get('timestamp', 0) for c in connections if c.get('timestamp')] | |
| if timestamps: | |
| result.time_range = { | |
| 'start': min(timestamps), | |
| 'end': max(timestamps), | |
| 'duration': max(timestamps) - min(timestamps) | |
| } | |
| # Destination summary | |
| dst_port_counts = {} | |
| for conn in connections: | |
| dst_ip = conn.get('dst_ip', '') | |
| dst_port = conn.get('dst_port', 0) | |
| key = f"{dst_ip}:{dst_port}" | |
| dst_port_counts[key] = dst_port_counts.get(key, 0) + 1 | |
| result.destination_summary = { | |
| 'unique_ips': list(dst_ips), | |
| 'unique_ports': list(ports), | |
| 'destinations': dst_port_counts, | |
| 'total_bytes_sent': total_sent, | |
| 'total_bytes_recv': total_recv | |
| } | |
| # Suspicious connections - mark each with a score | |
| if result.is_c2: | |
| # All connections to a detected C2 destination are suspicious | |
| for i, conn in enumerate(connections): | |
| result.suspicious_connections.append({ | |
| 'index': i, | |
| 'timestamp': conn.get('timestamp'), | |
| 'src_ip': conn.get('src_ip', ''), | |
| 'src_port': conn.get('src_port', 0), | |
| 'dst_ip': conn.get('dst_ip', ''), | |
| 'dst_port': conn.get('dst_port', 0), | |
| 'bytes_sent': conn.get('bytes_sent', 0), | |
| 'bytes_recv': conn.get('bytes_recv', 0), | |
| 'score': result.c2_probability | |
| }) | |
| # IOCs (Indicators of Compromise) | |
| if result.is_c2: | |
| result.iocs = { | |
| 'ip_addresses': list(dst_ips), | |
| 'ports': list(ports), | |
| 'c2_type': result.c2_type, | |
| 'timing_signature': { | |
| 'mean_interval': float(np.mean(np.diff(sorted(timestamps)))) if len(timestamps) > 1 else 0, | |
| 'interval_cv': float(np.std(np.diff(sorted(timestamps))) / (np.mean(np.diff(sorted(timestamps))) + 1e-6)) if len(timestamps) > 1 else 0 | |
| }, | |
| 'size_signature': { | |
| 'mean_bytes_sent': float(np.mean(bytes_sent)) if bytes_sent else 0, | |
| 'mean_bytes_recv': float(np.mean(bytes_recv)) if bytes_recv else 0, | |
| 'sent_cv': float(sent_cv), | |
| 'recv_cv': float(recv_cv) | |
| }, | |
| 'behavioral_indicators': result.risk_factors | |
| } | |
| return result | |
| def analyze_batch( | |
| self, | |
| connection_groups: List[List[Dict]], | |
| threshold: float = 0.5, | |
| contexts: Optional[List[ConnectionContext]] = None, | |
| parallel: bool = True | |
| ) -> List[AnalysisResult]: | |
| """ | |
| Analyze multiple connection groups efficiently. | |
| Args: | |
| connection_groups: List of connection lists to analyze | |
| threshold: Detection threshold | |
| contexts: Optional list of contexts (one per group) | |
| parallel: Use batch processing for efficiency | |
| Returns: | |
| List of AnalysisResults | |
| """ | |
| results = [] | |
| for i, connections in enumerate(connection_groups): | |
| context = contexts[i] if contexts and i < len(contexts) else None | |
| result = self.analyze(connections, threshold=threshold, context=context) | |
| results.append(result) | |
| return results | |
| def analyze_logs( | |
| self, | |
| log_lines: List[str], | |
| group_by_dst: bool = True, | |
| threshold: float = 0.5 | |
| ) -> List[Dict]: | |
| """Analyze raw log lines for C2 activity.""" | |
| from datetime import datetime | |
| connections = [] | |
| # First try to parse as complete JSON (array or object with messages) | |
| full_content = ''.join(log_lines) | |
| try: | |
| data = json.loads(full_content) | |
| # Handle Graylog-style nested JSON: {"messages": [...]} | |
| if isinstance(data, dict) and 'messages' in data: | |
| data = data['messages'] | |
| if isinstance(data, list): | |
| for item in data: | |
| if isinstance(item, dict): | |
| # Parse timestamp | |
| ts = item.get('timestamp', item.get('@timestamp', 0)) | |
| if isinstance(ts, str): | |
| try: | |
| dt = datetime.fromisoformat(ts.replace('Z', '+00:00')) | |
| ts = dt.timestamp() | |
| except: | |
| ts = 0 | |
| # Handle 'bytes' field (combined) vs separate sent/recv | |
| bytes_val = int(item.get('bytes', 0)) | |
| bytes_sent = int(item.get('bytes_sent', item.get('bytes_out', bytes_val))) | |
| bytes_recv = int(item.get('bytes_recv', item.get('bytes_in', 0))) | |
| conn = { | |
| 'timestamp': ts, | |
| 'src_ip': item.get('src_ip', item.get('source_ip', '')), | |
| 'dst_ip': item.get('dst_ip', item.get('dest_ip', '')), | |
| 'src_port': int(item.get('src_port', item.get('source_port', 0))), | |
| 'dst_port': int(item.get('dst_port', item.get('dest_port', 0))), | |
| 'protocol': item.get('protocol', 'tcp'), | |
| 'bytes_sent': bytes_sent, | |
| 'bytes_recv': bytes_recv, | |
| 'duration': float(item.get('duration', 0)) | |
| } | |
| if conn.get('dst_ip'): | |
| connections.append(conn) | |
| except (json.JSONDecodeError, TypeError, ValueError): | |
| pass | |
| # Fall back to line-by-line parsing | |
| if not connections: | |
| has_csv_header = log_lines and log_lines[0].strip().startswith('timestamp,') | |
| for line in log_lines: | |
| conn = self.log_parser.parse_json(line) | |
| if not conn: | |
| conn = self.log_parser.parse_zeek_conn(line) | |
| if not conn: | |
| conn = self.log_parser.parse_syslog(line) | |
| if not conn and has_csv_header: | |
| conn = self.log_parser.parse_csv(line, headers=['timestamp']) | |
| if conn: | |
| connections.append(conn) | |
| if not connections: | |
| return [] | |
| results = [] | |
| if group_by_dst: | |
| grouped = defaultdict(list) | |
| for conn in connections: | |
| grouped[conn.get('dst_ip', 'unknown')].append(conn) | |
| for dst_ip, group_conns in grouped.items(): | |
| if len(group_conns) >= 3: | |
| result = self.analyze(group_conns, threshold) | |
| result_dict = result.to_dict() | |
| result_dict['dst_ip'] = dst_ip | |
| result_dict['connection_count'] = len(group_conns) | |
| results.append(result_dict) | |
| else: | |
| result = self.analyze(connections, threshold) | |
| result_dict = result.to_dict() | |
| result_dict['connection_count'] = len(connections) | |
| results.append(result_dict) | |
| return sorted(results, key=lambda x: x['c2_probability'], reverse=True) | |
| def add_whitelist(self, ips: List[str] = None, domains: List[str] = None): | |
| """Add IPs or domains to whitelist.""" | |
| if ips: | |
| for ip in ips: | |
| self.context_engine.add_whitelist_ip(ip) | |
| if domains: | |
| for domain in domains: | |
| self.context_engine.add_whitelist_domain(domain) | |
| def add_blacklist(self, ips: List[str] = None, domains: List[str] = None): | |
| """Add IPs or domains to blacklist.""" | |
| if ips: | |
| for ip in ips: | |
| self.context_engine.add_blacklist_ip(ip) | |
| if domains: | |
| for domain in domains: | |
| self.context_engine.add_blacklist_domain(domain) | |
| def save(self, path: str): | |
| """Save model to safetensors format.""" | |
| path = Path(path) | |
| model_path = path.with_suffix('.safetensors') | |
| save_file(self.model.state_dict(), str(model_path)) | |
| config_path = path.with_suffix('.json') | |
| with open(config_path, 'w') as f: | |
| json.dump(self.config.to_dict(), f, indent=2) | |
| print(f"Model saved to {model_path}") | |
| print(f"Config saved to {config_path}") | |
| def load(cls, path: str, device: str = 'auto') -> 'C2Sentinel': | |
| """Load model from safetensors format.""" | |
| path = Path(path) | |
| if path.suffix == '.safetensors': | |
| model_path = path | |
| config_path = path.with_suffix('.json') | |
| else: | |
| model_path = path.with_suffix('.safetensors') | |
| config_path = path.with_suffix('.json') | |
| with open(config_path, 'r') as f: | |
| config = C2SentinelConfig.from_dict(json.load(f)) | |
| model = LogBERTC2Sentinel(config) | |
| state_dict = load_file(str(model_path)) | |
| model.load_state_dict(state_dict) | |
| return cls(model, config, device) | |
| def from_pretrained(cls, repo_id: str, device: str = 'auto', cache_dir: Optional[str] = None) -> 'C2Sentinel': | |
| """ | |
| Load model from HuggingFace Hub. | |
| Args: | |
| repo_id: HuggingFace repository ID (e.g., 'danielostrow/c2sentinel') | |
| device: Device to load model on ('auto', 'cpu', 'cuda', 'mps') | |
| cache_dir: Optional cache directory for downloaded files | |
| Returns: | |
| Loaded C2Sentinel instance | |
| """ | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| except ImportError: | |
| raise ImportError("huggingface_hub is required for from_pretrained. Install with: pip install huggingface_hub") | |
| # Download model files | |
| model_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="c2_sentinel.safetensors", | |
| cache_dir=cache_dir | |
| ) | |
| config_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="c2_sentinel.json", | |
| cache_dir=cache_dir | |
| ) | |
| # Load config | |
| with open(config_path, 'r') as f: | |
| config = C2SentinelConfig.from_dict(json.load(f)) | |
| # Load model | |
| model = LogBERTC2Sentinel(config) | |
| state_dict = load_file(str(model_path)) | |
| model.load_state_dict(state_dict) | |
| return cls(model, config, device) | |
| def create_new(cls, device: str = 'auto') -> 'C2Sentinel': | |
| """Create a new untrained model.""" | |
| config = C2SentinelConfig() | |
| model = LogBERTC2Sentinel(config) | |
| return cls(model, config, device) | |
| # ============================================================================ | |
| # CONVENIENCE FUNCTIONS | |
| # ============================================================================ | |
| def load_model(path: str, device: str = 'auto') -> C2Sentinel: | |
| """Load a pre-trained C2Sentinel model.""" | |
| return C2Sentinel.load(path, device) | |
| def create_model(device: str = 'auto') -> C2Sentinel: | |
| """Create a new untrained C2Sentinel model.""" | |
| return C2Sentinel.create_new(device) | |
| def quick_analyze(connections: List[Dict], model_path: str = 'c2_sentinel') -> AnalysisResult: | |
| """Quick one-shot analysis without keeping model in memory.""" | |
| sentinel = C2Sentinel.load(model_path) | |
| return sentinel.analyze(connections) | |
| # ============================================================================ | |
| # CLI AND TESTING | |
| # ============================================================================ | |
| if __name__ == '__main__': | |
| print("LogBERT-C2Sentinel v2.0: Advanced C2 Detection with Context Inference") | |
| print("=" * 70) | |
| sentinel = C2Sentinel.create_new() | |
| print(f"Model created with {sentinel.config.num_features} features") | |
| print(f"Device: {sentinel.device}") | |
| # Test 1: Metasploit signature detection | |
| print("\n[TEST 1] Metasploit Meterpreter (port 4444)...") | |
| msf_connections = [ | |
| {'timestamp': 1000 + i*5, 'dst_ip': '192.168.1.100', 'dst_port': 4444, | |
| 'bytes_sent': 150, 'bytes_recv': 400} | |
| for i in range(8) | |
| ] | |
| result = sentinel.analyze(msf_connections) | |
| print(f" {result}") | |
| # Test 2: SSH keepalive (should NOT be flagged) | |
| print("\n[TEST 2] SSH Keepalive (should be clean)...") | |
| ssh_keepalive = [ | |
| {'timestamp': 1000 + i*30, 'dst_ip': '192.168.1.10', 'dst_port': 22, | |
| 'bytes_sent': 48, 'bytes_recv': 48} | |
| for i in range(15) | |
| ] | |
| result = sentinel.analyze(ssh_keepalive) | |
| print(f" {result}") | |
| print(f" Matched pattern: {result.matched_legitimate_pattern}") | |
| print(f" Mitigating factors: {result.mitigating_factors}") | |
| # Test 3: SSH with context | |
| print("\n[TEST 3] SSH Keepalive with process context...") | |
| context = ConnectionContext(process_name='sshd', known_good=True) | |
| result = sentinel.analyze(ssh_keepalive, context=context) | |
| print(f" {result}") | |
| # Test 4: C2 beacon on 443 | |
| print("\n[TEST 4] C2 Beacon on port 443...") | |
| c2_beacon = [ | |
| {'timestamp': 1000 + i*60, 'dst_ip': '10.10.10.10', 'dst_port': 443, | |
| 'bytes_sent': 200, 'bytes_recv': 500} | |
| for i in range(10) | |
| ] | |
| result = sentinel.analyze(c2_beacon) | |
| print(f" {result}") | |
| # Test 5: Benign browsing | |
| print("\n[TEST 5] Benign web browsing...") | |
| import random | |
| browsing = [ | |
| {'timestamp': 1000 + i*random.uniform(5, 120), | |
| 'dst_ip': f"{random.randint(1,200)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}", | |
| 'dst_port': 443, | |
| 'bytes_sent': random.randint(500, 3000), | |
| 'bytes_recv': random.randint(10000, 500000)} | |
| for i in range(15) | |
| ] | |
| result = sentinel.analyze(browsing) | |
| print(f" {result}") | |
| # Test reconnaissance support | |
| print("\n[TEST 6] Reconnaissance support...") | |
| ip_info = sentinel.recon.analyze_ip('104.16.132.229') | |
| print(f" IP Analysis: {ip_info}") | |
| print("\n" + "=" * 70) | |
| print("Model ready for deployment!") | |
| print("=" * 70) | |