danielostrow commited on
Commit
0c4b9f9
·
verified ·
1 Parent(s): 3ec3978

Remove training files from public view

Browse files
Files changed (1) hide show
  1. train_model_phase2.py +0 -704
train_model_phase2.py DELETED
@@ -1,704 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- C2Sentinel Training Script - Phase 2: Multi-Task Learning & Adversarial Hardening
4
-
5
- Phase 2 adds:
6
- 1. C2 type classification (10 framework types)
7
- 2. Adversarial beacon patterns (jitter, domain fronting, etc.)
8
- 3. Enhanced benign patterns to reduce false positives
9
- 4. Confidence calibration
10
- """
11
-
12
- import torch
13
- import torch.nn as nn
14
- import torch.optim as optim
15
- from torch.utils.data import Dataset, DataLoader
16
- import numpy as np
17
- import random
18
- from tqdm import tqdm
19
- import json
20
- import os
21
-
22
- from c2sentinel import (
23
- C2Sentinel, C2SentinelConfig, LogBERTC2Sentinel,
24
- FeatureExtractor
25
- )
26
- from safetensors.torch import save_file, load_file
27
-
28
-
29
- class C2TrafficDatasetPhase2(Dataset):
30
- """Enhanced dataset with adversarial patterns and multi-task labels."""
31
-
32
- def __init__(self, num_samples=30000, normalize=True, norm_params_path=None):
33
- self.samples = []
34
- self.labels = []
35
- self.c2_types = [] # 0=benign, 1-10=C2 framework types
36
- self.feature_extractor = FeatureExtractor()
37
-
38
- print(f"Generating {num_samples} phase 2 training samples...")
39
-
40
- num_c2 = num_samples // 2
41
- num_adversarial = num_c2 // 3 # 1/3 of C2 samples are adversarial
42
- num_standard_c2 = num_c2 - num_adversarial
43
- num_benign = num_samples - num_c2
44
-
45
- # Generate standard C2 samples
46
- print(f"\n[1/4] Standard C2 samples ({num_standard_c2})...")
47
- for _ in tqdm(range(num_standard_c2), desc="Standard C2"):
48
- connections, c2_type = self._generate_standard_c2()
49
- features = self.feature_extractor.extract_features(connections)
50
- self.samples.append(features)
51
- self.labels.append(1)
52
- self.c2_types.append(c2_type)
53
-
54
- # Generate adversarial C2 samples (harder to detect)
55
- print(f"\n[2/4] Adversarial C2 samples ({num_adversarial})...")
56
- for _ in tqdm(range(num_adversarial), desc="Adversarial C2"):
57
- connections, c2_type = self._generate_adversarial_c2()
58
- features = self.feature_extractor.extract_features(connections)
59
- self.samples.append(features)
60
- self.labels.append(1)
61
- self.c2_types.append(c2_type)
62
-
63
- # Generate standard benign samples
64
- num_standard_benign = num_benign * 2 // 3
65
- print(f"\n[3/4] Standard benign samples ({num_standard_benign})...")
66
- for _ in tqdm(range(num_standard_benign), desc="Standard Benign"):
67
- connections = self._generate_benign_traffic()
68
- features = self.feature_extractor.extract_features(connections)
69
- self.samples.append(features)
70
- self.labels.append(0)
71
- self.c2_types.append(0)
72
-
73
- # Generate edge-case benign samples (look like C2 but aren't)
74
- num_edge_benign = num_benign - num_standard_benign
75
- print(f"\n[4/4] Edge-case benign samples ({num_edge_benign})...")
76
- for _ in tqdm(range(num_edge_benign), desc="Edge Benign"):
77
- connections = self._generate_edge_case_benign()
78
- features = self.feature_extractor.extract_features(connections)
79
- self.samples.append(features)
80
- self.labels.append(0)
81
- self.c2_types.append(0)
82
-
83
- self.samples = np.array(self.samples, dtype=np.float32)
84
- self.labels = np.array(self.labels, dtype=np.float32)
85
- self.c2_types = np.array(self.c2_types, dtype=np.int64)
86
-
87
- # Load or compute normalization
88
- if normalize:
89
- if norm_params_path and os.path.exists(norm_params_path):
90
- print(f"Loading normalization from {norm_params_path}")
91
- params = np.load(norm_params_path)
92
- self.mean = params['mean']
93
- self.std = params['std']
94
- else:
95
- self.mean = np.mean(self.samples, axis=0)
96
- self.std = np.std(self.samples, axis=0) + 1e-8
97
- np.savez('normalization_params.npz', mean=self.mean, std=self.std)
98
-
99
- self.samples = (self.samples - self.mean) / self.std
100
- print(f"Feature stats - mean range: [{self.mean.min():.2f}, {self.mean.max():.2f}]")
101
-
102
- # Shuffle
103
- indices = np.random.permutation(len(self.samples))
104
- self.samples = self.samples[indices]
105
- self.labels = self.labels[indices]
106
- self.c2_types = self.c2_types[indices]
107
-
108
- # Report distribution
109
- c2_count = np.sum(self.labels)
110
- print(f"\nDataset: {len(self.labels)} samples")
111
- print(f" C2: {int(c2_count)} ({100*c2_count/len(self.labels):.1f}%)")
112
- print(f" Benign: {int(len(self.labels) - c2_count)} ({100*(1 - c2_count/len(self.labels)):.1f}%)")
113
- print(f" C2 type distribution: {np.bincount(self.c2_types[self.labels == 1].astype(int), minlength=11)[1:]}")
114
-
115
- def _generate_standard_c2(self):
116
- """Generate standard C2 beacon patterns."""
117
- c2_type = random.randint(1, 10)
118
-
119
- # C2 type characteristics
120
- c2_profiles = {
121
- 1: {'name': 'Metasploit', 'interval': (2, 15), 'jitter': 0.1, 'ports': [4444, 4445, 5555]},
122
- 2: {'name': 'Cobalt Strike', 'interval': (30, 90), 'jitter': 0.2, 'ports': [443, 8443]},
123
- 3: {'name': 'Empire', 'interval': (5, 30), 'jitter': 0.15, 'ports': [443, 8080]},
124
- 4: {'name': 'Covenant', 'interval': (10, 60), 'jitter': 0.1, 'ports': [443, 80]},
125
- 5: {'name': 'Sliver', 'interval': (30, 120), 'jitter': 0.3, 'ports': [443, 8888]},
126
- 6: {'name': 'Brute Ratel', 'interval': (60, 180), 'jitter': 0.2, 'ports': [443]},
127
- 7: {'name': 'Mythic', 'interval': (15, 60), 'jitter': 0.15, 'ports': [443, 7443]},
128
- 8: {'name': 'PoshC2', 'interval': (10, 45), 'jitter': 0.2, 'ports': [443, 8000]},
129
- 9: {'name': 'Havoc', 'interval': (20, 90), 'jitter': 0.25, 'ports': [443, 40056]},
130
- 10: {'name': 'APT Custom', 'interval': (120, 600), 'jitter': 0.05, 'ports': [443, 80]},
131
- }
132
-
133
- profile = c2_profiles[c2_type]
134
- interval = random.uniform(*profile['interval'])
135
- jitter = profile['jitter']
136
- port = random.choice(profile['ports'])
137
-
138
- # Beacon characteristics
139
- bytes_sent = random.randint(60, 200)
140
- bytes_recv = random.randint(40, 150)
141
-
142
- dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
143
- num_connections = random.randint(12, 50)
144
-
145
- connections = []
146
- timestamp = 1705600000
147
-
148
- for _ in range(num_connections):
149
- actual_interval = interval * (1 + random.uniform(-jitter, jitter))
150
- timestamp += actual_interval
151
- size_var = random.uniform(0.95, 1.05)
152
-
153
- connections.append({
154
- 'timestamp': timestamp,
155
- 'dst_ip': dst_ip,
156
- 'dst_port': port,
157
- 'bytes_sent': int(bytes_sent * size_var),
158
- 'bytes_recv': int(bytes_recv * size_var),
159
- 'protocol': 'tcp'
160
- })
161
-
162
- return connections, c2_type
163
-
164
- def _generate_adversarial_c2(self):
165
- """Generate adversarial C2 patterns that try to evade detection."""
166
- c2_type = random.randint(1, 10)
167
- evasion = random.choice(['high_jitter', 'variable_size', 'burst_pattern', 'domain_rotation', 'mixed'])
168
-
169
- base_interval = random.uniform(30, 120)
170
- dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
171
- num_connections = random.randint(15, 60)
172
-
173
- connections = []
174
- timestamp = 1705600000
175
-
176
- if evasion == 'high_jitter':
177
- # High jitter to look like normal traffic
178
- jitter = random.uniform(0.4, 0.7)
179
- for _ in range(num_connections):
180
- actual_interval = base_interval * (1 + random.uniform(-jitter, jitter))
181
- timestamp += max(5, actual_interval)
182
- connections.append({
183
- 'timestamp': timestamp,
184
- 'dst_ip': dst_ip,
185
- 'dst_port': 443,
186
- 'bytes_sent': random.randint(80, 150),
187
- 'bytes_recv': random.randint(50, 100),
188
- 'protocol': 'tcp'
189
- })
190
-
191
- elif evasion == 'variable_size':
192
- # Variable packet sizes but consistent timing
193
- for _ in range(num_connections):
194
- timestamp += base_interval * (1 + random.uniform(-0.1, 0.1))
195
- # Sizes vary more but still bounded
196
- connections.append({
197
- 'timestamp': timestamp,
198
- 'dst_ip': dst_ip,
199
- 'dst_port': 443,
200
- 'bytes_sent': random.randint(50, 500),
201
- 'bytes_recv': random.randint(40, 400),
202
- 'protocol': 'tcp'
203
- })
204
-
205
- elif evasion == 'burst_pattern':
206
- # Beacon with occasional bursts (simulating commands)
207
- for i in range(num_connections):
208
- if i % 8 == 0:
209
- # Burst
210
- for _ in range(random.randint(2, 5)):
211
- timestamp += random.uniform(0.5, 3)
212
- connections.append({
213
- 'timestamp': timestamp,
214
- 'dst_ip': dst_ip,
215
- 'dst_port': 443,
216
- 'bytes_sent': random.randint(200, 2000),
217
- 'bytes_recv': random.randint(500, 5000),
218
- 'protocol': 'tcp'
219
- })
220
- else:
221
- timestamp += base_interval * (1 + random.uniform(-0.15, 0.15))
222
- connections.append({
223
- 'timestamp': timestamp,
224
- 'dst_ip': dst_ip,
225
- 'dst_port': 443,
226
- 'bytes_sent': random.randint(80, 120),
227
- 'bytes_recv': random.randint(50, 80),
228
- 'protocol': 'tcp'
229
- })
230
-
231
- elif evasion == 'domain_rotation':
232
- # Multiple IPs (CDN-like) but same beacon pattern
233
- ips = [f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
234
- for _ in range(random.randint(2, 4))]
235
- for _ in range(num_connections):
236
- timestamp += base_interval * (1 + random.uniform(-0.2, 0.2))
237
- connections.append({
238
- 'timestamp': timestamp,
239
- 'dst_ip': random.choice(ips),
240
- 'dst_port': 443,
241
- 'bytes_sent': random.randint(80, 150),
242
- 'bytes_recv': random.randint(50, 100),
243
- 'protocol': 'tcp'
244
- })
245
-
246
- else: # mixed
247
- # Mix of evasion techniques
248
- jitter = random.uniform(0.25, 0.5)
249
- for _ in range(num_connections):
250
- actual_interval = base_interval * (1 + random.uniform(-jitter, jitter))
251
- timestamp += max(3, actual_interval)
252
- connections.append({
253
- 'timestamp': timestamp,
254
- 'dst_ip': dst_ip,
255
- 'dst_port': random.choice([443, 8443]),
256
- 'bytes_sent': random.randint(60, 300),
257
- 'bytes_recv': random.randint(40, 200),
258
- 'protocol': 'tcp'
259
- })
260
-
261
- return connections, c2_type
262
-
263
- def _generate_benign_traffic(self):
264
- """Generate standard benign traffic patterns."""
265
- pattern = random.choice(['browsing', 'api', 'streaming', 'interactive', 'download', 'email'])
266
-
267
- connections = []
268
- timestamp = 1705600000
269
-
270
- if pattern == 'browsing':
271
- for _ in range(random.randint(10, 50)):
272
- timestamp += random.uniform(0.5, 60)
273
- connections.append({
274
- 'timestamp': timestamp,
275
- 'dst_ip': f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}",
276
- 'dst_port': random.choice([80, 443]),
277
- 'bytes_sent': random.randint(200, 5000),
278
- 'bytes_recv': random.randint(5000, 500000),
279
- 'protocol': 'tcp'
280
- })
281
-
282
- elif pattern == 'api':
283
- dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
284
- for _ in range(random.randint(15, 60)):
285
- timestamp += random.uniform(0.1, 30)
286
- connections.append({
287
- 'timestamp': timestamp,
288
- 'dst_ip': dst_ip,
289
- 'dst_port': 443,
290
- 'bytes_sent': random.randint(100, 5000),
291
- 'bytes_recv': random.randint(200, 200000),
292
- 'protocol': 'tcp'
293
- })
294
-
295
- elif pattern == 'streaming':
296
- dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
297
- for _ in range(random.randint(30, 100)):
298
- timestamp += random.uniform(0.02, 2)
299
- connections.append({
300
- 'timestamp': timestamp,
301
- 'dst_ip': dst_ip,
302
- 'dst_port': 443,
303
- 'bytes_sent': random.randint(30, 200),
304
- 'bytes_recv': random.randint(5000, 200000),
305
- 'protocol': 'tcp'
306
- })
307
-
308
- elif pattern == 'interactive':
309
- dst_ip = f"192.168.{random.randint(0,255)}.{random.randint(1,254)}"
310
- for _ in range(random.randint(20, 80)):
311
- if random.random() < 0.3:
312
- timestamp += random.uniform(5, 60)
313
- else:
314
- timestamp += random.uniform(0.1, 3)
315
- connections.append({
316
- 'timestamp': timestamp,
317
- 'dst_ip': dst_ip,
318
- 'dst_port': 22,
319
- 'bytes_sent': random.randint(20, 1000),
320
- 'bytes_recv': random.randint(50, 30000),
321
- 'protocol': 'tcp'
322
- })
323
-
324
- elif pattern == 'download':
325
- dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
326
- for _ in range(random.randint(50, 200)):
327
- timestamp += random.uniform(0.01, 0.5)
328
- connections.append({
329
- 'timestamp': timestamp,
330
- 'dst_ip': dst_ip,
331
- 'dst_port': 443,
332
- 'bytes_sent': random.randint(40, 100),
333
- 'bytes_recv': random.randint(10000, 65000),
334
- 'protocol': 'tcp'
335
- })
336
-
337
- else: # email
338
- for _ in range(random.randint(5, 20)):
339
- timestamp += random.uniform(30, 300)
340
- connections.append({
341
- 'timestamp': timestamp,
342
- 'dst_ip': f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}",
343
- 'dst_port': random.choice([443, 993, 587]),
344
- 'bytes_sent': random.randint(500, 50000),
345
- 'bytes_recv': random.randint(1000, 500000),
346
- 'protocol': 'tcp'
347
- })
348
-
349
- return connections
350
-
351
- def _generate_edge_case_benign(self):
352
- """Generate benign traffic that looks like C2 but isn't."""
353
- pattern = random.choice([
354
- 'heartbeat', 'monitoring', 'sync', 'keepalive', 'polling', 'iot'
355
- ])
356
-
357
- connections = []
358
- timestamp = 1705600000
359
- dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
360
-
361
- if pattern == 'heartbeat':
362
- # Regular heartbeat but with large, variable responses
363
- interval = random.uniform(30, 120)
364
- for _ in range(random.randint(15, 40)):
365
- timestamp += interval * (1 + random.uniform(-0.05, 0.05))
366
- connections.append({
367
- 'timestamp': timestamp,
368
- 'dst_ip': dst_ip,
369
- 'dst_port': 443,
370
- 'bytes_sent': random.randint(50, 100),
371
- 'bytes_recv': random.randint(1000, 50000), # Large variable responses
372
- 'protocol': 'tcp'
373
- })
374
-
375
- elif pattern == 'monitoring':
376
- # Regular monitoring checks with status responses
377
- interval = random.uniform(60, 300)
378
- for _ in range(random.randint(10, 30)):
379
- timestamp += interval * (1 + random.uniform(-0.1, 0.1))
380
- connections.append({
381
- 'timestamp': timestamp,
382
- 'dst_ip': dst_ip,
383
- 'dst_port': random.choice([443, 8443, 9090]),
384
- 'bytes_sent': random.randint(100, 500),
385
- 'bytes_recv': random.randint(500, 10000),
386
- 'protocol': 'tcp'
387
- })
388
-
389
- elif pattern == 'sync':
390
- # Periodic sync with variable data
391
- interval = random.uniform(300, 900)
392
- for _ in range(random.randint(8, 20)):
393
- timestamp += interval * (1 + random.uniform(-0.15, 0.15))
394
- connections.append({
395
- 'timestamp': timestamp,
396
- 'dst_ip': dst_ip,
397
- 'dst_port': 443,
398
- 'bytes_sent': random.randint(1000, 100000),
399
- 'bytes_recv': random.randint(1000, 100000),
400
- 'protocol': 'tcp'
401
- })
402
-
403
- elif pattern == 'keepalive':
404
- # SSH/VPN keepalive - very regular but small
405
- interval = random.uniform(15, 60)
406
- for _ in range(random.randint(20, 60)):
407
- timestamp += interval * (1 + random.uniform(-0.02, 0.02))
408
- connections.append({
409
- 'timestamp': timestamp,
410
- 'dst_ip': f"192.168.{random.randint(0,255)}.{random.randint(1,254)}",
411
- 'dst_port': random.choice([22, 1194, 443]),
412
- 'bytes_sent': random.randint(40, 80),
413
- 'bytes_recv': random.randint(40, 80),
414
- 'protocol': 'tcp'
415
- })
416
-
417
- elif pattern == 'polling':
418
- # API polling - regular but with variable responses
419
- interval = random.uniform(10, 60)
420
- for _ in range(random.randint(20, 50)):
421
- timestamp += interval * (1 + random.uniform(-0.2, 0.2))
422
- connections.append({
423
- 'timestamp': timestamp,
424
- 'dst_ip': dst_ip,
425
- 'dst_port': 443,
426
- 'bytes_sent': random.randint(80, 200),
427
- 'bytes_recv': random.randint(100, 50000), # Highly variable
428
- 'protocol': 'tcp'
429
- })
430
-
431
- else: # iot
432
- # IoT device - regular small packets
433
- interval = random.uniform(60, 300)
434
- for _ in range(random.randint(15, 40)):
435
- timestamp += interval * (1 + random.uniform(-0.1, 0.1))
436
- connections.append({
437
- 'timestamp': timestamp,
438
- 'dst_ip': dst_ip,
439
- 'dst_port': random.choice([443, 8883, 1883]),
440
- 'bytes_sent': random.randint(50, 200),
441
- 'bytes_recv': random.randint(50, 200),
442
- 'protocol': 'tcp'
443
- })
444
-
445
- return connections
446
-
447
- def __len__(self):
448
- return len(self.samples)
449
-
450
- def __getitem__(self, idx):
451
- return {
452
- 'features': torch.tensor(self.samples[idx]),
453
- 'label': torch.tensor(self.labels[idx]),
454
- 'c2_type': torch.tensor(self.c2_types[idx])
455
- }
456
-
457
-
458
- def train_phase2(
459
- pretrained_path='c2_sentinel.safetensors',
460
- num_epochs=50,
461
- batch_size=32,
462
- learning_rate=0.00005, # Lower LR for fine-tuning
463
- num_samples=30000
464
- ):
465
- """Phase 2 training with multi-task learning."""
466
-
467
- print("=" * 70)
468
- print("C2Sentinel Phase 2 Training - Multi-Task Learning")
469
- print("=" * 70)
470
-
471
- config = C2SentinelConfig()
472
- model = LogBERTC2Sentinel(config)
473
-
474
- # Load pretrained weights
475
- if os.path.exists(pretrained_path):
476
- print(f"Loading pretrained weights from {pretrained_path}")
477
- state_dict = load_file(pretrained_path)
478
- model.load_state_dict(state_dict)
479
- else:
480
- print("WARNING: No pretrained weights found, training from scratch")
481
-
482
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
483
- print(f"Device: {device}")
484
- model.to(device)
485
-
486
- total_params = sum(p.numel() for p in model.parameters())
487
- print(f"Model parameters: {total_params:,}")
488
-
489
- # Create dataset with existing normalization params
490
- norm_path = 'normalization_params.npz' if os.path.exists('normalization_params.npz') else None
491
- dataset = C2TrafficDatasetPhase2(num_samples=num_samples, normalize=True, norm_params_path=norm_path)
492
-
493
- train_size = int(0.9 * len(dataset))
494
- val_size = len(dataset) - train_size
495
- train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
496
-
497
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
498
- val_loader = DataLoader(val_dataset, batch_size=batch_size)
499
-
500
- print(f"Train: {train_size}, Val: {val_size}")
501
-
502
- # Multi-task loss: C2 detection + C2 type classification
503
- c2_criterion = nn.BCEWithLogitsLoss()
504
- type_criterion = nn.CrossEntropyLoss(ignore_index=0) # Ignore benign (0)
505
-
506
- # Lower LR for fine-tuning
507
- optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
508
-
509
- # Cosine annealing
510
- scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
511
-
512
- best_val_acc = 0
513
- patience = 10
514
- patience_counter = 0
515
-
516
- for epoch in range(num_epochs):
517
- model.train()
518
- train_loss = 0
519
- train_c2_correct = 0
520
- train_type_correct = 0
521
- train_total = 0
522
- train_c2_samples = 0
523
-
524
- for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False):
525
- features = batch['features'].to(device)
526
- labels = batch['label'].to(device)
527
- c2_types = batch['c2_type'].to(device)
528
-
529
- optimizer.zero_grad()
530
- outputs = model(features)
531
-
532
- # C2 detection loss
533
- c2_loss = c2_criterion(outputs['c2_logits'].squeeze(), labels)
534
-
535
- # C2 type classification loss (only for C2 samples)
536
- c2_mask = labels == 1
537
- if c2_mask.sum() > 0 and 'type_logits' in outputs:
538
- type_loss = type_criterion(outputs['type_logits'][c2_mask], c2_types[c2_mask])
539
- loss = c2_loss + 0.3 * type_loss # Weighted combination
540
- else:
541
- loss = c2_loss
542
-
543
- loss.backward()
544
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
545
- optimizer.step()
546
-
547
- train_loss += loss.item()
548
-
549
- # C2 detection accuracy
550
- c2_preds = (torch.sigmoid(outputs['c2_logits'].squeeze()) > 0.5).float()
551
- train_c2_correct += (c2_preds == labels).sum().item()
552
- train_total += labels.size(0)
553
-
554
- # Type classification accuracy (for C2 samples only)
555
- if c2_mask.sum() > 0 and 'type_logits' in outputs:
556
- type_preds = outputs['type_logits'][c2_mask].argmax(dim=1)
557
- train_type_correct += (type_preds == c2_types[c2_mask]).sum().item()
558
- train_c2_samples += c2_mask.sum().item()
559
-
560
- scheduler.step()
561
-
562
- # Validation
563
- model.eval()
564
- val_c2_correct = 0
565
- val_type_correct = 0
566
- val_total = 0
567
- val_c2_samples = 0
568
- val_loss = 0
569
-
570
- # Track per-class metrics
571
- val_tp, val_fp, val_tn, val_fn = 0, 0, 0, 0
572
-
573
- with torch.no_grad():
574
- for batch in val_loader:
575
- features = batch['features'].to(device)
576
- labels = batch['label'].to(device)
577
- c2_types = batch['c2_type'].to(device)
578
-
579
- outputs = model(features)
580
-
581
- c2_loss = c2_criterion(outputs['c2_logits'].squeeze(), labels)
582
- val_loss += c2_loss.item()
583
-
584
- c2_preds = (torch.sigmoid(outputs['c2_logits'].squeeze()) > 0.5).float()
585
- val_c2_correct += (c2_preds == labels).sum().item()
586
- val_total += labels.size(0)
587
-
588
- # Confusion matrix stats
589
- val_tp += ((c2_preds == 1) & (labels == 1)).sum().item()
590
- val_fp += ((c2_preds == 1) & (labels == 0)).sum().item()
591
- val_tn += ((c2_preds == 0) & (labels == 0)).sum().item()
592
- val_fn += ((c2_preds == 0) & (labels == 1)).sum().item()
593
-
594
- c2_mask = labels == 1
595
- if c2_mask.sum() > 0 and 'type_logits' in outputs:
596
- type_preds = outputs['type_logits'][c2_mask].argmax(dim=1)
597
- val_type_correct += (type_preds == c2_types[c2_mask]).sum().item()
598
- val_c2_samples += c2_mask.sum().item()
599
-
600
- train_c2_acc = 100 * train_c2_correct / train_total
601
- train_type_acc = 100 * train_type_correct / train_c2_samples if train_c2_samples > 0 else 0
602
- val_c2_acc = 100 * val_c2_correct / val_total
603
- val_type_acc = 100 * val_type_correct / val_c2_samples if val_c2_samples > 0 else 0
604
-
605
- precision = val_tp / (val_tp + val_fp) if (val_tp + val_fp) > 0 else 0
606
- recall = val_tp / (val_tp + val_fn) if (val_tp + val_fn) > 0 else 0
607
- f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
608
-
609
- lr = optimizer.param_groups[0]['lr']
610
-
611
- print(f"Epoch {epoch+1}: Loss={train_loss/len(train_loader):.4f}, "
612
- f"Train C2={train_c2_acc:.1f}%, Val C2={val_c2_acc:.1f}%, "
613
- f"Type Acc={val_type_acc:.1f}%, P={precision:.2f}, R={recall:.2f}, F1={f1:.2f}")
614
-
615
- if val_c2_acc > best_val_acc:
616
- best_val_acc = val_c2_acc
617
- patience_counter = 0
618
- save_file(model.state_dict(), 'c2_sentinel.safetensors')
619
- print(f" -> Saved (Val: {val_c2_acc:.1f}%, F1: {f1:.2f})")
620
- else:
621
- patience_counter += 1
622
- if patience_counter >= patience:
623
- print(f"Early stopping at epoch {epoch+1}")
624
- break
625
-
626
- print(f"\nBest validation C2 accuracy: {best_val_acc:.1f}%")
627
- return model, config
628
-
629
-
630
- def test_adversarial():
631
- """Test model on adversarial patterns."""
632
- print("\n" + "=" * 70)
633
- print("Adversarial Pattern Testing")
634
- print("=" * 70)
635
-
636
- sentinel = C2Sentinel.load('c2_sentinel')
637
-
638
- test_cases = [
639
- ("High-jitter Cobalt Strike", [
640
- {'timestamp': 1705600000 + i * 60 * (1 + random.uniform(-0.5, 0.5)),
641
- 'dst_ip': '185.234.72.19', 'dst_port': 443,
642
- 'bytes_sent': 92, 'bytes_recv': 48}
643
- for i in range(20)
644
- ]),
645
- ("Burst pattern beacon", [
646
- {'timestamp': 1705600000 + (i * 60 if i % 5 != 0 else i * 60 + random.uniform(0, 5)),
647
- 'dst_ip': '45.33.32.156', 'dst_port': 443,
648
- 'bytes_sent': 100 if i % 5 != 0 else random.randint(500, 2000),
649
- 'bytes_recv': 60 if i % 5 != 0 else random.randint(1000, 5000)}
650
- for i in range(25)
651
- ]),
652
- ("Variable size beacon", [
653
- {'timestamp': 1705600000 + i * 45,
654
- 'dst_ip': '10.10.10.10', 'dst_port': 4444,
655
- 'bytes_sent': random.randint(50, 300),
656
- 'bytes_recv': random.randint(40, 250)}
657
- for i in range(18)
658
- ]),
659
- ("SSH keepalive (should be clean)", [
660
- {'timestamp': 1705600000 + i * 30,
661
- 'dst_ip': '192.168.1.50', 'dst_port': 22,
662
- 'bytes_sent': 48, 'bytes_recv': 48}
663
- for i in range(20)
664
- ]),
665
- ("API polling (should be clean)", [
666
- {'timestamp': 1705600000 + i * random.uniform(25, 35),
667
- 'dst_ip': '52.85.132.99', 'dst_port': 443,
668
- 'bytes_sent': 150, 'bytes_recv': random.randint(500, 50000)}
669
- for i in range(25)
670
- ]),
671
- ]
672
-
673
- for name, connections in test_cases:
674
- result = sentinel.analyze(connections)
675
- status = "C2 DETECTED" if result.is_c2 else "Clean"
676
- print(f"\n{name}:")
677
- print(f" {status} (prob={result.c2_probability:.2%})")
678
- if result.risk_factors:
679
- for rf in result.risk_factors[:3]:
680
- print(f" - {rf}")
681
-
682
-
683
- if __name__ == '__main__':
684
- import argparse
685
- parser = argparse.ArgumentParser()
686
- parser.add_argument('--epochs', type=int, default=50)
687
- parser.add_argument('--samples', type=int, default=30000)
688
- parser.add_argument('--batch-size', type=int, default=32)
689
- parser.add_argument('--lr', type=float, default=0.00005)
690
- parser.add_argument('--pretrained', type=str, default='c2_sentinel.safetensors')
691
- parser.add_argument('--test-only', action='store_true')
692
- args = parser.parse_args()
693
-
694
- if args.test_only:
695
- test_adversarial()
696
- else:
697
- train_phase2(
698
- pretrained_path=args.pretrained,
699
- num_epochs=args.epochs,
700
- batch_size=args.batch_size,
701
- learning_rate=args.lr,
702
- num_samples=args.samples
703
- )
704
- test_adversarial()