c2sentinel / train_model_phase2.py
danielostrow's picture
Add phase 2 multi-task training script
679e1eb verified
raw
history blame
29.5 kB
#!/usr/bin/env python3
"""
C2Sentinel Training Script - Phase 2: Multi-Task Learning & Adversarial Hardening
Phase 2 adds:
1. C2 type classification (10 framework types)
2. Adversarial beacon patterns (jitter, domain fronting, etc.)
3. Enhanced benign patterns to reduce false positives
4. Confidence calibration
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from tqdm import tqdm
import json
import os
from c2sentinel import (
C2Sentinel, C2SentinelConfig, LogBERTC2Sentinel,
FeatureExtractor
)
from safetensors.torch import save_file, load_file
class C2TrafficDatasetPhase2(Dataset):
"""Enhanced dataset with adversarial patterns and multi-task labels."""
def __init__(self, num_samples=30000, normalize=True, norm_params_path=None):
self.samples = []
self.labels = []
self.c2_types = [] # 0=benign, 1-10=C2 framework types
self.feature_extractor = FeatureExtractor()
print(f"Generating {num_samples} phase 2 training samples...")
num_c2 = num_samples // 2
num_adversarial = num_c2 // 3 # 1/3 of C2 samples are adversarial
num_standard_c2 = num_c2 - num_adversarial
num_benign = num_samples - num_c2
# Generate standard C2 samples
print(f"\n[1/4] Standard C2 samples ({num_standard_c2})...")
for _ in tqdm(range(num_standard_c2), desc="Standard C2"):
connections, c2_type = self._generate_standard_c2()
features = self.feature_extractor.extract_features(connections)
self.samples.append(features)
self.labels.append(1)
self.c2_types.append(c2_type)
# Generate adversarial C2 samples (harder to detect)
print(f"\n[2/4] Adversarial C2 samples ({num_adversarial})...")
for _ in tqdm(range(num_adversarial), desc="Adversarial C2"):
connections, c2_type = self._generate_adversarial_c2()
features = self.feature_extractor.extract_features(connections)
self.samples.append(features)
self.labels.append(1)
self.c2_types.append(c2_type)
# Generate standard benign samples
num_standard_benign = num_benign * 2 // 3
print(f"\n[3/4] Standard benign samples ({num_standard_benign})...")
for _ in tqdm(range(num_standard_benign), desc="Standard Benign"):
connections = self._generate_benign_traffic()
features = self.feature_extractor.extract_features(connections)
self.samples.append(features)
self.labels.append(0)
self.c2_types.append(0)
# Generate edge-case benign samples (look like C2 but aren't)
num_edge_benign = num_benign - num_standard_benign
print(f"\n[4/4] Edge-case benign samples ({num_edge_benign})...")
for _ in tqdm(range(num_edge_benign), desc="Edge Benign"):
connections = self._generate_edge_case_benign()
features = self.feature_extractor.extract_features(connections)
self.samples.append(features)
self.labels.append(0)
self.c2_types.append(0)
self.samples = np.array(self.samples, dtype=np.float32)
self.labels = np.array(self.labels, dtype=np.float32)
self.c2_types = np.array(self.c2_types, dtype=np.int64)
# Load or compute normalization
if normalize:
if norm_params_path and os.path.exists(norm_params_path):
print(f"Loading normalization from {norm_params_path}")
params = np.load(norm_params_path)
self.mean = params['mean']
self.std = params['std']
else:
self.mean = np.mean(self.samples, axis=0)
self.std = np.std(self.samples, axis=0) + 1e-8
np.savez('normalization_params.npz', mean=self.mean, std=self.std)
self.samples = (self.samples - self.mean) / self.std
print(f"Feature stats - mean range: [{self.mean.min():.2f}, {self.mean.max():.2f}]")
# Shuffle
indices = np.random.permutation(len(self.samples))
self.samples = self.samples[indices]
self.labels = self.labels[indices]
self.c2_types = self.c2_types[indices]
# Report distribution
c2_count = np.sum(self.labels)
print(f"\nDataset: {len(self.labels)} samples")
print(f" C2: {int(c2_count)} ({100*c2_count/len(self.labels):.1f}%)")
print(f" Benign: {int(len(self.labels) - c2_count)} ({100*(1 - c2_count/len(self.labels)):.1f}%)")
print(f" C2 type distribution: {np.bincount(self.c2_types[self.labels == 1].astype(int), minlength=11)[1:]}")
def _generate_standard_c2(self):
"""Generate standard C2 beacon patterns."""
c2_type = random.randint(1, 10)
# C2 type characteristics
c2_profiles = {
1: {'name': 'Metasploit', 'interval': (2, 15), 'jitter': 0.1, 'ports': [4444, 4445, 5555]},
2: {'name': 'Cobalt Strike', 'interval': (30, 90), 'jitter': 0.2, 'ports': [443, 8443]},
3: {'name': 'Empire', 'interval': (5, 30), 'jitter': 0.15, 'ports': [443, 8080]},
4: {'name': 'Covenant', 'interval': (10, 60), 'jitter': 0.1, 'ports': [443, 80]},
5: {'name': 'Sliver', 'interval': (30, 120), 'jitter': 0.3, 'ports': [443, 8888]},
6: {'name': 'Brute Ratel', 'interval': (60, 180), 'jitter': 0.2, 'ports': [443]},
7: {'name': 'Mythic', 'interval': (15, 60), 'jitter': 0.15, 'ports': [443, 7443]},
8: {'name': 'PoshC2', 'interval': (10, 45), 'jitter': 0.2, 'ports': [443, 8000]},
9: {'name': 'Havoc', 'interval': (20, 90), 'jitter': 0.25, 'ports': [443, 40056]},
10: {'name': 'APT Custom', 'interval': (120, 600), 'jitter': 0.05, 'ports': [443, 80]},
}
profile = c2_profiles[c2_type]
interval = random.uniform(*profile['interval'])
jitter = profile['jitter']
port = random.choice(profile['ports'])
# Beacon characteristics
bytes_sent = random.randint(60, 200)
bytes_recv = random.randint(40, 150)
dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
num_connections = random.randint(12, 50)
connections = []
timestamp = 1705600000
for _ in range(num_connections):
actual_interval = interval * (1 + random.uniform(-jitter, jitter))
timestamp += actual_interval
size_var = random.uniform(0.95, 1.05)
connections.append({
'timestamp': timestamp,
'dst_ip': dst_ip,
'dst_port': port,
'bytes_sent': int(bytes_sent * size_var),
'bytes_recv': int(bytes_recv * size_var),
'protocol': 'tcp'
})
return connections, c2_type
def _generate_adversarial_c2(self):
"""Generate adversarial C2 patterns that try to evade detection."""
c2_type = random.randint(1, 10)
evasion = random.choice(['high_jitter', 'variable_size', 'burst_pattern', 'domain_rotation', 'mixed'])
base_interval = random.uniform(30, 120)
dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
num_connections = random.randint(15, 60)
connections = []
timestamp = 1705600000
if evasion == 'high_jitter':
# High jitter to look like normal traffic
jitter = random.uniform(0.4, 0.7)
for _ in range(num_connections):
actual_interval = base_interval * (1 + random.uniform(-jitter, jitter))
timestamp += max(5, actual_interval)
connections.append({
'timestamp': timestamp,
'dst_ip': dst_ip,
'dst_port': 443,
'bytes_sent': random.randint(80, 150),
'bytes_recv': random.randint(50, 100),
'protocol': 'tcp'
})
elif evasion == 'variable_size':
# Variable packet sizes but consistent timing
for _ in range(num_connections):
timestamp += base_interval * (1 + random.uniform(-0.1, 0.1))
# Sizes vary more but still bounded
connections.append({
'timestamp': timestamp,
'dst_ip': dst_ip,
'dst_port': 443,
'bytes_sent': random.randint(50, 500),
'bytes_recv': random.randint(40, 400),
'protocol': 'tcp'
})
elif evasion == 'burst_pattern':
# Beacon with occasional bursts (simulating commands)
for i in range(num_connections):
if i % 8 == 0:
# Burst
for _ in range(random.randint(2, 5)):
timestamp += random.uniform(0.5, 3)
connections.append({
'timestamp': timestamp,
'dst_ip': dst_ip,
'dst_port': 443,
'bytes_sent': random.randint(200, 2000),
'bytes_recv': random.randint(500, 5000),
'protocol': 'tcp'
})
else:
timestamp += base_interval * (1 + random.uniform(-0.15, 0.15))
connections.append({
'timestamp': timestamp,
'dst_ip': dst_ip,
'dst_port': 443,
'bytes_sent': random.randint(80, 120),
'bytes_recv': random.randint(50, 80),
'protocol': 'tcp'
})
elif evasion == 'domain_rotation':
# Multiple IPs (CDN-like) but same beacon pattern
ips = [f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
for _ in range(random.randint(2, 4))]
for _ in range(num_connections):
timestamp += base_interval * (1 + random.uniform(-0.2, 0.2))
connections.append({
'timestamp': timestamp,
'dst_ip': random.choice(ips),
'dst_port': 443,
'bytes_sent': random.randint(80, 150),
'bytes_recv': random.randint(50, 100),
'protocol': 'tcp'
})
else: # mixed
# Mix of evasion techniques
jitter = random.uniform(0.25, 0.5)
for _ in range(num_connections):
actual_interval = base_interval * (1 + random.uniform(-jitter, jitter))
timestamp += max(3, actual_interval)
connections.append({
'timestamp': timestamp,
'dst_ip': dst_ip,
'dst_port': random.choice([443, 8443]),
'bytes_sent': random.randint(60, 300),
'bytes_recv': random.randint(40, 200),
'protocol': 'tcp'
})
return connections, c2_type
def _generate_benign_traffic(self):
"""Generate standard benign traffic patterns."""
pattern = random.choice(['browsing', 'api', 'streaming', 'interactive', 'download', 'email'])
connections = []
timestamp = 1705600000
if pattern == 'browsing':
for _ in range(random.randint(10, 50)):
timestamp += random.uniform(0.5, 60)
connections.append({
'timestamp': timestamp,
'dst_ip': f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}",
'dst_port': random.choice([80, 443]),
'bytes_sent': random.randint(200, 5000),
'bytes_recv': random.randint(5000, 500000),
'protocol': 'tcp'
})
elif pattern == 'api':
dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
for _ in range(random.randint(15, 60)):
timestamp += random.uniform(0.1, 30)
connections.append({
'timestamp': timestamp,
'dst_ip': dst_ip,
'dst_port': 443,
'bytes_sent': random.randint(100, 5000),
'bytes_recv': random.randint(200, 200000),
'protocol': 'tcp'
})
elif pattern == 'streaming':
dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
for _ in range(random.randint(30, 100)):
timestamp += random.uniform(0.02, 2)
connections.append({
'timestamp': timestamp,
'dst_ip': dst_ip,
'dst_port': 443,
'bytes_sent': random.randint(30, 200),
'bytes_recv': random.randint(5000, 200000),
'protocol': 'tcp'
})
elif pattern == 'interactive':
dst_ip = f"192.168.{random.randint(0,255)}.{random.randint(1,254)}"
for _ in range(random.randint(20, 80)):
if random.random() < 0.3:
timestamp += random.uniform(5, 60)
else:
timestamp += random.uniform(0.1, 3)
connections.append({
'timestamp': timestamp,
'dst_ip': dst_ip,
'dst_port': 22,
'bytes_sent': random.randint(20, 1000),
'bytes_recv': random.randint(50, 30000),
'protocol': 'tcp'
})
elif pattern == 'download':
dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
for _ in range(random.randint(50, 200)):
timestamp += random.uniform(0.01, 0.5)
connections.append({
'timestamp': timestamp,
'dst_ip': dst_ip,
'dst_port': 443,
'bytes_sent': random.randint(40, 100),
'bytes_recv': random.randint(10000, 65000),
'protocol': 'tcp'
})
else: # email
for _ in range(random.randint(5, 20)):
timestamp += random.uniform(30, 300)
connections.append({
'timestamp': timestamp,
'dst_ip': f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}",
'dst_port': random.choice([443, 993, 587]),
'bytes_sent': random.randint(500, 50000),
'bytes_recv': random.randint(1000, 500000),
'protocol': 'tcp'
})
return connections
def _generate_edge_case_benign(self):
"""Generate benign traffic that looks like C2 but isn't."""
pattern = random.choice([
'heartbeat', 'monitoring', 'sync', 'keepalive', 'polling', 'iot'
])
connections = []
timestamp = 1705600000
dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
if pattern == 'heartbeat':
# Regular heartbeat but with large, variable responses
interval = random.uniform(30, 120)
for _ in range(random.randint(15, 40)):
timestamp += interval * (1 + random.uniform(-0.05, 0.05))
connections.append({
'timestamp': timestamp,
'dst_ip': dst_ip,
'dst_port': 443,
'bytes_sent': random.randint(50, 100),
'bytes_recv': random.randint(1000, 50000), # Large variable responses
'protocol': 'tcp'
})
elif pattern == 'monitoring':
# Regular monitoring checks with status responses
interval = random.uniform(60, 300)
for _ in range(random.randint(10, 30)):
timestamp += interval * (1 + random.uniform(-0.1, 0.1))
connections.append({
'timestamp': timestamp,
'dst_ip': dst_ip,
'dst_port': random.choice([443, 8443, 9090]),
'bytes_sent': random.randint(100, 500),
'bytes_recv': random.randint(500, 10000),
'protocol': 'tcp'
})
elif pattern == 'sync':
# Periodic sync with variable data
interval = random.uniform(300, 900)
for _ in range(random.randint(8, 20)):
timestamp += interval * (1 + random.uniform(-0.15, 0.15))
connections.append({
'timestamp': timestamp,
'dst_ip': dst_ip,
'dst_port': 443,
'bytes_sent': random.randint(1000, 100000),
'bytes_recv': random.randint(1000, 100000),
'protocol': 'tcp'
})
elif pattern == 'keepalive':
# SSH/VPN keepalive - very regular but small
interval = random.uniform(15, 60)
for _ in range(random.randint(20, 60)):
timestamp += interval * (1 + random.uniform(-0.02, 0.02))
connections.append({
'timestamp': timestamp,
'dst_ip': f"192.168.{random.randint(0,255)}.{random.randint(1,254)}",
'dst_port': random.choice([22, 1194, 443]),
'bytes_sent': random.randint(40, 80),
'bytes_recv': random.randint(40, 80),
'protocol': 'tcp'
})
elif pattern == 'polling':
# API polling - regular but with variable responses
interval = random.uniform(10, 60)
for _ in range(random.randint(20, 50)):
timestamp += interval * (1 + random.uniform(-0.2, 0.2))
connections.append({
'timestamp': timestamp,
'dst_ip': dst_ip,
'dst_port': 443,
'bytes_sent': random.randint(80, 200),
'bytes_recv': random.randint(100, 50000), # Highly variable
'protocol': 'tcp'
})
else: # iot
# IoT device - regular small packets
interval = random.uniform(60, 300)
for _ in range(random.randint(15, 40)):
timestamp += interval * (1 + random.uniform(-0.1, 0.1))
connections.append({
'timestamp': timestamp,
'dst_ip': dst_ip,
'dst_port': random.choice([443, 8883, 1883]),
'bytes_sent': random.randint(50, 200),
'bytes_recv': random.randint(50, 200),
'protocol': 'tcp'
})
return connections
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return {
'features': torch.tensor(self.samples[idx]),
'label': torch.tensor(self.labels[idx]),
'c2_type': torch.tensor(self.c2_types[idx])
}
def train_phase2(
pretrained_path='c2_sentinel.safetensors',
num_epochs=50,
batch_size=32,
learning_rate=0.00005, # Lower LR for fine-tuning
num_samples=30000
):
"""Phase 2 training with multi-task learning."""
print("=" * 70)
print("C2Sentinel Phase 2 Training - Multi-Task Learning")
print("=" * 70)
config = C2SentinelConfig()
model = LogBERTC2Sentinel(config)
# Load pretrained weights
if os.path.exists(pretrained_path):
print(f"Loading pretrained weights from {pretrained_path}")
state_dict = load_file(pretrained_path)
model.load_state_dict(state_dict)
else:
print("WARNING: No pretrained weights found, training from scratch")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
model.to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")
# Create dataset with existing normalization params
norm_path = 'normalization_params.npz' if os.path.exists('normalization_params.npz') else None
dataset = C2TrafficDatasetPhase2(num_samples=num_samples, normalize=True, norm_params_path=norm_path)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
print(f"Train: {train_size}, Val: {val_size}")
# Multi-task loss: C2 detection + C2 type classification
c2_criterion = nn.BCEWithLogitsLoss()
type_criterion = nn.CrossEntropyLoss(ignore_index=0) # Ignore benign (0)
# Lower LR for fine-tuning
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
# Cosine annealing
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
best_val_acc = 0
patience = 10
patience_counter = 0
for epoch in range(num_epochs):
model.train()
train_loss = 0
train_c2_correct = 0
train_type_correct = 0
train_total = 0
train_c2_samples = 0
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False):
features = batch['features'].to(device)
labels = batch['label'].to(device)
c2_types = batch['c2_type'].to(device)
optimizer.zero_grad()
outputs = model(features)
# C2 detection loss
c2_loss = c2_criterion(outputs['c2_logits'].squeeze(), labels)
# C2 type classification loss (only for C2 samples)
c2_mask = labels == 1
if c2_mask.sum() > 0 and 'type_logits' in outputs:
type_loss = type_criterion(outputs['type_logits'][c2_mask], c2_types[c2_mask])
loss = c2_loss + 0.3 * type_loss # Weighted combination
else:
loss = c2_loss
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
train_loss += loss.item()
# C2 detection accuracy
c2_preds = (torch.sigmoid(outputs['c2_logits'].squeeze()) > 0.5).float()
train_c2_correct += (c2_preds == labels).sum().item()
train_total += labels.size(0)
# Type classification accuracy (for C2 samples only)
if c2_mask.sum() > 0 and 'type_logits' in outputs:
type_preds = outputs['type_logits'][c2_mask].argmax(dim=1)
train_type_correct += (type_preds == c2_types[c2_mask]).sum().item()
train_c2_samples += c2_mask.sum().item()
scheduler.step()
# Validation
model.eval()
val_c2_correct = 0
val_type_correct = 0
val_total = 0
val_c2_samples = 0
val_loss = 0
# Track per-class metrics
val_tp, val_fp, val_tn, val_fn = 0, 0, 0, 0
with torch.no_grad():
for batch in val_loader:
features = batch['features'].to(device)
labels = batch['label'].to(device)
c2_types = batch['c2_type'].to(device)
outputs = model(features)
c2_loss = c2_criterion(outputs['c2_logits'].squeeze(), labels)
val_loss += c2_loss.item()
c2_preds = (torch.sigmoid(outputs['c2_logits'].squeeze()) > 0.5).float()
val_c2_correct += (c2_preds == labels).sum().item()
val_total += labels.size(0)
# Confusion matrix stats
val_tp += ((c2_preds == 1) & (labels == 1)).sum().item()
val_fp += ((c2_preds == 1) & (labels == 0)).sum().item()
val_tn += ((c2_preds == 0) & (labels == 0)).sum().item()
val_fn += ((c2_preds == 0) & (labels == 1)).sum().item()
c2_mask = labels == 1
if c2_mask.sum() > 0 and 'type_logits' in outputs:
type_preds = outputs['type_logits'][c2_mask].argmax(dim=1)
val_type_correct += (type_preds == c2_types[c2_mask]).sum().item()
val_c2_samples += c2_mask.sum().item()
train_c2_acc = 100 * train_c2_correct / train_total
train_type_acc = 100 * train_type_correct / train_c2_samples if train_c2_samples > 0 else 0
val_c2_acc = 100 * val_c2_correct / val_total
val_type_acc = 100 * val_type_correct / val_c2_samples if val_c2_samples > 0 else 0
precision = val_tp / (val_tp + val_fp) if (val_tp + val_fp) > 0 else 0
recall = val_tp / (val_tp + val_fn) if (val_tp + val_fn) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
lr = optimizer.param_groups[0]['lr']
print(f"Epoch {epoch+1}: Loss={train_loss/len(train_loader):.4f}, "
f"Train C2={train_c2_acc:.1f}%, Val C2={val_c2_acc:.1f}%, "
f"Type Acc={val_type_acc:.1f}%, P={precision:.2f}, R={recall:.2f}, F1={f1:.2f}")
if val_c2_acc > best_val_acc:
best_val_acc = val_c2_acc
patience_counter = 0
save_file(model.state_dict(), 'c2_sentinel.safetensors')
print(f" -> Saved (Val: {val_c2_acc:.1f}%, F1: {f1:.2f})")
else:
patience_counter += 1
if patience_counter >= patience:
print(f"Early stopping at epoch {epoch+1}")
break
print(f"\nBest validation C2 accuracy: {best_val_acc:.1f}%")
return model, config
def test_adversarial():
"""Test model on adversarial patterns."""
print("\n" + "=" * 70)
print("Adversarial Pattern Testing")
print("=" * 70)
sentinel = C2Sentinel.load('c2_sentinel')
test_cases = [
("High-jitter Cobalt Strike", [
{'timestamp': 1705600000 + i * 60 * (1 + random.uniform(-0.5, 0.5)),
'dst_ip': '185.234.72.19', 'dst_port': 443,
'bytes_sent': 92, 'bytes_recv': 48}
for i in range(20)
]),
("Burst pattern beacon", [
{'timestamp': 1705600000 + (i * 60 if i % 5 != 0 else i * 60 + random.uniform(0, 5)),
'dst_ip': '45.33.32.156', 'dst_port': 443,
'bytes_sent': 100 if i % 5 != 0 else random.randint(500, 2000),
'bytes_recv': 60 if i % 5 != 0 else random.randint(1000, 5000)}
for i in range(25)
]),
("Variable size beacon", [
{'timestamp': 1705600000 + i * 45,
'dst_ip': '10.10.10.10', 'dst_port': 4444,
'bytes_sent': random.randint(50, 300),
'bytes_recv': random.randint(40, 250)}
for i in range(18)
]),
("SSH keepalive (should be clean)", [
{'timestamp': 1705600000 + i * 30,
'dst_ip': '192.168.1.50', 'dst_port': 22,
'bytes_sent': 48, 'bytes_recv': 48}
for i in range(20)
]),
("API polling (should be clean)", [
{'timestamp': 1705600000 + i * random.uniform(25, 35),
'dst_ip': '52.85.132.99', 'dst_port': 443,
'bytes_sent': 150, 'bytes_recv': random.randint(500, 50000)}
for i in range(25)
]),
]
for name, connections in test_cases:
result = sentinel.analyze(connections)
status = "C2 DETECTED" if result.is_c2 else "Clean"
print(f"\n{name}:")
print(f" {status} (prob={result.c2_probability:.2%})")
if result.risk_factors:
for rf in result.risk_factors[:3]:
print(f" - {rf}")
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--samples', type=int, default=30000)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--lr', type=float, default=0.00005)
parser.add_argument('--pretrained', type=str, default='c2_sentinel.safetensors')
parser.add_argument('--test-only', action='store_true')
args = parser.parse_args()
if args.test_only:
test_adversarial()
else:
train_phase2(
pretrained_path=args.pretrained,
num_epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
num_samples=args.samples
)
test_adversarial()