| |
| """ |
| C2Sentinel - Network Traffic C2 Beacon Detection Model |
| |
| A machine learning model for detecting Command and Control (C2) beacon |
| communications in network traffic. Built on a fine-tuned LogBERT transformer |
| architecture. |
| |
| Author: Daniel Ostrow |
| Website: https://neuralintellect.com |
| |
| Features: |
| - Detection of 34+ C2 framework behavioral patterns across all ports |
| - Smart context inference for additional metadata (process, DNS, reputation) |
| - Legitimate service pattern recognition (SSH keepalive, health checks) |
| - Reconnaissance support (IP enrichment, IOC generation) |
| - Comprehensive scripting API for automation |
| |
| Uses safetensors format for secure model serialization. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import json |
| import math |
| import socket |
| import struct |
| import hashlib |
| from pathlib import Path |
| from typing import Dict, List, Tuple, Optional, Union, Any, Callable |
| from dataclasses import dataclass, asdict, field |
| from collections import defaultdict |
| from enum import Enum |
| import re |
| from datetime import datetime |
| import ipaddress |
|
|
| |
| from safetensors.torch import save_file, load_file |
|
|
|
|
| |
| |
| |
|
|
| class DetectionMethod(Enum): |
| """Detection method used for classification.""" |
| SIGNATURE = "signature" |
| BEHAVIORAL = "behavioral" |
| ML = "ml" |
| CONTEXT = "context" |
| HEURISTIC = "heuristic" |
| WHITELIST = "whitelist" |
|
|
|
|
| class TrafficType(Enum): |
| """Classification of traffic type.""" |
| C2_BEACON = "c2_beacon" |
| C2_EXFIL = "c2_exfiltration" |
| C2_LATERAL = "c2_lateral_movement" |
| LEGITIMATE = "legitimate" |
| SUSPICIOUS = "suspicious" |
| UNKNOWN = "unknown" |
|
|
|
|
| class ServiceType(Enum): |
| """Known service types for context.""" |
| SSH = "ssh" |
| HTTP = "http" |
| HTTPS = "https" |
| DNS = "dns" |
| DATABASE = "database" |
| API = "api" |
| STREAMING = "streaming" |
| GAMING = "gaming" |
| VPN = "vpn" |
| MONITORING = "monitoring" |
| UNKNOWN = "unknown" |
|
|
|
|
| @dataclass |
| class C2SentinelConfig: |
| """Configuration for LogBERT-C2Sentinel model.""" |
| num_features: int = 40 |
| d_model: int = 256 |
| nhead: int = 8 |
| num_encoder_layers: int = 6 |
| dim_feedforward: int = 1024 |
| dropout: float = 0.1 |
| max_seq_length: int = 512 |
| num_c2_types: int = 35 |
| version: str = "2.0.0" |
|
|
| def to_dict(self) -> dict: |
| return asdict(self) |
|
|
| @classmethod |
| def from_dict(cls, d: dict) -> 'C2SentinelConfig': |
| return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__}) |
|
|
|
|
| |
| C2_INDICATOR_PORTS = { |
| 4444, |
| 4445, |
| 5555, |
| 31337, |
| 40056, |
| } |
|
|
| |
| C2_COMMON_PORTS = { |
| 80, |
| 443, |
| 53, |
| 8080, |
| 8443, |
| 8888, |
| } |
|
|
| |
| LEGITIMATE_SERVICE_PORTS = { |
| 22: ServiceType.SSH, |
| 80: ServiceType.HTTP, |
| 443: ServiceType.HTTPS, |
| 53: ServiceType.DNS, |
| 3306: ServiceType.DATABASE, |
| 5432: ServiceType.DATABASE, |
| 6379: ServiceType.DATABASE, |
| 27017: ServiceType.DATABASE, |
| 5000: ServiceType.API, |
| 3000: ServiceType.API, |
| 8080: ServiceType.API, |
| 9090: ServiceType.MONITORING, |
| 3100: ServiceType.MONITORING, |
| } |
|
|
| |
| C2_SIGNATURES = { |
| 'metasploit': { |
| 'ports': [4444, 4445, 5555], |
| 'interval_range': (1, 30), |
| 'packet_sizes': [(50, 200), (500, 2000)], |
| 'jitter_range': (0.0, 0.3), |
| }, |
| 'cobalt_strike': { |
| 'ports': [50050], |
| 'interval_range': (30, 300), |
| 'packet_sizes': [(68, 200), (200, 1000)], |
| 'jitter_range': (0.0, 0.5), |
| }, |
| 'sliver': { |
| 'ports': [8888, 31337], |
| 'interval_range': (5, 60), |
| 'packet_sizes': [(100, 500)], |
| 'jitter_range': (0.0, 0.3), |
| }, |
| 'havoc': { |
| 'ports': [40056], |
| 'interval_range': (2, 30), |
| 'packet_sizes': [(64, 256)], |
| 'jitter_range': (0.0, 0.2), |
| }, |
| } |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class LegitimatePattern: |
| """Defines a known legitimate traffic pattern.""" |
| name: str |
| service_type: ServiceType |
| port: Optional[int] = None |
| ports: Optional[List[int]] = None |
| min_packet_size: int = 0 |
| max_packet_size: int = 100000 |
| symmetric_ratio: Tuple[float, float] = (0.0, 10.0) |
| max_interval_cv: float = 1.0 |
| max_size_cv: float = 1.0 |
| description: str = "" |
|
|
| def matches(self, connections: List[Dict], stats: Dict) -> Tuple[bool, float]: |
| """Check if connections match this legitimate pattern. Returns (matches, confidence).""" |
| if not connections: |
| return False, 0.0 |
|
|
| ports = set(conn.get('dst_port', 0) for conn in connections) |
|
|
| |
| if self.port and self.port not in ports: |
| return False, 0.0 |
| if self.ports and not any(p in ports for p in self.ports): |
| return False, 0.0 |
|
|
| |
| bytes_sent = [conn.get('bytes_sent', 0) for conn in connections] |
| bytes_recv = [conn.get('bytes_recv', 0) for conn in connections] |
|
|
| if bytes_sent: |
| if max(bytes_sent) > self.max_packet_size or min(bytes_sent) < self.min_packet_size: |
| return False, 0.0 |
|
|
| |
| total_sent = sum(bytes_sent) |
| total_recv = sum(bytes_recv) |
| if total_recv > 0: |
| ratio = total_sent / total_recv |
| if not (self.symmetric_ratio[0] <= ratio <= self.symmetric_ratio[1]): |
| return False, 0.0 |
|
|
| |
| |
| recv_cv = stats.get('recv_cv', 0) |
| sent_cv = stats.get('sent_cv', 0) |
|
|
| |
| |
| if recv_cv < 0.3 and sent_cv < 0.3: |
| |
| if self.name == "ssh_keepalive": |
| pass |
| else: |
| |
| return False, 0.0 |
|
|
| return True, 0.8 |
|
|
|
|
| |
| LEGITIMATE_PATTERNS = [ |
| LegitimatePattern( |
| name="ssh_keepalive", |
| service_type=ServiceType.SSH, |
| port=22, |
| min_packet_size=20, |
| max_packet_size=100, |
| symmetric_ratio=(0.8, 1.2), |
| max_interval_cv=0.3, |
| max_size_cv=0.15, |
| description="SSH keepalive probes - small symmetric packets at regular intervals" |
| ), |
| LegitimatePattern( |
| name="ssh_interactive", |
| service_type=ServiceType.SSH, |
| port=22, |
| min_packet_size=20, |
| max_packet_size=50000, |
| symmetric_ratio=(0.01, 100.0), |
| max_interval_cv=2.0, |
| max_size_cv=2.0, |
| description="Interactive SSH session with variable human-driven timing" |
| ), |
| LegitimatePattern( |
| name="health_check", |
| service_type=ServiceType.MONITORING, |
| ports=[80, 443, 8080, 8443, 9090], |
| min_packet_size=50, |
| max_packet_size=10000, |
| symmetric_ratio=(0.01, 0.5), |
| max_interval_cv=0.3, |
| max_size_cv=1.0, |
| description="Health check endpoint with variable response sizes" |
| ), |
| LegitimatePattern( |
| name="database_heartbeat", |
| service_type=ServiceType.DATABASE, |
| ports=[3306, 5432, 6379, 27017], |
| min_packet_size=20, |
| max_packet_size=100000, |
| symmetric_ratio=(0.01, 100.0), |
| max_interval_cv=0.3, |
| max_size_cv=5.0, |
| description="Database connection with variable query responses" |
| ), |
| LegitimatePattern( |
| name="websocket_ping", |
| service_type=ServiceType.API, |
| ports=[80, 443, 8080], |
| min_packet_size=10, |
| max_packet_size=100000, |
| symmetric_ratio=(0.001, 100.0), |
| max_interval_cv=0.5, |
| max_size_cv=5.0, |
| description="WebSocket connection with ping/pong and data pushes" |
| ), |
| ] |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class ConnectionContext: |
| """ |
| Additional context for connection analysis. |
| |
| Provide any available context to improve detection accuracy. |
| All fields are optional - more context = better analysis. |
| """ |
| |
| process_name: Optional[str] = None |
| process_path: Optional[str] = None |
| process_pid: Optional[int] = None |
| parent_process: Optional[str] = None |
| command_line: Optional[str] = None |
|
|
| |
| dns_queries: Optional[List[str]] = None |
| resolved_hostname: Optional[str] = None |
| tls_sni: Optional[str] = None |
| tls_ja3: Optional[str] = None |
| tls_ja3s: Optional[str] = None |
| certificate_issuer: Optional[str] = None |
| certificate_subject: Optional[str] = None |
| certificate_valid: Optional[bool] = None |
| http_user_agent: Optional[str] = None |
| http_host: Optional[str] = None |
|
|
| |
| ip_reputation: Optional[float] = None |
| domain_reputation: Optional[float] = None |
| known_good: Optional[bool] = None |
| known_bad: Optional[bool] = None |
| threat_intel_match: Optional[str] = None |
|
|
| |
| source_hostname: Optional[str] = None |
| source_user: Optional[str] = None |
| source_is_server: Optional[bool] = None |
| source_is_workstation: Optional[bool] = None |
|
|
| |
| geo_country: Optional[str] = None |
| geo_asn: Optional[str] = None |
| tags: Optional[List[str]] = None |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| return {k: v for k, v in asdict(self).items() if v is not None} |
|
|
|
|
| class ContextInference: |
| """ |
| Smart context inference engine. |
| |
| Uses additional context to refine detection decisions and reduce false positives. |
| """ |
|
|
| |
| KNOWN_LEGITIMATE_PROCESSES = { |
| 'sshd', 'ssh', 'openssh', 'dropbear', |
| 'chrome', 'firefox', 'safari', 'edge', 'brave', |
| 'curl', 'wget', 'httpd', 'nginx', 'apache2', |
| 'python', 'python3', 'node', 'java', 'ruby', |
| 'postgres', 'mysql', 'mongod', 'redis-server', |
| 'docker', 'containerd', 'kubelet', |
| 'systemd', 'init', 'launchd', |
| 'prometheus', 'grafana', 'telegraf', |
| 'code', 'code-server', 'vim', 'emacs', |
| 'git', 'git-remote-https', |
| 'apt', 'yum', 'dnf', 'brew', 'pip', |
| 'zoom', 'slack', 'teams', 'discord', |
| 'spotify', 'vlc', 'mpv', |
| } |
|
|
| |
| SUSPICIOUS_PROCESSES = { |
| 'powershell', 'cmd', 'wscript', 'cscript', 'mshta', |
| 'rundll32', 'regsvr32', 'msiexec', |
| 'nc', 'netcat', 'ncat', 'socat', |
| 'mimikatz', 'procdump', 'psexec', |
| 'beacon', 'payload', 'implant', 'agent', |
| } |
|
|
| |
| KNOWN_C2_JA3 = { |
| '72a589da586844d7f0818ce684948eea', |
| '51c64c77e60f3980eea90869b68c58a8', |
| } |
|
|
| |
| SUSPICIOUS_CERT_PATTERNS = [ |
| r'localhost', |
| r'test\.', |
| r'example\.', |
| r'\.local$', |
| r'^C2', |
| r'beacon', |
| ] |
|
|
| def __init__(self): |
| self.whitelist_ips: set = set() |
| self.whitelist_domains: set = set() |
| self.blacklist_ips: set = set() |
| self.blacklist_domains: set = set() |
| self.custom_rules: List[Callable] = [] |
|
|
| def add_whitelist_ip(self, ip: str): |
| """Add IP to whitelist.""" |
| self.whitelist_ips.add(ip) |
|
|
| def add_whitelist_domain(self, domain: str): |
| """Add domain to whitelist.""" |
| self.whitelist_domains.add(domain.lower()) |
|
|
| def add_blacklist_ip(self, ip: str): |
| """Add IP to blacklist.""" |
| self.blacklist_ips.add(ip) |
|
|
| def add_blacklist_domain(self, domain: str): |
| """Add domain to blacklist.""" |
| self.blacklist_domains.add(domain.lower()) |
|
|
| def add_custom_rule(self, rule: Callable[[List[Dict], ConnectionContext], Tuple[Optional[float], str]]): |
| """ |
| Add custom inference rule. |
| |
| Rule should return (probability_modifier, reason) or (None, "") to skip. |
| """ |
| self.custom_rules.append(rule) |
|
|
| def infer(self, connections: List[Dict], context: Optional[ConnectionContext] = None) -> Dict[str, Any]: |
| """ |
| Perform context-based inference. |
| |
| Returns inference results that can modify detection probability. |
| """ |
| result = { |
| 'probability_modifier': 1.0, |
| 'confidence_boost': 0.0, |
| 'is_whitelisted': False, |
| 'is_blacklisted': False, |
| 'matched_patterns': [], |
| 'risk_factors': [], |
| 'mitigating_factors': [], |
| 'service_type': ServiceType.UNKNOWN, |
| 'recommendations': [], |
| } |
|
|
| if not connections: |
| return result |
|
|
| dst_ips = set(conn.get('dst_ip', '') for conn in connections) |
| ports = set(conn.get('dst_port', 0) for conn in connections) |
|
|
| |
| for ip in dst_ips: |
| if ip in self.whitelist_ips: |
| result['is_whitelisted'] = True |
| result['probability_modifier'] *= 0.1 |
| result['mitigating_factors'].append(f"Destination IP {ip} is whitelisted") |
|
|
| |
| for ip in dst_ips: |
| if ip in self.blacklist_ips: |
| result['is_blacklisted'] = True |
| result['probability_modifier'] *= 3.0 |
| result['risk_factors'].append(f"Destination IP {ip} is blacklisted") |
|
|
| if context: |
| result = self._apply_context(result, connections, context, ports) |
|
|
| |
| for rule in self.custom_rules: |
| try: |
| modifier, reason = rule(connections, context) |
| if modifier is not None: |
| result['probability_modifier'] *= modifier |
| if modifier < 1.0: |
| result['mitigating_factors'].append(reason) |
| elif modifier > 1.0: |
| result['risk_factors'].append(reason) |
| except Exception: |
| pass |
|
|
| return result |
|
|
| def _apply_context(self, result: Dict, connections: List[Dict], |
| context: ConnectionContext, ports: set) -> Dict: |
| """Apply context-based inference rules.""" |
|
|
| |
| if context.process_name: |
| proc_lower = context.process_name.lower() |
|
|
| if proc_lower in self.KNOWN_LEGITIMATE_PROCESSES: |
| result['mitigating_factors'].append(f"Known legitimate process: {context.process_name}") |
| result['probability_modifier'] *= 0.5 |
|
|
| if proc_lower in self.SUSPICIOUS_PROCESSES: |
| result['risk_factors'].append(f"Suspicious process: {context.process_name}") |
| result['probability_modifier'] *= 1.5 |
|
|
| |
| if proc_lower in ('sshd', 'ssh', 'openssh') and 22 in ports: |
| result['mitigating_factors'].append("SSH process on SSH port - expected behavior") |
| result['probability_modifier'] *= 0.3 |
| result['service_type'] = ServiceType.SSH |
|
|
| |
| if context.known_good: |
| result['is_whitelisted'] = True |
| result['probability_modifier'] *= 0.1 |
| result['mitigating_factors'].append("Explicitly marked as known good") |
|
|
| if context.known_bad: |
| result['is_blacklisted'] = True |
| result['probability_modifier'] *= 5.0 |
| result['risk_factors'].append("Explicitly marked as known bad") |
|
|
| |
| if context.ip_reputation is not None: |
| if context.ip_reputation > 0.8: |
| result['mitigating_factors'].append(f"Good IP reputation: {context.ip_reputation:.2f}") |
| result['probability_modifier'] *= 0.6 |
| elif context.ip_reputation < 0.3: |
| result['risk_factors'].append(f"Poor IP reputation: {context.ip_reputation:.2f}") |
| result['probability_modifier'] *= 1.5 |
|
|
| if context.domain_reputation is not None: |
| if context.domain_reputation > 0.8: |
| result['mitigating_factors'].append(f"Good domain reputation: {context.domain_reputation:.2f}") |
| result['probability_modifier'] *= 0.6 |
| elif context.domain_reputation < 0.3: |
| result['risk_factors'].append(f"Poor domain reputation: {context.domain_reputation:.2f}") |
| result['probability_modifier'] *= 1.5 |
|
|
| |
| if context.tls_ja3: |
| if context.tls_ja3 in self.KNOWN_C2_JA3: |
| result['risk_factors'].append(f"Known C2 JA3 fingerprint: {context.tls_ja3}") |
| result['probability_modifier'] *= 3.0 |
|
|
| |
| if context.certificate_subject: |
| for pattern in self.SUSPICIOUS_CERT_PATTERNS: |
| if re.search(pattern, context.certificate_subject, re.IGNORECASE): |
| result['risk_factors'].append(f"Suspicious certificate subject: {context.certificate_subject}") |
| result['probability_modifier'] *= 1.3 |
| break |
|
|
| if context.certificate_valid is False: |
| result['risk_factors'].append("Invalid TLS certificate") |
| result['probability_modifier'] *= 1.4 |
|
|
| |
| if context.threat_intel_match: |
| result['is_blacklisted'] = True |
| result['risk_factors'].append(f"Threat intel match: {context.threat_intel_match}") |
| result['probability_modifier'] *= 5.0 |
|
|
| |
| if context.dns_queries: |
| |
| for query in context.dns_queries: |
| query_lower = query.lower() |
|
|
| |
| if query_lower in self.blacklist_domains: |
| result['risk_factors'].append(f"Blacklisted domain: {query}") |
| result['probability_modifier'] *= 2.0 |
|
|
| |
| if query_lower in self.whitelist_domains: |
| result['mitigating_factors'].append(f"Whitelisted domain: {query}") |
| result['probability_modifier'] *= 0.5 |
|
|
| |
| if len(query) > 20 and self._calculate_entropy(query) > 3.5: |
| result['risk_factors'].append(f"Possible DGA domain: {query}") |
| result['probability_modifier'] *= 1.3 |
|
|
| |
| if context.geo_country: |
| |
| pass |
|
|
| return result |
|
|
| def _calculate_entropy(self, s: str) -> float: |
| """Calculate Shannon entropy of a string.""" |
| if not s: |
| return 0.0 |
| prob = [s.count(c) / len(s) for c in set(s)] |
| return -sum(p * math.log2(p) for p in prob if p > 0) |
|
|
|
|
| |
| |
| |
|
|
| class PositionalEncoding(nn.Module): |
| """Positional encoding for transformer.""" |
|
|
| def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1): |
| super().__init__() |
| self.dropout = nn.Dropout(p=dropout) |
|
|
| pe = torch.zeros(max_len, d_model) |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| pe = pe.unsqueeze(0) |
| self.register_buffer('pe', pe) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x + self.pe[:, :x.size(1)] |
| return self.dropout(x) |
|
|
|
|
| class LogBERTC2Sentinel(nn.Module): |
| """LogBERT-based model for C2 beacon detection.""" |
|
|
| def __init__(self, config: C2SentinelConfig): |
| super().__init__() |
| self.config = config |
|
|
| |
| self.feature_projection = nn.Sequential( |
| nn.Linear(config.num_features, config.d_model), |
| nn.LayerNorm(config.d_model), |
| nn.GELU(), |
| nn.Dropout(config.dropout) |
| ) |
|
|
| |
| self.pos_encoder = PositionalEncoding(config.d_model, config.max_seq_length, config.dropout) |
|
|
| |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=config.d_model, |
| nhead=config.nhead, |
| dim_feedforward=config.dim_feedforward, |
| dropout=config.dropout, |
| activation='gelu', |
| batch_first=True |
| ) |
| self.transformer_encoder = nn.TransformerEncoder(encoder_layer, config.num_encoder_layers) |
|
|
| |
| self.c2_head = nn.Sequential( |
| nn.Linear(config.d_model, config.d_model // 2), |
| nn.GELU(), |
| nn.Dropout(config.dropout), |
| nn.Linear(config.d_model // 2, 1) |
| ) |
|
|
| self.anomaly_head = nn.Sequential( |
| nn.Linear(config.d_model, config.d_model // 2), |
| nn.GELU(), |
| nn.Dropout(config.dropout), |
| nn.Linear(config.d_model // 2, 1), |
| nn.Sigmoid() |
| ) |
|
|
| self.evasion_head = nn.Sequential( |
| nn.Linear(config.d_model, config.d_model // 2), |
| nn.GELU(), |
| nn.Dropout(config.dropout), |
| nn.Linear(config.d_model // 2, 1), |
| nn.Sigmoid() |
| ) |
|
|
| self.c2_type_head = nn.Sequential( |
| nn.Linear(config.d_model, config.d_model // 2), |
| nn.GELU(), |
| nn.Dropout(config.dropout), |
| nn.Linear(config.d_model // 2, config.num_c2_types) |
| ) |
|
|
| self.confidence_head = nn.Sequential( |
| nn.Linear(config.d_model, config.d_model // 4), |
| nn.GELU(), |
| nn.Linear(config.d_model // 4, 1), |
| nn.Sigmoid() |
| ) |
|
|
| def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: |
| if x.dim() == 2: |
| x = x.unsqueeze(1) |
|
|
| x = self.feature_projection(x) |
| x = self.pos_encoder(x) |
| encoded = self.transformer_encoder(x, src_key_padding_mask=mask) |
|
|
| if mask is not None: |
| mask_expanded = (~mask).unsqueeze(-1).float() |
| pooled = (encoded * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1).clamp(min=1) |
| else: |
| pooled = encoded.mean(dim=1) |
|
|
| return { |
| 'c2_logits': self.c2_head(pooled), |
| 'anomaly_score': self.anomaly_head(pooled), |
| 'evasion_score': self.evasion_head(pooled), |
| 'c2_type_logits': self.c2_type_head(pooled), |
| 'confidence': self.confidence_head(pooled) |
| } |
|
|
|
|
| |
| |
| |
|
|
| class FeatureExtractor: |
| """Extracts 40-dimensional feature vectors from network traffic.""" |
|
|
| C2_TYPES = [ |
| 'unknown', 'metasploit', 'cobalt_strike', 'sliver', 'havoc', |
| 'mythic', 'poshc2', 'merlin', 'empire', 'covenant', |
| 'brute_ratel', 'koadic', 'pupy', 'silenttrinity', 'faction', |
| 'ibombshell', 'godoh', 'dnscat2', 'iodine', 'dns_generic', |
| 'http_custom', 'https_custom', 'websocket', 'domain_fronting', |
| 'cloud_fronting', 'cdn_abuse', 'apt_generic', 'apt28', 'apt29', |
| 'apt41', 'lazarus', 'fin7', 'turla', 'winnti', 'custom' |
| ] |
|
|
| METASPLOIT_PORTS = {4444, 4445, 5555} |
|
|
| def __init__(self): |
| self.connection_cache = defaultdict(list) |
| self.destination_history = defaultdict(set) |
|
|
| def check_metasploit_signature(self, connections: List[Dict]) -> Tuple[bool, float]: |
| """Check for Metasploit-specific signatures.""" |
| if not connections: |
| return False, 0.0 |
|
|
| confidence = 0.0 |
| indicators = 0 |
|
|
| ports = set(conn.get('dst_port', 0) for conn in connections) |
| metasploit_port_match = ports & self.METASPLOIT_PORTS |
|
|
| if not metasploit_port_match: |
| return False, 0.0 |
|
|
| if 4444 in metasploit_port_match: |
| confidence += 0.6 |
| indicators += 2 |
| elif 4445 in metasploit_port_match or 5555 in metasploit_port_match: |
| confidence += 0.4 |
| indicators += 1 |
|
|
| if len(connections) > 1: |
| timestamps = sorted([conn.get('timestamp', 0) for conn in connections]) |
| intervals = np.diff(timestamps) |
| if len(intervals) > 0: |
| mean_interval = np.mean(intervals) |
| if 1 <= mean_interval <= 30: |
| confidence += 0.15 |
| indicators += 1 |
|
|
| bytes_sent = [conn.get('bytes_sent', 0) for conn in connections] |
| if bytes_sent: |
| mean_size = np.mean(bytes_sent) |
| if 50 <= mean_size <= 200: |
| confidence += 0.1 |
| indicators += 1 |
|
|
| dst_ips = [conn.get('dst_ip', '') for conn in connections] |
| if dst_ips: |
| unique_dsts = len(set(dst_ips)) |
| if unique_dsts == 1 and len(dst_ips) >= 3: |
| confidence += 0.1 |
| indicators += 1 |
|
|
| is_metasploit = indicators >= 2 and confidence >= 0.5 |
| return is_metasploit, min(confidence, 1.0) |
|
|
| def check_ssh_keepalive(self, connections: List[Dict]) -> Tuple[bool, float]: |
| """ |
| Check for SSH keepalive pattern to prevent false positives. |
| |
| SSH keepalive characteristics: |
| - Port 22 |
| - Very small packets (typically 48-64 bytes) |
| - Nearly symmetric (sent ≈ recv) |
| - Regular intervals (typically 30s, 60s, 120s) |
| - Very consistent sizes |
| |
| Returns (is_ssh_keepalive, confidence) |
| """ |
| if not connections or len(connections) < 3: |
| return False, 0.0 |
|
|
| ports = set(conn.get('dst_port', 0) for conn in connections) |
|
|
| |
| if 22 not in ports: |
| return False, 0.0 |
|
|
| bytes_sent = [conn.get('bytes_sent', 0) for conn in connections] |
| bytes_recv = [conn.get('bytes_recv', 0) for conn in connections] |
|
|
| if not bytes_sent or not bytes_recv: |
| return False, 0.0 |
|
|
| mean_sent = np.mean(bytes_sent) |
| mean_recv = np.mean(bytes_recv) |
|
|
| |
| if mean_sent > 100 or mean_recv > 100: |
| |
| return False, 0.0 |
|
|
| |
| if mean_recv > 0: |
| ratio = mean_sent / mean_recv |
| if not (0.5 <= ratio <= 2.0): |
| |
| return False, 0.0 |
|
|
| |
| sent_cv = np.std(bytes_sent) / (mean_sent + 1e-6) |
| recv_cv = np.std(bytes_recv) / (mean_recv + 1e-6) |
|
|
| if sent_cv > 0.2 or recv_cv > 0.2: |
| |
| return False, 0.0 |
|
|
| |
| timestamps = sorted([conn.get('timestamp', 0) for conn in connections]) |
| if len(timestamps) > 1: |
| intervals = np.diff(timestamps) |
| if len(intervals) > 0: |
| mean_interval = np.mean(intervals) |
| interval_cv = np.std(intervals) / (mean_interval + 1e-6) |
|
|
| |
| common_keepalive_intervals = [15, 30, 60, 120, 180, 300] |
| closest_match = min(common_keepalive_intervals, key=lambda x: abs(x - mean_interval)) |
| interval_match = abs(mean_interval - closest_match) / closest_match < 0.2 |
|
|
| if interval_cv < 0.15 and interval_match: |
| |
| confidence = 0.95 |
| elif interval_cv < 0.2: |
| confidence = 0.85 |
| else: |
| return False, 0.0 |
|
|
| return True, confidence |
|
|
| return False, 0.0 |
|
|
| def check_legitimate_patterns(self, connections: List[Dict]) -> Tuple[Optional[LegitimatePattern], float]: |
| """ |
| Check if connections match any known legitimate patterns. |
| |
| Returns (matched_pattern, confidence) or (None, 0.0) |
| """ |
| if not connections: |
| return None, 0.0 |
|
|
| |
| bytes_sent = [conn.get('bytes_sent', 0) for conn in connections] |
| bytes_recv = [conn.get('bytes_recv', 0) for conn in connections] |
|
|
| stats = { |
| 'mean_sent': np.mean(bytes_sent) if bytes_sent else 0, |
| 'mean_recv': np.mean(bytes_recv) if bytes_recv else 0, |
| 'sent_cv': np.std(bytes_sent) / (np.mean(bytes_sent) + 1e-6) if bytes_sent else 0, |
| 'recv_cv': np.std(bytes_recv) / (np.mean(bytes_recv) + 1e-6) if bytes_recv else 0, |
| } |
|
|
| for pattern in LEGITIMATE_PATTERNS: |
| matches, confidence = pattern.matches(connections, stats) |
| if matches: |
| return pattern, confidence |
|
|
| return None, 0.0 |
|
|
| def extract_features(self, connections: List[Dict]) -> np.ndarray: |
| """Extract 40 features from connection records.""" |
| if not connections: |
| return np.zeros(40) |
|
|
| features = np.zeros(40) |
|
|
| |
| timestamps = [] |
| for conn in connections: |
| ts = conn.get('timestamp', 0) |
| if isinstance(ts, str): |
| try: |
| ts = datetime.fromisoformat(ts.replace('Z', '+00:00')).timestamp() |
| except: |
| ts = 0 |
| timestamps.append(float(ts)) |
|
|
| timestamps = np.array(sorted(timestamps)) |
|
|
| |
| if len(timestamps) > 1: |
| intervals = np.diff(timestamps) |
| intervals = intervals[intervals > 0] |
|
|
| if len(intervals) > 0: |
| features[0] = np.mean(intervals) |
| features[1] = np.std(intervals) |
| features[2] = np.std(intervals) / (np.mean(intervals) + 1e-6) |
| features[3] = np.median(intervals) |
| features[4] = np.min(intervals) |
| features[5] = np.max(intervals) |
|
|
| if len(intervals) > 2: |
| sorted_intervals = np.sort(intervals) |
| mode_estimate = sorted_intervals[len(sorted_intervals)//2] |
| regularity = 1.0 - np.mean(np.abs(intervals - mode_estimate) / (mode_estimate + 1e-6)) |
| features[6] = max(0, min(1, regularity)) |
|
|
| if len(intervals) >= 8: |
| fft = np.fft.fft(intervals - np.mean(intervals)) |
| power = np.abs(fft[:len(fft)//2])**2 |
| features[7] = np.max(power) / (np.sum(power) + 1e-6) |
|
|
| hours = [(ts % 86400) / 3600 for ts in timestamps] |
| features[8] = np.std(hours) / 12.0 |
| business_hours = sum(1 for h in hours if 9 <= h <= 17) / len(hours) |
| features[9] = business_hours |
|
|
| |
| dst_ips = [conn.get('dst_ip', '') for conn in connections] |
| dst_ports = [conn.get('dst_port', 0) for conn in connections] |
|
|
| unique_dsts = len(set(dst_ips)) |
| features[10] = unique_dsts |
| features[11] = unique_dsts / len(connections) if connections else 0 |
|
|
| if dst_ips: |
| dst_counts = defaultdict(int) |
| for ip in dst_ips: |
| dst_counts[ip] += 1 |
| max_persistence = max(dst_counts.values()) |
| features[12] = max_persistence / len(connections) |
| features[13] = len([c for c in dst_counts.values() if c > 1]) / len(dst_counts) if dst_counts else 0 |
|
|
| unique_ports = len(set(dst_ports)) |
| features[14] = unique_ports |
| features[15] = 1.0 if 443 in dst_ports or 80 in dst_ports else 0.0 |
|
|
| high_port_ratio = sum(1 for p in dst_ports if p > 10000) / len(dst_ports) if dst_ports else 0 |
| features[16] = high_port_ratio |
|
|
| msf_port_hit = any(p in self.METASPLOIT_PORTS for p in dst_ports) |
| features[17] = 1.0 if msf_port_hit else 0.0 |
|
|
| |
| bytes_sent = [conn.get('bytes_sent', 0) for conn in connections] |
| bytes_recv = [conn.get('bytes_recv', 0) for conn in connections] |
|
|
| if bytes_sent: |
| features[18] = np.mean(bytes_sent) |
| features[19] = np.std(bytes_sent) |
| features[20] = np.std(bytes_sent) / (np.mean(bytes_sent) + 1e-6) |
|
|
| if bytes_recv: |
| features[21] = np.mean(bytes_recv) |
| features[22] = np.std(bytes_recv) |
|
|
| total_sent = sum(bytes_sent) |
| total_recv = sum(bytes_recv) |
| features[23] = total_sent / (total_recv + 1e-6) if total_recv else 0 |
|
|
| if len(bytes_sent) > 1: |
| unique_sizes = len(set(bytes_sent)) |
| features[24] = 1.0 - (unique_sizes / len(bytes_sent)) |
|
|
| features[25] = sum(1 for b in bytes_sent if b < 500) / len(bytes_sent) if bytes_sent else 0 |
|
|
| if bytes_sent: |
| size_hist, _ = np.histogram(bytes_sent, bins=10) |
| size_hist = size_hist / (sum(size_hist) + 1e-6) |
| entropy = -np.sum(size_hist * np.log2(size_hist + 1e-6)) |
| features[26] = entropy / 3.32 |
|
|
| features[27] = len(connections) |
|
|
| |
| if len(timestamps) > 5: |
| intervals = np.diff(timestamps) |
| if len(intervals) > 0: |
| jitter_pattern = np.abs(np.diff(intervals)) |
| if len(jitter_pattern) > 0: |
| features[28] = np.mean(jitter_pattern) / (np.mean(intervals) + 1e-6) |
|
|
| autocorr = np.correlate(intervals - np.mean(intervals), intervals - np.mean(intervals), mode='full') |
| autocorr = autocorr[len(autocorr)//2:] |
| if len(autocorr) > 1: |
| features[29] = autocorr[1] / (autocorr[0] + 1e-6) |
|
|
| if len(timestamps) > 3: |
| intervals = np.diff(timestamps) |
| burst_threshold = np.mean(intervals) * 0.1 |
| bursts = sum(1 for i in intervals if i < burst_threshold) |
| features[30] = bursts / len(intervals) if intervals.size > 0 else 0 |
|
|
| if timestamps.size > 0: |
| session_length = timestamps[-1] - timestamps[0] |
| features[31] = min(session_length / 86400, 1.0) |
|
|
| if len(timestamps) > 10: |
| window_size = len(timestamps) // 5 |
| window_counts = [] |
| for i in range(5): |
| start_idx = i * window_size |
| end_idx = start_idx + window_size |
| window_counts.append(end_idx - start_idx) |
| features[32] = 1.0 - (np.std(window_counts) / (np.mean(window_counts) + 1e-6)) |
|
|
| protocols = [conn.get('protocol', 'tcp').lower() for conn in connections] |
| unique_protocols = len(set(protocols)) |
| features[33] = 1.0 if unique_protocols == 1 else 1.0 / unique_protocols |
|
|
| features[34] = sum(1 for p in dst_ports if p in [80, 443, 8080, 8443]) / len(dst_ports) if dst_ports else 0 |
| features[35] = sum(1 for p in dst_ports if p == 443) / len(dst_ports) if dst_ports else 0 |
|
|
| |
| if timestamps.size > 0: |
| night_hours = sum(1 for ts in timestamps if 0 <= (ts % 86400) / 3600 < 6) |
| features[36] = night_hours / len(timestamps) |
|
|
| if len(timestamps) > 1: |
| intervals = np.diff(timestamps) |
| fast_beacon_ratio = sum(1 for i in intervals if 1 <= i <= 5) / len(intervals) if len(intervals) > 0 else 0 |
| features[37] = fast_beacon_ratio |
|
|
| durations = [conn.get('duration', 0) for conn in connections] |
| if durations: |
| features[38] = np.mean(durations) |
| features[39] = np.std(durations) / (np.mean(durations) + 1e-6) if np.mean(durations) > 0 else 0 |
|
|
| return features.astype(np.float32) |
|
|
|
|
| |
| |
| |
|
|
| class LogParser: |
| """Parses various log formats into connection records.""" |
|
|
| @staticmethod |
| def parse_zeek_conn(log_line: str) -> Optional[Dict]: |
| """Parse Zeek/Bro conn.log format.""" |
| try: |
| parts = log_line.strip().split('\t') |
| if len(parts) >= 15: |
| return { |
| 'timestamp': float(parts[0]), |
| 'src_ip': parts[2], |
| 'src_port': int(parts[3]), |
| 'dst_ip': parts[4], |
| 'dst_port': int(parts[5]), |
| 'protocol': parts[6], |
| 'duration': float(parts[8]) if parts[8] != '-' else 0, |
| 'bytes_sent': int(parts[9]) if parts[9] != '-' else 0, |
| 'bytes_recv': int(parts[10]) if parts[10] != '-' else 0 |
| } |
| except: |
| pass |
| return None |
|
|
| @staticmethod |
| def parse_syslog(log_line: str) -> Optional[Dict]: |
| """Parse common syslog/netflow formats.""" |
| patterns = [ |
| r'(\d{4}-\d{2}-\d{2}[T ]\d{2}:\d{2}:\d{2}).*?(\d+\.\d+\.\d+\.\d+):(\d+)\s*->\s*(\d+\.\d+\.\d+\.\d+):(\d+)', |
| r'src=(\d+\.\d+\.\d+\.\d+).*?dst=(\d+\.\d+\.\d+\.\d+).*?sport=(\d+).*?dport=(\d+)', |
| ] |
|
|
| for pattern in patterns: |
| match = re.search(pattern, log_line) |
| if match: |
| groups = match.groups() |
| try: |
| if len(groups) == 5: |
| return { |
| 'timestamp': groups[0], |
| 'src_ip': groups[1], |
| 'src_port': int(groups[2]), |
| 'dst_ip': groups[3], |
| 'dst_port': int(groups[4]), |
| 'protocol': 'tcp', |
| 'bytes_sent': 0, |
| 'bytes_recv': 0 |
| } |
| except: |
| pass |
| return None |
|
|
| @staticmethod |
| def parse_json(log_line: str) -> Optional[Dict]: |
| """Parse JSON log format.""" |
| try: |
| data = json.loads(log_line) |
| return { |
| 'timestamp': data.get('timestamp', data.get('@timestamp', 0)), |
| 'src_ip': data.get('src_ip', data.get('source_ip', data.get('src', ''))), |
| 'dst_ip': data.get('dst_ip', data.get('dest_ip', data.get('dst', ''))), |
| 'src_port': int(data.get('src_port', data.get('source_port', 0))), |
| 'dst_port': int(data.get('dst_port', data.get('dest_port', 0))), |
| 'protocol': data.get('protocol', 'tcp'), |
| 'bytes_sent': int(data.get('bytes_sent', data.get('bytes_out', 0))), |
| 'bytes_recv': int(data.get('bytes_recv', data.get('bytes_in', 0))), |
| 'duration': float(data.get('duration', 0)) |
| } |
| except: |
| return None |
|
|
|
|
| |
| |
| |
|
|
| class ReconSupport: |
| """ |
| Reconnaissance and enrichment support for scripting. |
| |
| Provides IP analysis, network intelligence, and enrichment functions |
| useful for security automation and scripting. |
| """ |
|
|
| |
| KNOWN_CDNS = { |
| 'cloudflare': ['104.16.0.0/12', '172.64.0.0/13', '131.0.72.0/22'], |
| 'aws': ['52.0.0.0/6', '54.0.0.0/6'], |
| 'google': ['35.190.0.0/16', '35.220.0.0/14', '142.250.0.0/15'], |
| 'azure': ['13.64.0.0/11', '40.64.0.0/10'], |
| 'akamai': ['23.0.0.0/12', '104.64.0.0/10'], |
| } |
|
|
| |
| PRIVATE_RANGES = [ |
| ipaddress.ip_network('10.0.0.0/8'), |
| ipaddress.ip_network('172.16.0.0/12'), |
| ipaddress.ip_network('192.168.0.0/16'), |
| ipaddress.ip_network('127.0.0.0/8'), |
| ipaddress.ip_network('169.254.0.0/16'), |
| ] |
|
|
| @classmethod |
| def analyze_ip(cls, ip: str) -> Dict[str, Any]: |
| """ |
| Analyze an IP address for reconnaissance purposes. |
| |
| Returns enrichment data about the IP. |
| """ |
| result = { |
| 'ip': ip, |
| 'is_valid': False, |
| 'is_private': False, |
| 'is_loopback': False, |
| 'is_multicast': False, |
| 'is_cdn': False, |
| 'cdn_provider': None, |
| 'ip_version': None, |
| 'reverse_dns': None, |
| 'numeric': None, |
| } |
|
|
| try: |
| ip_obj = ipaddress.ip_address(ip) |
| result['is_valid'] = True |
| result['ip_version'] = ip_obj.version |
| result['is_private'] = ip_obj.is_private |
| result['is_loopback'] = ip_obj.is_loopback |
| result['is_multicast'] = ip_obj.is_multicast |
|
|
| |
| if isinstance(ip_obj, ipaddress.IPv4Address): |
| result['numeric'] = int(ip_obj) |
|
|
| |
| for cdn, ranges in cls.KNOWN_CDNS.items(): |
| for range_str in ranges: |
| try: |
| network = ipaddress.ip_network(range_str) |
| if ip_obj in network: |
| result['is_cdn'] = True |
| result['cdn_provider'] = cdn |
| break |
| except: |
| pass |
| if result['is_cdn']: |
| break |
|
|
| |
| try: |
| result['reverse_dns'] = socket.gethostbyaddr(ip)[0] |
| except: |
| pass |
|
|
| except ValueError: |
| pass |
|
|
| return result |
|
|
| @classmethod |
| def analyze_connection_patterns(cls, connections: List[Dict]) -> Dict[str, Any]: |
| """ |
| Analyze connection patterns for reconnaissance. |
| |
| Provides high-level pattern analysis useful for threat hunting. |
| """ |
| if not connections: |
| return {'error': 'No connections provided'} |
|
|
| dst_ips = [conn.get('dst_ip', '') for conn in connections] |
| dst_ports = [conn.get('dst_port', 0) for conn in connections] |
| bytes_sent = [conn.get('bytes_sent', 0) for conn in connections] |
| bytes_recv = [conn.get('bytes_recv', 0) for conn in connections] |
|
|
| timestamps = sorted([conn.get('timestamp', 0) for conn in connections]) |
| intervals = np.diff(timestamps) if len(timestamps) > 1 else [] |
|
|
| |
| unique_dsts = set(dst_ips) |
| dst_analysis = {} |
| for ip in unique_dsts: |
| if ip: |
| dst_analysis[ip] = cls.analyze_ip(ip) |
|
|
| |
| port_counts = defaultdict(int) |
| for port in dst_ports: |
| port_counts[port] += 1 |
|
|
| |
| result = { |
| 'connection_count': len(connections), |
| 'unique_destinations': len(unique_dsts), |
| 'unique_ports': len(set(dst_ports)), |
|
|
| |
| 'timing': { |
| 'duration_seconds': timestamps[-1] - timestamps[0] if len(timestamps) > 1 else 0, |
| 'mean_interval': float(np.mean(intervals)) if len(intervals) > 0 else 0, |
| 'interval_stddev': float(np.std(intervals)) if len(intervals) > 0 else 0, |
| 'interval_cv': float(np.std(intervals) / (np.mean(intervals) + 1e-6)) if len(intervals) > 0 else 0, |
| }, |
|
|
| |
| 'volume': { |
| 'total_sent': sum(bytes_sent), |
| 'total_recv': sum(bytes_recv), |
| 'mean_sent': float(np.mean(bytes_sent)) if bytes_sent else 0, |
| 'mean_recv': float(np.mean(bytes_recv)) if bytes_recv else 0, |
| 'sent_recv_ratio': sum(bytes_sent) / (sum(bytes_recv) + 1e-6) if bytes_recv else 0, |
| }, |
|
|
| |
| 'ports': dict(port_counts), |
|
|
| |
| 'destinations': dst_analysis, |
|
|
| |
| 'indicators': { |
| 'single_destination': len(unique_dsts) == 1, |
| 'consistent_timing': float(np.std(intervals) / (np.mean(intervals) + 1e-6)) < 0.3 if len(intervals) > 0 else False, |
| 'consistent_sizes': float(np.std(bytes_sent) / (np.mean(bytes_sent) + 1e-6)) < 0.2 if bytes_sent and np.mean(bytes_sent) > 0 else False, |
| 'uses_common_port': bool(set(dst_ports) & {80, 443, 53, 22}), |
| 'uses_high_port': any(p > 10000 for p in dst_ports), |
| 'has_cdn_destination': any(d.get('is_cdn', False) for d in dst_analysis.values()), |
| 'all_private_destinations': all(d.get('is_private', False) for d in dst_analysis.values() if d.get('is_valid')), |
| }, |
| } |
|
|
| return result |
|
|
| @classmethod |
| def generate_iocs(cls, connections: List[Dict], result: Dict) -> Dict[str, List[str]]: |
| """ |
| Generate Indicators of Compromise (IOCs) from analysis. |
| |
| Returns IOCs suitable for threat intelligence sharing. |
| """ |
| iocs = { |
| 'ips': [], |
| 'ports': [], |
| 'timing_signatures': [], |
| 'behavioral_indicators': [], |
| } |
|
|
| if not result.get('is_c2', False): |
| return iocs |
|
|
| |
| dst_ips = set(conn.get('dst_ip', '') for conn in connections if conn.get('dst_ip')) |
| iocs['ips'] = list(dst_ips) |
|
|
| |
| dst_ports = set(conn.get('dst_port', 0) for conn in connections if conn.get('dst_port')) |
| iocs['ports'] = [str(p) for p in dst_ports] |
|
|
| |
| timestamps = sorted([conn.get('timestamp', 0) for conn in connections]) |
| if len(timestamps) > 1: |
| intervals = np.diff(timestamps) |
| mean_interval = np.mean(intervals) |
| iocs['timing_signatures'].append(f"beacon_interval:{mean_interval:.1f}s±{np.std(intervals):.1f}s") |
|
|
| |
| if result.get('c2_type'): |
| iocs['behavioral_indicators'].append(f"c2_type:{result['c2_type']}") |
| if result.get('evasion_score', 0) > 0.5: |
| iocs['behavioral_indicators'].append("evasion_detected") |
|
|
| return iocs |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class AnalysisResult: |
| """Structured result from C2 analysis.""" |
| is_c2: bool |
| c2_probability: float |
| anomaly_score: float |
| evasion_score: float |
| confidence: float |
| c2_type: str |
| c2_type_confidence: float |
| detection_method: str |
| immediate_detection: bool |
|
|
| |
| context_applied: bool = False |
| original_probability: float = 0.0 |
| probability_modifier: float = 1.0 |
|
|
| |
| matched_legitimate_pattern: Optional[str] = None |
| legitimate_confidence: float = 0.0 |
|
|
| |
| risk_factors: List[str] = field(default_factory=list) |
| mitigating_factors: List[str] = field(default_factory=list) |
|
|
| |
| service_type: str = "unknown" |
|
|
| |
| recommendations: List[str] = field(default_factory=list) |
|
|
| |
| features: List[float] = field(default_factory=list) |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| return asdict(self) |
|
|
| def __repr__(self) -> str: |
| status = "🚨 C2 DETECTED" if self.is_c2 else "✅ Clean" |
| return f"<AnalysisResult: {status} | prob={self.c2_probability:.3f} | type={self.c2_type}>" |
|
|
|
|
| class C2Sentinel: |
| """ |
| Main API for LogBERT-C2Sentinel. |
| |
| Advanced C2 detection with context inference and reconnaissance support. |
| |
| Usage: |
| # Load pre-trained model |
| sentinel = C2Sentinel.load('c2_sentinel') |
| |
| # Basic analysis |
| result = sentinel.analyze(connections) |
| |
| # With context |
| context = ConnectionContext(process_name='sshd', known_good=True) |
| result = sentinel.analyze(connections, context=context) |
| |
| # Batch analysis |
| results = sentinel.analyze_batch([conn_list1, conn_list2, ...]) |
| |
| # With reconnaissance |
| recon = sentinel.recon.analyze_connection_patterns(connections) |
| iocs = sentinel.recon.generate_iocs(connections, result) |
| """ |
|
|
| def __init__(self, model: LogBERTC2Sentinel, config: C2SentinelConfig, device: str = 'auto'): |
| self.model = model |
| self.config = config |
| self.feature_extractor = FeatureExtractor() |
| self.log_parser = LogParser() |
| self.context_engine = ContextInference() |
| self.recon = ReconSupport() |
|
|
| if device == 'auto': |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| else: |
| self.device = torch.device(device) |
|
|
| self.model.to(self.device) |
| self.model.eval() |
|
|
| def analyze( |
| self, |
| connections: List[Dict], |
| threshold: float = 0.5, |
| context: Optional[ConnectionContext] = None, |
| include_features: bool = False, |
| strict_mode: bool = False |
| ) -> AnalysisResult: |
| """ |
| Analyze connections for C2 activity. |
| |
| Args: |
| connections: List of connection records |
| threshold: Detection threshold (default 0.5, use 0.7 for fewer false positives) |
| context: Optional ConnectionContext with additional metadata |
| include_features: Include raw feature vector in result |
| strict_mode: Require higher confidence for C2 detection |
| |
| Returns: |
| AnalysisResult with comprehensive detection information |
| """ |
| ports = set(conn.get('dst_port', 0) for conn in connections) |
|
|
| |
| result = AnalysisResult( |
| is_c2=False, |
| c2_probability=0.0, |
| anomaly_score=0.0, |
| evasion_score=0.0, |
| confidence=0.0, |
| c2_type='none', |
| c2_type_confidence=0.0, |
| detection_method='none', |
| immediate_detection=False, |
| ) |
|
|
| if not connections: |
| return result |
|
|
| |
| |
| |
|
|
| |
| is_ssh_keepalive, ssh_ka_confidence = self.feature_extractor.check_ssh_keepalive(connections) |
| if is_ssh_keepalive: |
| result.matched_legitimate_pattern = "ssh_keepalive" |
| result.legitimate_confidence = ssh_ka_confidence |
| result.service_type = ServiceType.SSH.value |
| result.mitigating_factors.append(f"Matches SSH keepalive pattern (confidence: {ssh_ka_confidence:.2f})") |
| result.detection_method = DetectionMethod.WHITELIST.value |
| result.recommendations.append("SSH keepalive is normal system behavior") |
|
|
| |
| result.is_c2 = False |
| result.c2_probability = 0.05 |
| result.confidence = ssh_ka_confidence |
| return result |
|
|
| |
| matched_pattern, pattern_confidence = self.feature_extractor.check_legitimate_patterns(connections) |
| if matched_pattern and pattern_confidence > 0.7: |
| result.matched_legitimate_pattern = matched_pattern.name |
| result.legitimate_confidence = pattern_confidence |
| result.service_type = matched_pattern.service_type.value |
| result.mitigating_factors.append(f"Matches {matched_pattern.name} pattern: {matched_pattern.description}") |
|
|
| |
| |
| |
|
|
| is_msf, msf_confidence = self.feature_extractor.check_metasploit_signature(connections) |
| if is_msf: |
| result.is_c2 = True |
| result.c2_probability = msf_confidence |
| result.anomaly_score = 0.95 |
| result.evasion_score = 0.1 |
| result.confidence = msf_confidence |
| result.c2_type = 'metasploit' |
| result.c2_type_confidence = msf_confidence |
| result.immediate_detection = True |
| result.detection_method = DetectionMethod.SIGNATURE.value |
| result.risk_factors.append("Matches Metasploit signature (high-confidence C2 port + behavior)") |
| if include_features: |
| result.features = self.feature_extractor.extract_features(connections).tolist() |
| return result |
|
|
| |
| |
| |
|
|
| features = self.feature_extractor.extract_features(connections) |
| features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(self.device) |
|
|
| with torch.no_grad(): |
| outputs = self.model(features_tensor) |
|
|
| c2_prob = torch.sigmoid(outputs['c2_logits']).item() |
| result.original_probability = c2_prob |
| result.anomaly_score = outputs['anomaly_score'].item() |
| result.evasion_score = outputs['evasion_score'].item() |
| result.confidence = outputs['confidence'].item() |
|
|
| |
| c2_type_probs = F.softmax(outputs['c2_type_logits'], dim=-1) |
| c2_type_idx = torch.argmax(c2_type_probs, dim=-1).item() |
| result.c2_type = FeatureExtractor.C2_TYPES[c2_type_idx] |
| result.c2_type_confidence = c2_type_probs[0, c2_type_idx].item() |
|
|
| |
| |
| |
|
|
| dst_ips = set(conn.get('dst_ip', '') for conn in connections) |
| bytes_recv = [conn.get('bytes_recv', 0) for conn in connections] |
| bytes_sent = [conn.get('bytes_sent', 0) for conn in connections] |
|
|
| recv_cv = np.std(bytes_recv) / (np.mean(bytes_recv) + 1e-6) if bytes_recv else 0 |
| sent_cv = np.std(bytes_sent) / (np.mean(bytes_sent) + 1e-6) if bytes_sent else 0 |
| total_sent = sum(bytes_sent) |
| total_recv = sum(bytes_recv) |
| req_resp_ratio = total_sent / (total_recv + 1e-6) if total_recv else float('inf') |
|
|
| |
| if len(dst_ips) > 5 and bytes_recv and recv_cv > 2: |
| c2_prob *= 0.4 |
| result.mitigating_factors.append("Multiple destinations with high response variance") |
|
|
| |
| if len(dst_ips) == 1 and len(connections) >= 5: |
| timestamps = sorted([c.get('timestamp', 0) for c in connections]) |
| if len(timestamps) > 1: |
| intervals = np.diff(timestamps) |
| mean_interval = np.mean(intervals) if len(intervals) > 0 else 0 |
| interval_cv = np.std(intervals) / (mean_interval + 1e-6) if mean_interval > 0 else 0 |
|
|
| |
| if recv_cv > 0.5: |
| c2_prob *= 0.5 |
| result.mitigating_factors.append("High response size variance (likely data retrieval)") |
| elif recv_cv < 0.2 and sent_cv < 0.2: |
| c2_prob = min(1.0, c2_prob * 1.4) |
| result.risk_factors.append("Very consistent request/response sizes") |
|
|
| |
| if req_resp_ratio < 0.1: |
| c2_prob *= 0.4 |
| result.mitigating_factors.append("Asymmetric traffic (small requests, large responses)") |
| elif 0.2 < req_resp_ratio < 0.8: |
| c2_prob = min(1.0, c2_prob * 1.2) |
| result.risk_factors.append("Balanced request/response ratio (C2-like)") |
|
|
| |
| if interval_cv < 0.3 and mean_interval > 0 and recv_cv < 0.3: |
| c2_prob = min(1.0, c2_prob * 1.3) |
| result.risk_factors.append("Regular timing with consistent sizes") |
|
|
| |
| if mean_interval > 60 and recv_cv < 0.15 and sent_cv < 0.15: |
| c2_prob = min(1.0, c2_prob * 1.5) |
| result.risk_factors.append("APT-style slow beacon pattern") |
|
|
| |
| |
| |
|
|
| if matched_pattern and pattern_confidence > 0.5: |
| |
| discount = 1.0 - (pattern_confidence * 0.7) |
| c2_prob *= discount |
| result.mitigating_factors.append(f"Legitimate pattern match reduces probability by {(1-discount)*100:.0f}%") |
|
|
| |
| |
| |
|
|
| |
| inference = self.context_engine.infer(connections, context) |
|
|
| if inference['probability_modifier'] != 1.0 or context: |
| result.context_applied = True |
| result.probability_modifier = inference['probability_modifier'] |
| c2_prob *= inference['probability_modifier'] |
|
|
| result.risk_factors.extend(inference['risk_factors']) |
| result.mitigating_factors.extend(inference['mitigating_factors']) |
| result.recommendations.extend(inference['recommendations']) |
|
|
| if inference['is_whitelisted']: |
| result.mitigating_factors.append("Destination is whitelisted") |
| if inference['is_blacklisted']: |
| result.risk_factors.append("Destination is blacklisted") |
|
|
| if inference['service_type'] != ServiceType.UNKNOWN: |
| result.service_type = inference['service_type'].value |
|
|
| |
| |
| |
|
|
| |
| effective_threshold = threshold |
| if strict_mode: |
| effective_threshold = max(threshold, 0.7) |
|
|
| result.c2_probability = min(max(c2_prob, 0.0), 1.0) |
| result.is_c2 = result.c2_probability >= effective_threshold |
| result.detection_method = DetectionMethod.ML.value if not result.context_applied else DetectionMethod.CONTEXT.value |
|
|
| if result.is_c2: |
| result.c2_type = FeatureExtractor.C2_TYPES[c2_type_idx] if c2_type_idx > 0 else 'unknown' |
| else: |
| result.c2_type = 'none' |
|
|
| |
| if result.is_c2: |
| result.recommendations.append("Investigate destination IP for known C2 infrastructure") |
| result.recommendations.append("Check for associated process and user activity") |
| if result.evasion_score > 0.5: |
| result.recommendations.append("C2 may be using evasion techniques - correlate with other telemetry") |
|
|
| if include_features: |
| result.features = features.tolist() |
|
|
| return result |
|
|
| def analyze_batch( |
| self, |
| connection_groups: List[List[Dict]], |
| threshold: float = 0.5, |
| contexts: Optional[List[ConnectionContext]] = None, |
| parallel: bool = True |
| ) -> List[AnalysisResult]: |
| """ |
| Analyze multiple connection groups efficiently. |
| |
| Args: |
| connection_groups: List of connection lists to analyze |
| threshold: Detection threshold |
| contexts: Optional list of contexts (one per group) |
| parallel: Use batch processing for efficiency |
| |
| Returns: |
| List of AnalysisResults |
| """ |
| results = [] |
|
|
| for i, connections in enumerate(connection_groups): |
| context = contexts[i] if contexts and i < len(contexts) else None |
| result = self.analyze(connections, threshold=threshold, context=context) |
| results.append(result) |
|
|
| return results |
|
|
| def analyze_logs( |
| self, |
| log_lines: List[str], |
| group_by_dst: bool = True, |
| threshold: float = 0.5 |
| ) -> List[Dict]: |
| """Analyze raw log lines for C2 activity.""" |
| connections = [] |
| for line in log_lines: |
| conn = self.log_parser.parse_json(line) |
| if not conn: |
| conn = self.log_parser.parse_zeek_conn(line) |
| if not conn: |
| conn = self.log_parser.parse_syslog(line) |
| if conn: |
| connections.append(conn) |
|
|
| if not connections: |
| return [] |
|
|
| results = [] |
|
|
| if group_by_dst: |
| grouped = defaultdict(list) |
| for conn in connections: |
| grouped[conn.get('dst_ip', 'unknown')].append(conn) |
|
|
| for dst_ip, group_conns in grouped.items(): |
| if len(group_conns) >= 3: |
| result = self.analyze(group_conns, threshold) |
| result_dict = result.to_dict() |
| result_dict['dst_ip'] = dst_ip |
| result_dict['connection_count'] = len(group_conns) |
| results.append(result_dict) |
| else: |
| result = self.analyze(connections, threshold) |
| result_dict = result.to_dict() |
| result_dict['connection_count'] = len(connections) |
| results.append(result_dict) |
|
|
| return sorted(results, key=lambda x: x['c2_probability'], reverse=True) |
|
|
| def add_whitelist(self, ips: List[str] = None, domains: List[str] = None): |
| """Add IPs or domains to whitelist.""" |
| if ips: |
| for ip in ips: |
| self.context_engine.add_whitelist_ip(ip) |
| if domains: |
| for domain in domains: |
| self.context_engine.add_whitelist_domain(domain) |
|
|
| def add_blacklist(self, ips: List[str] = None, domains: List[str] = None): |
| """Add IPs or domains to blacklist.""" |
| if ips: |
| for ip in ips: |
| self.context_engine.add_blacklist_ip(ip) |
| if domains: |
| for domain in domains: |
| self.context_engine.add_blacklist_domain(domain) |
|
|
| def save(self, path: str): |
| """Save model to safetensors format.""" |
| path = Path(path) |
| model_path = path.with_suffix('.safetensors') |
| save_file(self.model.state_dict(), str(model_path)) |
|
|
| config_path = path.with_suffix('.json') |
| with open(config_path, 'w') as f: |
| json.dump(self.config.to_dict(), f, indent=2) |
|
|
| print(f"Model saved to {model_path}") |
| print(f"Config saved to {config_path}") |
|
|
| @classmethod |
| def load(cls, path: str, device: str = 'auto') -> 'C2Sentinel': |
| """Load model from safetensors format.""" |
| path = Path(path) |
|
|
| if path.suffix == '.safetensors': |
| model_path = path |
| config_path = path.with_suffix('.json') |
| else: |
| model_path = path.with_suffix('.safetensors') |
| config_path = path.with_suffix('.json') |
|
|
| with open(config_path, 'r') as f: |
| config = C2SentinelConfig.from_dict(json.load(f)) |
|
|
| model = LogBERTC2Sentinel(config) |
| state_dict = load_file(str(model_path)) |
| model.load_state_dict(state_dict) |
|
|
| return cls(model, config, device) |
|
|
| @classmethod |
| def from_pretrained(cls, repo_id: str, device: str = 'auto', cache_dir: Optional[str] = None) -> 'C2Sentinel': |
| """ |
| Load model from HuggingFace Hub. |
| |
| Args: |
| repo_id: HuggingFace repository ID (e.g., 'danielostrow/c2sentinel') |
| device: Device to load model on ('auto', 'cpu', 'cuda', 'mps') |
| cache_dir: Optional cache directory for downloaded files |
| |
| Returns: |
| Loaded C2Sentinel instance |
| """ |
| try: |
| from huggingface_hub import hf_hub_download |
| except ImportError: |
| raise ImportError("huggingface_hub is required for from_pretrained. Install with: pip install huggingface_hub") |
|
|
| |
| model_path = hf_hub_download( |
| repo_id=repo_id, |
| filename="c2_sentinel.safetensors", |
| cache_dir=cache_dir |
| ) |
| config_path = hf_hub_download( |
| repo_id=repo_id, |
| filename="c2_sentinel.json", |
| cache_dir=cache_dir |
| ) |
|
|
| |
| with open(config_path, 'r') as f: |
| config = C2SentinelConfig.from_dict(json.load(f)) |
|
|
| |
| model = LogBERTC2Sentinel(config) |
| state_dict = load_file(str(model_path)) |
| model.load_state_dict(state_dict) |
|
|
| return cls(model, config, device) |
|
|
| @classmethod |
| def create_new(cls, device: str = 'auto') -> 'C2Sentinel': |
| """Create a new untrained model.""" |
| config = C2SentinelConfig() |
| model = LogBERTC2Sentinel(config) |
| return cls(model, config, device) |
|
|
|
|
| |
| |
| |
|
|
| def load_model(path: str, device: str = 'auto') -> C2Sentinel: |
| """Load a pre-trained C2Sentinel model.""" |
| return C2Sentinel.load(path, device) |
|
|
|
|
| def create_model(device: str = 'auto') -> C2Sentinel: |
| """Create a new untrained C2Sentinel model.""" |
| return C2Sentinel.create_new(device) |
|
|
|
|
| def quick_analyze(connections: List[Dict], model_path: str = 'c2_sentinel') -> AnalysisResult: |
| """Quick one-shot analysis without keeping model in memory.""" |
| sentinel = C2Sentinel.load(model_path) |
| return sentinel.analyze(connections) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == '__main__': |
| print("LogBERT-C2Sentinel v2.0: Advanced C2 Detection with Context Inference") |
| print("=" * 70) |
|
|
| sentinel = C2Sentinel.create_new() |
| print(f"Model created with {sentinel.config.num_features} features") |
| print(f"Device: {sentinel.device}") |
|
|
| |
| print("\n[TEST 1] Metasploit Meterpreter (port 4444)...") |
| msf_connections = [ |
| {'timestamp': 1000 + i*5, 'dst_ip': '192.168.1.100', 'dst_port': 4444, |
| 'bytes_sent': 150, 'bytes_recv': 400} |
| for i in range(8) |
| ] |
| result = sentinel.analyze(msf_connections) |
| print(f" {result}") |
|
|
| |
| print("\n[TEST 2] SSH Keepalive (should be clean)...") |
| ssh_keepalive = [ |
| {'timestamp': 1000 + i*30, 'dst_ip': '192.168.1.10', 'dst_port': 22, |
| 'bytes_sent': 48, 'bytes_recv': 48} |
| for i in range(15) |
| ] |
| result = sentinel.analyze(ssh_keepalive) |
| print(f" {result}") |
| print(f" Matched pattern: {result.matched_legitimate_pattern}") |
| print(f" Mitigating factors: {result.mitigating_factors}") |
|
|
| |
| print("\n[TEST 3] SSH Keepalive with process context...") |
| context = ConnectionContext(process_name='sshd', known_good=True) |
| result = sentinel.analyze(ssh_keepalive, context=context) |
| print(f" {result}") |
|
|
| |
| print("\n[TEST 4] C2 Beacon on port 443...") |
| c2_beacon = [ |
| {'timestamp': 1000 + i*60, 'dst_ip': '10.10.10.10', 'dst_port': 443, |
| 'bytes_sent': 200, 'bytes_recv': 500} |
| for i in range(10) |
| ] |
| result = sentinel.analyze(c2_beacon) |
| print(f" {result}") |
|
|
| |
| print("\n[TEST 5] Benign web browsing...") |
| import random |
| browsing = [ |
| {'timestamp': 1000 + i*random.uniform(5, 120), |
| 'dst_ip': f"{random.randint(1,200)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}", |
| 'dst_port': 443, |
| 'bytes_sent': random.randint(500, 3000), |
| 'bytes_recv': random.randint(10000, 500000)} |
| for i in range(15) |
| ] |
| result = sentinel.analyze(browsing) |
| print(f" {result}") |
|
|
| |
| print("\n[TEST 6] Reconnaissance support...") |
| ip_info = sentinel.recon.analyze_ip('104.16.132.229') |
| print(f" IP Analysis: {ip_info}") |
|
|
| print("\n" + "=" * 70) |
| print("Model ready for deployment!") |
| print("=" * 70) |
|
|