| |
| """ |
| C2Sentinel Training Script - Phase 2: Multi-Task Learning & Adversarial Hardening |
| |
| Phase 2 adds: |
| 1. C2 type classification (10 framework types) |
| 2. Adversarial beacon patterns (jitter, domain fronting, etc.) |
| 3. Enhanced benign patterns to reduce false positives |
| 4. Confidence calibration |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import Dataset, DataLoader |
| import numpy as np |
| import random |
| from tqdm import tqdm |
| import json |
| import os |
|
|
| from c2sentinel import ( |
| C2Sentinel, C2SentinelConfig, LogBERTC2Sentinel, |
| FeatureExtractor |
| ) |
| from safetensors.torch import save_file, load_file |
|
|
|
|
| class C2TrafficDatasetPhase2(Dataset): |
| """Enhanced dataset with adversarial patterns and multi-task labels.""" |
|
|
| def __init__(self, num_samples=30000, normalize=True, norm_params_path=None): |
| self.samples = [] |
| self.labels = [] |
| self.c2_types = [] |
| self.feature_extractor = FeatureExtractor() |
|
|
| print(f"Generating {num_samples} phase 2 training samples...") |
|
|
| num_c2 = num_samples // 2 |
| num_adversarial = num_c2 // 3 |
| num_standard_c2 = num_c2 - num_adversarial |
| num_benign = num_samples - num_c2 |
|
|
| |
| print(f"\n[1/4] Standard C2 samples ({num_standard_c2})...") |
| for _ in tqdm(range(num_standard_c2), desc="Standard C2"): |
| connections, c2_type = self._generate_standard_c2() |
| features = self.feature_extractor.extract_features(connections) |
| self.samples.append(features) |
| self.labels.append(1) |
| self.c2_types.append(c2_type) |
|
|
| |
| print(f"\n[2/4] Adversarial C2 samples ({num_adversarial})...") |
| for _ in tqdm(range(num_adversarial), desc="Adversarial C2"): |
| connections, c2_type = self._generate_adversarial_c2() |
| features = self.feature_extractor.extract_features(connections) |
| self.samples.append(features) |
| self.labels.append(1) |
| self.c2_types.append(c2_type) |
|
|
| |
| num_standard_benign = num_benign * 2 // 3 |
| print(f"\n[3/4] Standard benign samples ({num_standard_benign})...") |
| for _ in tqdm(range(num_standard_benign), desc="Standard Benign"): |
| connections = self._generate_benign_traffic() |
| features = self.feature_extractor.extract_features(connections) |
| self.samples.append(features) |
| self.labels.append(0) |
| self.c2_types.append(0) |
|
|
| |
| num_edge_benign = num_benign - num_standard_benign |
| print(f"\n[4/4] Edge-case benign samples ({num_edge_benign})...") |
| for _ in tqdm(range(num_edge_benign), desc="Edge Benign"): |
| connections = self._generate_edge_case_benign() |
| features = self.feature_extractor.extract_features(connections) |
| self.samples.append(features) |
| self.labels.append(0) |
| self.c2_types.append(0) |
|
|
| self.samples = np.array(self.samples, dtype=np.float32) |
| self.labels = np.array(self.labels, dtype=np.float32) |
| self.c2_types = np.array(self.c2_types, dtype=np.int64) |
|
|
| |
| if normalize: |
| if norm_params_path and os.path.exists(norm_params_path): |
| print(f"Loading normalization from {norm_params_path}") |
| params = np.load(norm_params_path) |
| self.mean = params['mean'] |
| self.std = params['std'] |
| else: |
| self.mean = np.mean(self.samples, axis=0) |
| self.std = np.std(self.samples, axis=0) + 1e-8 |
| np.savez('normalization_params.npz', mean=self.mean, std=self.std) |
|
|
| self.samples = (self.samples - self.mean) / self.std |
| print(f"Feature stats - mean range: [{self.mean.min():.2f}, {self.mean.max():.2f}]") |
|
|
| |
| indices = np.random.permutation(len(self.samples)) |
| self.samples = self.samples[indices] |
| self.labels = self.labels[indices] |
| self.c2_types = self.c2_types[indices] |
|
|
| |
| c2_count = np.sum(self.labels) |
| print(f"\nDataset: {len(self.labels)} samples") |
| print(f" C2: {int(c2_count)} ({100*c2_count/len(self.labels):.1f}%)") |
| print(f" Benign: {int(len(self.labels) - c2_count)} ({100*(1 - c2_count/len(self.labels)):.1f}%)") |
| print(f" C2 type distribution: {np.bincount(self.c2_types[self.labels == 1].astype(int), minlength=11)[1:]}") |
|
|
| def _generate_standard_c2(self): |
| """Generate standard C2 beacon patterns.""" |
| c2_type = random.randint(1, 10) |
|
|
| |
| c2_profiles = { |
| 1: {'name': 'Metasploit', 'interval': (2, 15), 'jitter': 0.1, 'ports': [4444, 4445, 5555]}, |
| 2: {'name': 'Cobalt Strike', 'interval': (30, 90), 'jitter': 0.2, 'ports': [443, 8443]}, |
| 3: {'name': 'Empire', 'interval': (5, 30), 'jitter': 0.15, 'ports': [443, 8080]}, |
| 4: {'name': 'Covenant', 'interval': (10, 60), 'jitter': 0.1, 'ports': [443, 80]}, |
| 5: {'name': 'Sliver', 'interval': (30, 120), 'jitter': 0.3, 'ports': [443, 8888]}, |
| 6: {'name': 'Brute Ratel', 'interval': (60, 180), 'jitter': 0.2, 'ports': [443]}, |
| 7: {'name': 'Mythic', 'interval': (15, 60), 'jitter': 0.15, 'ports': [443, 7443]}, |
| 8: {'name': 'PoshC2', 'interval': (10, 45), 'jitter': 0.2, 'ports': [443, 8000]}, |
| 9: {'name': 'Havoc', 'interval': (20, 90), 'jitter': 0.25, 'ports': [443, 40056]}, |
| 10: {'name': 'APT Custom', 'interval': (120, 600), 'jitter': 0.05, 'ports': [443, 80]}, |
| } |
|
|
| profile = c2_profiles[c2_type] |
| interval = random.uniform(*profile['interval']) |
| jitter = profile['jitter'] |
| port = random.choice(profile['ports']) |
|
|
| |
| bytes_sent = random.randint(60, 200) |
| bytes_recv = random.randint(40, 150) |
|
|
| dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}" |
| num_connections = random.randint(12, 50) |
|
|
| connections = [] |
| timestamp = 1705600000 |
|
|
| for _ in range(num_connections): |
| actual_interval = interval * (1 + random.uniform(-jitter, jitter)) |
| timestamp += actual_interval |
| size_var = random.uniform(0.95, 1.05) |
|
|
| connections.append({ |
| 'timestamp': timestamp, |
| 'dst_ip': dst_ip, |
| 'dst_port': port, |
| 'bytes_sent': int(bytes_sent * size_var), |
| 'bytes_recv': int(bytes_recv * size_var), |
| 'protocol': 'tcp' |
| }) |
|
|
| return connections, c2_type |
|
|
| def _generate_adversarial_c2(self): |
| """Generate adversarial C2 patterns that try to evade detection.""" |
| c2_type = random.randint(1, 10) |
| evasion = random.choice(['high_jitter', 'variable_size', 'burst_pattern', 'domain_rotation', 'mixed']) |
|
|
| base_interval = random.uniform(30, 120) |
| dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}" |
| num_connections = random.randint(15, 60) |
|
|
| connections = [] |
| timestamp = 1705600000 |
|
|
| if evasion == 'high_jitter': |
| |
| jitter = random.uniform(0.4, 0.7) |
| for _ in range(num_connections): |
| actual_interval = base_interval * (1 + random.uniform(-jitter, jitter)) |
| timestamp += max(5, actual_interval) |
| connections.append({ |
| 'timestamp': timestamp, |
| 'dst_ip': dst_ip, |
| 'dst_port': 443, |
| 'bytes_sent': random.randint(80, 150), |
| 'bytes_recv': random.randint(50, 100), |
| 'protocol': 'tcp' |
| }) |
|
|
| elif evasion == 'variable_size': |
| |
| for _ in range(num_connections): |
| timestamp += base_interval * (1 + random.uniform(-0.1, 0.1)) |
| |
| connections.append({ |
| 'timestamp': timestamp, |
| 'dst_ip': dst_ip, |
| 'dst_port': 443, |
| 'bytes_sent': random.randint(50, 500), |
| 'bytes_recv': random.randint(40, 400), |
| 'protocol': 'tcp' |
| }) |
|
|
| elif evasion == 'burst_pattern': |
| |
| for i in range(num_connections): |
| if i % 8 == 0: |
| |
| for _ in range(random.randint(2, 5)): |
| timestamp += random.uniform(0.5, 3) |
| connections.append({ |
| 'timestamp': timestamp, |
| 'dst_ip': dst_ip, |
| 'dst_port': 443, |
| 'bytes_sent': random.randint(200, 2000), |
| 'bytes_recv': random.randint(500, 5000), |
| 'protocol': 'tcp' |
| }) |
| else: |
| timestamp += base_interval * (1 + random.uniform(-0.15, 0.15)) |
| connections.append({ |
| 'timestamp': timestamp, |
| 'dst_ip': dst_ip, |
| 'dst_port': 443, |
| 'bytes_sent': random.randint(80, 120), |
| 'bytes_recv': random.randint(50, 80), |
| 'protocol': 'tcp' |
| }) |
|
|
| elif evasion == 'domain_rotation': |
| |
| ips = [f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}" |
| for _ in range(random.randint(2, 4))] |
| for _ in range(num_connections): |
| timestamp += base_interval * (1 + random.uniform(-0.2, 0.2)) |
| connections.append({ |
| 'timestamp': timestamp, |
| 'dst_ip': random.choice(ips), |
| 'dst_port': 443, |
| 'bytes_sent': random.randint(80, 150), |
| 'bytes_recv': random.randint(50, 100), |
| 'protocol': 'tcp' |
| }) |
|
|
| else: |
| |
| jitter = random.uniform(0.25, 0.5) |
| for _ in range(num_connections): |
| actual_interval = base_interval * (1 + random.uniform(-jitter, jitter)) |
| timestamp += max(3, actual_interval) |
| connections.append({ |
| 'timestamp': timestamp, |
| 'dst_ip': dst_ip, |
| 'dst_port': random.choice([443, 8443]), |
| 'bytes_sent': random.randint(60, 300), |
| 'bytes_recv': random.randint(40, 200), |
| 'protocol': 'tcp' |
| }) |
|
|
| return connections, c2_type |
|
|
| def _generate_benign_traffic(self): |
| """Generate standard benign traffic patterns.""" |
| pattern = random.choice(['browsing', 'api', 'streaming', 'interactive', 'download', 'email']) |
|
|
| connections = [] |
| timestamp = 1705600000 |
|
|
| if pattern == 'browsing': |
| for _ in range(random.randint(10, 50)): |
| timestamp += random.uniform(0.5, 60) |
| connections.append({ |
| 'timestamp': timestamp, |
| 'dst_ip': f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}", |
| 'dst_port': random.choice([80, 443]), |
| 'bytes_sent': random.randint(200, 5000), |
| 'bytes_recv': random.randint(5000, 500000), |
| 'protocol': 'tcp' |
| }) |
|
|
| elif pattern == 'api': |
| dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}" |
| for _ in range(random.randint(15, 60)): |
| timestamp += random.uniform(0.1, 30) |
| connections.append({ |
| 'timestamp': timestamp, |
| 'dst_ip': dst_ip, |
| 'dst_port': 443, |
| 'bytes_sent': random.randint(100, 5000), |
| 'bytes_recv': random.randint(200, 200000), |
| 'protocol': 'tcp' |
| }) |
|
|
| elif pattern == 'streaming': |
| dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}" |
| for _ in range(random.randint(30, 100)): |
| timestamp += random.uniform(0.02, 2) |
| connections.append({ |
| 'timestamp': timestamp, |
| 'dst_ip': dst_ip, |
| 'dst_port': 443, |
| 'bytes_sent': random.randint(30, 200), |
| 'bytes_recv': random.randint(5000, 200000), |
| 'protocol': 'tcp' |
| }) |
|
|
| elif pattern == 'interactive': |
| dst_ip = f"192.168.{random.randint(0,255)}.{random.randint(1,254)}" |
| for _ in range(random.randint(20, 80)): |
| if random.random() < 0.3: |
| timestamp += random.uniform(5, 60) |
| else: |
| timestamp += random.uniform(0.1, 3) |
| connections.append({ |
| 'timestamp': timestamp, |
| 'dst_ip': dst_ip, |
| 'dst_port': 22, |
| 'bytes_sent': random.randint(20, 1000), |
| 'bytes_recv': random.randint(50, 30000), |
| 'protocol': 'tcp' |
| }) |
|
|
| elif pattern == 'download': |
| dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}" |
| for _ in range(random.randint(50, 200)): |
| timestamp += random.uniform(0.01, 0.5) |
| connections.append({ |
| 'timestamp': timestamp, |
| 'dst_ip': dst_ip, |
| 'dst_port': 443, |
| 'bytes_sent': random.randint(40, 100), |
| 'bytes_recv': random.randint(10000, 65000), |
| 'protocol': 'tcp' |
| }) |
|
|
| else: |
| for _ in range(random.randint(5, 20)): |
| timestamp += random.uniform(30, 300) |
| connections.append({ |
| 'timestamp': timestamp, |
| 'dst_ip': f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}", |
| 'dst_port': random.choice([443, 993, 587]), |
| 'bytes_sent': random.randint(500, 50000), |
| 'bytes_recv': random.randint(1000, 500000), |
| 'protocol': 'tcp' |
| }) |
|
|
| return connections |
|
|
| def _generate_edge_case_benign(self): |
| """Generate benign traffic that looks like C2 but isn't.""" |
| pattern = random.choice([ |
| 'heartbeat', 'monitoring', 'sync', 'keepalive', 'polling', 'iot' |
| ]) |
|
|
| connections = [] |
| timestamp = 1705600000 |
| dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}" |
|
|
| if pattern == 'heartbeat': |
| |
| interval = random.uniform(30, 120) |
| for _ in range(random.randint(15, 40)): |
| timestamp += interval * (1 + random.uniform(-0.05, 0.05)) |
| connections.append({ |
| 'timestamp': timestamp, |
| 'dst_ip': dst_ip, |
| 'dst_port': 443, |
| 'bytes_sent': random.randint(50, 100), |
| 'bytes_recv': random.randint(1000, 50000), |
| 'protocol': 'tcp' |
| }) |
|
|
| elif pattern == 'monitoring': |
| |
| interval = random.uniform(60, 300) |
| for _ in range(random.randint(10, 30)): |
| timestamp += interval * (1 + random.uniform(-0.1, 0.1)) |
| connections.append({ |
| 'timestamp': timestamp, |
| 'dst_ip': dst_ip, |
| 'dst_port': random.choice([443, 8443, 9090]), |
| 'bytes_sent': random.randint(100, 500), |
| 'bytes_recv': random.randint(500, 10000), |
| 'protocol': 'tcp' |
| }) |
|
|
| elif pattern == 'sync': |
| |
| interval = random.uniform(300, 900) |
| for _ in range(random.randint(8, 20)): |
| timestamp += interval * (1 + random.uniform(-0.15, 0.15)) |
| connections.append({ |
| 'timestamp': timestamp, |
| 'dst_ip': dst_ip, |
| 'dst_port': 443, |
| 'bytes_sent': random.randint(1000, 100000), |
| 'bytes_recv': random.randint(1000, 100000), |
| 'protocol': 'tcp' |
| }) |
|
|
| elif pattern == 'keepalive': |
| |
| interval = random.uniform(15, 60) |
| for _ in range(random.randint(20, 60)): |
| timestamp += interval * (1 + random.uniform(-0.02, 0.02)) |
| connections.append({ |
| 'timestamp': timestamp, |
| 'dst_ip': f"192.168.{random.randint(0,255)}.{random.randint(1,254)}", |
| 'dst_port': random.choice([22, 1194, 443]), |
| 'bytes_sent': random.randint(40, 80), |
| 'bytes_recv': random.randint(40, 80), |
| 'protocol': 'tcp' |
| }) |
|
|
| elif pattern == 'polling': |
| |
| interval = random.uniform(10, 60) |
| for _ in range(random.randint(20, 50)): |
| timestamp += interval * (1 + random.uniform(-0.2, 0.2)) |
| connections.append({ |
| 'timestamp': timestamp, |
| 'dst_ip': dst_ip, |
| 'dst_port': 443, |
| 'bytes_sent': random.randint(80, 200), |
| 'bytes_recv': random.randint(100, 50000), |
| 'protocol': 'tcp' |
| }) |
|
|
| else: |
| |
| interval = random.uniform(60, 300) |
| for _ in range(random.randint(15, 40)): |
| timestamp += interval * (1 + random.uniform(-0.1, 0.1)) |
| connections.append({ |
| 'timestamp': timestamp, |
| 'dst_ip': dst_ip, |
| 'dst_port': random.choice([443, 8883, 1883]), |
| 'bytes_sent': random.randint(50, 200), |
| 'bytes_recv': random.randint(50, 200), |
| 'protocol': 'tcp' |
| }) |
|
|
| return connections |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def __getitem__(self, idx): |
| return { |
| 'features': torch.tensor(self.samples[idx]), |
| 'label': torch.tensor(self.labels[idx]), |
| 'c2_type': torch.tensor(self.c2_types[idx]) |
| } |
|
|
|
|
| def train_phase2( |
| pretrained_path='c2_sentinel.safetensors', |
| num_epochs=50, |
| batch_size=32, |
| learning_rate=0.00005, |
| num_samples=30000 |
| ): |
| """Phase 2 training with multi-task learning.""" |
|
|
| print("=" * 70) |
| print("C2Sentinel Phase 2 Training - Multi-Task Learning") |
| print("=" * 70) |
|
|
| config = C2SentinelConfig() |
| model = LogBERTC2Sentinel(config) |
|
|
| |
| if os.path.exists(pretrained_path): |
| print(f"Loading pretrained weights from {pretrained_path}") |
| state_dict = load_file(pretrained_path) |
| model.load_state_dict(state_dict) |
| else: |
| print("WARNING: No pretrained weights found, training from scratch") |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
| model.to(device) |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"Model parameters: {total_params:,}") |
|
|
| |
| norm_path = 'normalization_params.npz' if os.path.exists('normalization_params.npz') else None |
| dataset = C2TrafficDatasetPhase2(num_samples=num_samples, normalize=True, norm_params_path=norm_path) |
|
|
| train_size = int(0.9 * len(dataset)) |
| val_size = len(dataset) - train_size |
| train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) |
|
|
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) |
| val_loader = DataLoader(val_dataset, batch_size=batch_size) |
|
|
| print(f"Train: {train_size}, Val: {val_size}") |
|
|
| |
| c2_criterion = nn.BCEWithLogitsLoss() |
| type_criterion = nn.CrossEntropyLoss(ignore_index=0) |
|
|
| |
| optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01) |
|
|
| |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) |
|
|
| best_val_acc = 0 |
| patience = 10 |
| patience_counter = 0 |
|
|
| for epoch in range(num_epochs): |
| model.train() |
| train_loss = 0 |
| train_c2_correct = 0 |
| train_type_correct = 0 |
| train_total = 0 |
| train_c2_samples = 0 |
|
|
| for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False): |
| features = batch['features'].to(device) |
| labels = batch['label'].to(device) |
| c2_types = batch['c2_type'].to(device) |
|
|
| optimizer.zero_grad() |
| outputs = model(features) |
|
|
| |
| c2_loss = c2_criterion(outputs['c2_logits'].squeeze(), labels) |
|
|
| |
| c2_mask = labels == 1 |
| if c2_mask.sum() > 0 and 'type_logits' in outputs: |
| type_loss = type_criterion(outputs['type_logits'][c2_mask], c2_types[c2_mask]) |
| loss = c2_loss + 0.3 * type_loss |
| else: |
| loss = c2_loss |
|
|
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
| optimizer.step() |
|
|
| train_loss += loss.item() |
|
|
| |
| c2_preds = (torch.sigmoid(outputs['c2_logits'].squeeze()) > 0.5).float() |
| train_c2_correct += (c2_preds == labels).sum().item() |
| train_total += labels.size(0) |
|
|
| |
| if c2_mask.sum() > 0 and 'type_logits' in outputs: |
| type_preds = outputs['type_logits'][c2_mask].argmax(dim=1) |
| train_type_correct += (type_preds == c2_types[c2_mask]).sum().item() |
| train_c2_samples += c2_mask.sum().item() |
|
|
| scheduler.step() |
|
|
| |
| model.eval() |
| val_c2_correct = 0 |
| val_type_correct = 0 |
| val_total = 0 |
| val_c2_samples = 0 |
| val_loss = 0 |
|
|
| |
| val_tp, val_fp, val_tn, val_fn = 0, 0, 0, 0 |
|
|
| with torch.no_grad(): |
| for batch in val_loader: |
| features = batch['features'].to(device) |
| labels = batch['label'].to(device) |
| c2_types = batch['c2_type'].to(device) |
|
|
| outputs = model(features) |
|
|
| c2_loss = c2_criterion(outputs['c2_logits'].squeeze(), labels) |
| val_loss += c2_loss.item() |
|
|
| c2_preds = (torch.sigmoid(outputs['c2_logits'].squeeze()) > 0.5).float() |
| val_c2_correct += (c2_preds == labels).sum().item() |
| val_total += labels.size(0) |
|
|
| |
| val_tp += ((c2_preds == 1) & (labels == 1)).sum().item() |
| val_fp += ((c2_preds == 1) & (labels == 0)).sum().item() |
| val_tn += ((c2_preds == 0) & (labels == 0)).sum().item() |
| val_fn += ((c2_preds == 0) & (labels == 1)).sum().item() |
|
|
| c2_mask = labels == 1 |
| if c2_mask.sum() > 0 and 'type_logits' in outputs: |
| type_preds = outputs['type_logits'][c2_mask].argmax(dim=1) |
| val_type_correct += (type_preds == c2_types[c2_mask]).sum().item() |
| val_c2_samples += c2_mask.sum().item() |
|
|
| train_c2_acc = 100 * train_c2_correct / train_total |
| train_type_acc = 100 * train_type_correct / train_c2_samples if train_c2_samples > 0 else 0 |
| val_c2_acc = 100 * val_c2_correct / val_total |
| val_type_acc = 100 * val_type_correct / val_c2_samples if val_c2_samples > 0 else 0 |
|
|
| precision = val_tp / (val_tp + val_fp) if (val_tp + val_fp) > 0 else 0 |
| recall = val_tp / (val_tp + val_fn) if (val_tp + val_fn) > 0 else 0 |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 |
|
|
| lr = optimizer.param_groups[0]['lr'] |
|
|
| print(f"Epoch {epoch+1}: Loss={train_loss/len(train_loader):.4f}, " |
| f"Train C2={train_c2_acc:.1f}%, Val C2={val_c2_acc:.1f}%, " |
| f"Type Acc={val_type_acc:.1f}%, P={precision:.2f}, R={recall:.2f}, F1={f1:.2f}") |
|
|
| if val_c2_acc > best_val_acc: |
| best_val_acc = val_c2_acc |
| patience_counter = 0 |
| save_file(model.state_dict(), 'c2_sentinel.safetensors') |
| print(f" -> Saved (Val: {val_c2_acc:.1f}%, F1: {f1:.2f})") |
| else: |
| patience_counter += 1 |
| if patience_counter >= patience: |
| print(f"Early stopping at epoch {epoch+1}") |
| break |
|
|
| print(f"\nBest validation C2 accuracy: {best_val_acc:.1f}%") |
| return model, config |
|
|
|
|
| def test_adversarial(): |
| """Test model on adversarial patterns.""" |
| print("\n" + "=" * 70) |
| print("Adversarial Pattern Testing") |
| print("=" * 70) |
|
|
| sentinel = C2Sentinel.load('c2_sentinel') |
|
|
| test_cases = [ |
| ("High-jitter Cobalt Strike", [ |
| {'timestamp': 1705600000 + i * 60 * (1 + random.uniform(-0.5, 0.5)), |
| 'dst_ip': '185.234.72.19', 'dst_port': 443, |
| 'bytes_sent': 92, 'bytes_recv': 48} |
| for i in range(20) |
| ]), |
| ("Burst pattern beacon", [ |
| {'timestamp': 1705600000 + (i * 60 if i % 5 != 0 else i * 60 + random.uniform(0, 5)), |
| 'dst_ip': '45.33.32.156', 'dst_port': 443, |
| 'bytes_sent': 100 if i % 5 != 0 else random.randint(500, 2000), |
| 'bytes_recv': 60 if i % 5 != 0 else random.randint(1000, 5000)} |
| for i in range(25) |
| ]), |
| ("Variable size beacon", [ |
| {'timestamp': 1705600000 + i * 45, |
| 'dst_ip': '10.10.10.10', 'dst_port': 4444, |
| 'bytes_sent': random.randint(50, 300), |
| 'bytes_recv': random.randint(40, 250)} |
| for i in range(18) |
| ]), |
| ("SSH keepalive (should be clean)", [ |
| {'timestamp': 1705600000 + i * 30, |
| 'dst_ip': '192.168.1.50', 'dst_port': 22, |
| 'bytes_sent': 48, 'bytes_recv': 48} |
| for i in range(20) |
| ]), |
| ("API polling (should be clean)", [ |
| {'timestamp': 1705600000 + i * random.uniform(25, 35), |
| 'dst_ip': '52.85.132.99', 'dst_port': 443, |
| 'bytes_sent': 150, 'bytes_recv': random.randint(500, 50000)} |
| for i in range(25) |
| ]), |
| ] |
|
|
| for name, connections in test_cases: |
| result = sentinel.analyze(connections) |
| status = "C2 DETECTED" if result.is_c2 else "Clean" |
| print(f"\n{name}:") |
| print(f" {status} (prob={result.c2_probability:.2%})") |
| if result.risk_factors: |
| for rf in result.risk_factors[:3]: |
| print(f" - {rf}") |
|
|
|
|
| if __name__ == '__main__': |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--epochs', type=int, default=50) |
| parser.add_argument('--samples', type=int, default=30000) |
| parser.add_argument('--batch-size', type=int, default=32) |
| parser.add_argument('--lr', type=float, default=0.00005) |
| parser.add_argument('--pretrained', type=str, default='c2_sentinel.safetensors') |
| parser.add_argument('--test-only', action='store_true') |
| args = parser.parse_args() |
|
|
| if args.test_only: |
| test_adversarial() |
| else: |
| train_phase2( |
| pretrained_path=args.pretrained, |
| num_epochs=args.epochs, |
| batch_size=args.batch_size, |
| learning_rate=args.lr, |
| num_samples=args.samples |
| ) |
| test_adversarial() |
|
|