danielostrow commited on
Commit
679e1eb
·
verified ·
1 Parent(s): 4a02ea8

Add phase 2 multi-task training script

Browse files
Files changed (1) hide show
  1. train_model_phase2.py +704 -0
train_model_phase2.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ C2Sentinel Training Script - Phase 2: Multi-Task Learning & Adversarial Hardening
4
+
5
+ Phase 2 adds:
6
+ 1. C2 type classification (10 framework types)
7
+ 2. Adversarial beacon patterns (jitter, domain fronting, etc.)
8
+ 3. Enhanced benign patterns to reduce false positives
9
+ 4. Confidence calibration
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.optim as optim
15
+ from torch.utils.data import Dataset, DataLoader
16
+ import numpy as np
17
+ import random
18
+ from tqdm import tqdm
19
+ import json
20
+ import os
21
+
22
+ from c2sentinel import (
23
+ C2Sentinel, C2SentinelConfig, LogBERTC2Sentinel,
24
+ FeatureExtractor
25
+ )
26
+ from safetensors.torch import save_file, load_file
27
+
28
+
29
+ class C2TrafficDatasetPhase2(Dataset):
30
+ """Enhanced dataset with adversarial patterns and multi-task labels."""
31
+
32
+ def __init__(self, num_samples=30000, normalize=True, norm_params_path=None):
33
+ self.samples = []
34
+ self.labels = []
35
+ self.c2_types = [] # 0=benign, 1-10=C2 framework types
36
+ self.feature_extractor = FeatureExtractor()
37
+
38
+ print(f"Generating {num_samples} phase 2 training samples...")
39
+
40
+ num_c2 = num_samples // 2
41
+ num_adversarial = num_c2 // 3 # 1/3 of C2 samples are adversarial
42
+ num_standard_c2 = num_c2 - num_adversarial
43
+ num_benign = num_samples - num_c2
44
+
45
+ # Generate standard C2 samples
46
+ print(f"\n[1/4] Standard C2 samples ({num_standard_c2})...")
47
+ for _ in tqdm(range(num_standard_c2), desc="Standard C2"):
48
+ connections, c2_type = self._generate_standard_c2()
49
+ features = self.feature_extractor.extract_features(connections)
50
+ self.samples.append(features)
51
+ self.labels.append(1)
52
+ self.c2_types.append(c2_type)
53
+
54
+ # Generate adversarial C2 samples (harder to detect)
55
+ print(f"\n[2/4] Adversarial C2 samples ({num_adversarial})...")
56
+ for _ in tqdm(range(num_adversarial), desc="Adversarial C2"):
57
+ connections, c2_type = self._generate_adversarial_c2()
58
+ features = self.feature_extractor.extract_features(connections)
59
+ self.samples.append(features)
60
+ self.labels.append(1)
61
+ self.c2_types.append(c2_type)
62
+
63
+ # Generate standard benign samples
64
+ num_standard_benign = num_benign * 2 // 3
65
+ print(f"\n[3/4] Standard benign samples ({num_standard_benign})...")
66
+ for _ in tqdm(range(num_standard_benign), desc="Standard Benign"):
67
+ connections = self._generate_benign_traffic()
68
+ features = self.feature_extractor.extract_features(connections)
69
+ self.samples.append(features)
70
+ self.labels.append(0)
71
+ self.c2_types.append(0)
72
+
73
+ # Generate edge-case benign samples (look like C2 but aren't)
74
+ num_edge_benign = num_benign - num_standard_benign
75
+ print(f"\n[4/4] Edge-case benign samples ({num_edge_benign})...")
76
+ for _ in tqdm(range(num_edge_benign), desc="Edge Benign"):
77
+ connections = self._generate_edge_case_benign()
78
+ features = self.feature_extractor.extract_features(connections)
79
+ self.samples.append(features)
80
+ self.labels.append(0)
81
+ self.c2_types.append(0)
82
+
83
+ self.samples = np.array(self.samples, dtype=np.float32)
84
+ self.labels = np.array(self.labels, dtype=np.float32)
85
+ self.c2_types = np.array(self.c2_types, dtype=np.int64)
86
+
87
+ # Load or compute normalization
88
+ if normalize:
89
+ if norm_params_path and os.path.exists(norm_params_path):
90
+ print(f"Loading normalization from {norm_params_path}")
91
+ params = np.load(norm_params_path)
92
+ self.mean = params['mean']
93
+ self.std = params['std']
94
+ else:
95
+ self.mean = np.mean(self.samples, axis=0)
96
+ self.std = np.std(self.samples, axis=0) + 1e-8
97
+ np.savez('normalization_params.npz', mean=self.mean, std=self.std)
98
+
99
+ self.samples = (self.samples - self.mean) / self.std
100
+ print(f"Feature stats - mean range: [{self.mean.min():.2f}, {self.mean.max():.2f}]")
101
+
102
+ # Shuffle
103
+ indices = np.random.permutation(len(self.samples))
104
+ self.samples = self.samples[indices]
105
+ self.labels = self.labels[indices]
106
+ self.c2_types = self.c2_types[indices]
107
+
108
+ # Report distribution
109
+ c2_count = np.sum(self.labels)
110
+ print(f"\nDataset: {len(self.labels)} samples")
111
+ print(f" C2: {int(c2_count)} ({100*c2_count/len(self.labels):.1f}%)")
112
+ print(f" Benign: {int(len(self.labels) - c2_count)} ({100*(1 - c2_count/len(self.labels)):.1f}%)")
113
+ print(f" C2 type distribution: {np.bincount(self.c2_types[self.labels == 1].astype(int), minlength=11)[1:]}")
114
+
115
+ def _generate_standard_c2(self):
116
+ """Generate standard C2 beacon patterns."""
117
+ c2_type = random.randint(1, 10)
118
+
119
+ # C2 type characteristics
120
+ c2_profiles = {
121
+ 1: {'name': 'Metasploit', 'interval': (2, 15), 'jitter': 0.1, 'ports': [4444, 4445, 5555]},
122
+ 2: {'name': 'Cobalt Strike', 'interval': (30, 90), 'jitter': 0.2, 'ports': [443, 8443]},
123
+ 3: {'name': 'Empire', 'interval': (5, 30), 'jitter': 0.15, 'ports': [443, 8080]},
124
+ 4: {'name': 'Covenant', 'interval': (10, 60), 'jitter': 0.1, 'ports': [443, 80]},
125
+ 5: {'name': 'Sliver', 'interval': (30, 120), 'jitter': 0.3, 'ports': [443, 8888]},
126
+ 6: {'name': 'Brute Ratel', 'interval': (60, 180), 'jitter': 0.2, 'ports': [443]},
127
+ 7: {'name': 'Mythic', 'interval': (15, 60), 'jitter': 0.15, 'ports': [443, 7443]},
128
+ 8: {'name': 'PoshC2', 'interval': (10, 45), 'jitter': 0.2, 'ports': [443, 8000]},
129
+ 9: {'name': 'Havoc', 'interval': (20, 90), 'jitter': 0.25, 'ports': [443, 40056]},
130
+ 10: {'name': 'APT Custom', 'interval': (120, 600), 'jitter': 0.05, 'ports': [443, 80]},
131
+ }
132
+
133
+ profile = c2_profiles[c2_type]
134
+ interval = random.uniform(*profile['interval'])
135
+ jitter = profile['jitter']
136
+ port = random.choice(profile['ports'])
137
+
138
+ # Beacon characteristics
139
+ bytes_sent = random.randint(60, 200)
140
+ bytes_recv = random.randint(40, 150)
141
+
142
+ dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
143
+ num_connections = random.randint(12, 50)
144
+
145
+ connections = []
146
+ timestamp = 1705600000
147
+
148
+ for _ in range(num_connections):
149
+ actual_interval = interval * (1 + random.uniform(-jitter, jitter))
150
+ timestamp += actual_interval
151
+ size_var = random.uniform(0.95, 1.05)
152
+
153
+ connections.append({
154
+ 'timestamp': timestamp,
155
+ 'dst_ip': dst_ip,
156
+ 'dst_port': port,
157
+ 'bytes_sent': int(bytes_sent * size_var),
158
+ 'bytes_recv': int(bytes_recv * size_var),
159
+ 'protocol': 'tcp'
160
+ })
161
+
162
+ return connections, c2_type
163
+
164
+ def _generate_adversarial_c2(self):
165
+ """Generate adversarial C2 patterns that try to evade detection."""
166
+ c2_type = random.randint(1, 10)
167
+ evasion = random.choice(['high_jitter', 'variable_size', 'burst_pattern', 'domain_rotation', 'mixed'])
168
+
169
+ base_interval = random.uniform(30, 120)
170
+ dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
171
+ num_connections = random.randint(15, 60)
172
+
173
+ connections = []
174
+ timestamp = 1705600000
175
+
176
+ if evasion == 'high_jitter':
177
+ # High jitter to look like normal traffic
178
+ jitter = random.uniform(0.4, 0.7)
179
+ for _ in range(num_connections):
180
+ actual_interval = base_interval * (1 + random.uniform(-jitter, jitter))
181
+ timestamp += max(5, actual_interval)
182
+ connections.append({
183
+ 'timestamp': timestamp,
184
+ 'dst_ip': dst_ip,
185
+ 'dst_port': 443,
186
+ 'bytes_sent': random.randint(80, 150),
187
+ 'bytes_recv': random.randint(50, 100),
188
+ 'protocol': 'tcp'
189
+ })
190
+
191
+ elif evasion == 'variable_size':
192
+ # Variable packet sizes but consistent timing
193
+ for _ in range(num_connections):
194
+ timestamp += base_interval * (1 + random.uniform(-0.1, 0.1))
195
+ # Sizes vary more but still bounded
196
+ connections.append({
197
+ 'timestamp': timestamp,
198
+ 'dst_ip': dst_ip,
199
+ 'dst_port': 443,
200
+ 'bytes_sent': random.randint(50, 500),
201
+ 'bytes_recv': random.randint(40, 400),
202
+ 'protocol': 'tcp'
203
+ })
204
+
205
+ elif evasion == 'burst_pattern':
206
+ # Beacon with occasional bursts (simulating commands)
207
+ for i in range(num_connections):
208
+ if i % 8 == 0:
209
+ # Burst
210
+ for _ in range(random.randint(2, 5)):
211
+ timestamp += random.uniform(0.5, 3)
212
+ connections.append({
213
+ 'timestamp': timestamp,
214
+ 'dst_ip': dst_ip,
215
+ 'dst_port': 443,
216
+ 'bytes_sent': random.randint(200, 2000),
217
+ 'bytes_recv': random.randint(500, 5000),
218
+ 'protocol': 'tcp'
219
+ })
220
+ else:
221
+ timestamp += base_interval * (1 + random.uniform(-0.15, 0.15))
222
+ connections.append({
223
+ 'timestamp': timestamp,
224
+ 'dst_ip': dst_ip,
225
+ 'dst_port': 443,
226
+ 'bytes_sent': random.randint(80, 120),
227
+ 'bytes_recv': random.randint(50, 80),
228
+ 'protocol': 'tcp'
229
+ })
230
+
231
+ elif evasion == 'domain_rotation':
232
+ # Multiple IPs (CDN-like) but same beacon pattern
233
+ ips = [f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
234
+ for _ in range(random.randint(2, 4))]
235
+ for _ in range(num_connections):
236
+ timestamp += base_interval * (1 + random.uniform(-0.2, 0.2))
237
+ connections.append({
238
+ 'timestamp': timestamp,
239
+ 'dst_ip': random.choice(ips),
240
+ 'dst_port': 443,
241
+ 'bytes_sent': random.randint(80, 150),
242
+ 'bytes_recv': random.randint(50, 100),
243
+ 'protocol': 'tcp'
244
+ })
245
+
246
+ else: # mixed
247
+ # Mix of evasion techniques
248
+ jitter = random.uniform(0.25, 0.5)
249
+ for _ in range(num_connections):
250
+ actual_interval = base_interval * (1 + random.uniform(-jitter, jitter))
251
+ timestamp += max(3, actual_interval)
252
+ connections.append({
253
+ 'timestamp': timestamp,
254
+ 'dst_ip': dst_ip,
255
+ 'dst_port': random.choice([443, 8443]),
256
+ 'bytes_sent': random.randint(60, 300),
257
+ 'bytes_recv': random.randint(40, 200),
258
+ 'protocol': 'tcp'
259
+ })
260
+
261
+ return connections, c2_type
262
+
263
+ def _generate_benign_traffic(self):
264
+ """Generate standard benign traffic patterns."""
265
+ pattern = random.choice(['browsing', 'api', 'streaming', 'interactive', 'download', 'email'])
266
+
267
+ connections = []
268
+ timestamp = 1705600000
269
+
270
+ if pattern == 'browsing':
271
+ for _ in range(random.randint(10, 50)):
272
+ timestamp += random.uniform(0.5, 60)
273
+ connections.append({
274
+ 'timestamp': timestamp,
275
+ 'dst_ip': f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}",
276
+ 'dst_port': random.choice([80, 443]),
277
+ 'bytes_sent': random.randint(200, 5000),
278
+ 'bytes_recv': random.randint(5000, 500000),
279
+ 'protocol': 'tcp'
280
+ })
281
+
282
+ elif pattern == 'api':
283
+ dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
284
+ for _ in range(random.randint(15, 60)):
285
+ timestamp += random.uniform(0.1, 30)
286
+ connections.append({
287
+ 'timestamp': timestamp,
288
+ 'dst_ip': dst_ip,
289
+ 'dst_port': 443,
290
+ 'bytes_sent': random.randint(100, 5000),
291
+ 'bytes_recv': random.randint(200, 200000),
292
+ 'protocol': 'tcp'
293
+ })
294
+
295
+ elif pattern == 'streaming':
296
+ dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
297
+ for _ in range(random.randint(30, 100)):
298
+ timestamp += random.uniform(0.02, 2)
299
+ connections.append({
300
+ 'timestamp': timestamp,
301
+ 'dst_ip': dst_ip,
302
+ 'dst_port': 443,
303
+ 'bytes_sent': random.randint(30, 200),
304
+ 'bytes_recv': random.randint(5000, 200000),
305
+ 'protocol': 'tcp'
306
+ })
307
+
308
+ elif pattern == 'interactive':
309
+ dst_ip = f"192.168.{random.randint(0,255)}.{random.randint(1,254)}"
310
+ for _ in range(random.randint(20, 80)):
311
+ if random.random() < 0.3:
312
+ timestamp += random.uniform(5, 60)
313
+ else:
314
+ timestamp += random.uniform(0.1, 3)
315
+ connections.append({
316
+ 'timestamp': timestamp,
317
+ 'dst_ip': dst_ip,
318
+ 'dst_port': 22,
319
+ 'bytes_sent': random.randint(20, 1000),
320
+ 'bytes_recv': random.randint(50, 30000),
321
+ 'protocol': 'tcp'
322
+ })
323
+
324
+ elif pattern == 'download':
325
+ dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
326
+ for _ in range(random.randint(50, 200)):
327
+ timestamp += random.uniform(0.01, 0.5)
328
+ connections.append({
329
+ 'timestamp': timestamp,
330
+ 'dst_ip': dst_ip,
331
+ 'dst_port': 443,
332
+ 'bytes_sent': random.randint(40, 100),
333
+ 'bytes_recv': random.randint(10000, 65000),
334
+ 'protocol': 'tcp'
335
+ })
336
+
337
+ else: # email
338
+ for _ in range(random.randint(5, 20)):
339
+ timestamp += random.uniform(30, 300)
340
+ connections.append({
341
+ 'timestamp': timestamp,
342
+ 'dst_ip': f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}",
343
+ 'dst_port': random.choice([443, 993, 587]),
344
+ 'bytes_sent': random.randint(500, 50000),
345
+ 'bytes_recv': random.randint(1000, 500000),
346
+ 'protocol': 'tcp'
347
+ })
348
+
349
+ return connections
350
+
351
+ def _generate_edge_case_benign(self):
352
+ """Generate benign traffic that looks like C2 but isn't."""
353
+ pattern = random.choice([
354
+ 'heartbeat', 'monitoring', 'sync', 'keepalive', 'polling', 'iot'
355
+ ])
356
+
357
+ connections = []
358
+ timestamp = 1705600000
359
+ dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
360
+
361
+ if pattern == 'heartbeat':
362
+ # Regular heartbeat but with large, variable responses
363
+ interval = random.uniform(30, 120)
364
+ for _ in range(random.randint(15, 40)):
365
+ timestamp += interval * (1 + random.uniform(-0.05, 0.05))
366
+ connections.append({
367
+ 'timestamp': timestamp,
368
+ 'dst_ip': dst_ip,
369
+ 'dst_port': 443,
370
+ 'bytes_sent': random.randint(50, 100),
371
+ 'bytes_recv': random.randint(1000, 50000), # Large variable responses
372
+ 'protocol': 'tcp'
373
+ })
374
+
375
+ elif pattern == 'monitoring':
376
+ # Regular monitoring checks with status responses
377
+ interval = random.uniform(60, 300)
378
+ for _ in range(random.randint(10, 30)):
379
+ timestamp += interval * (1 + random.uniform(-0.1, 0.1))
380
+ connections.append({
381
+ 'timestamp': timestamp,
382
+ 'dst_ip': dst_ip,
383
+ 'dst_port': random.choice([443, 8443, 9090]),
384
+ 'bytes_sent': random.randint(100, 500),
385
+ 'bytes_recv': random.randint(500, 10000),
386
+ 'protocol': 'tcp'
387
+ })
388
+
389
+ elif pattern == 'sync':
390
+ # Periodic sync with variable data
391
+ interval = random.uniform(300, 900)
392
+ for _ in range(random.randint(8, 20)):
393
+ timestamp += interval * (1 + random.uniform(-0.15, 0.15))
394
+ connections.append({
395
+ 'timestamp': timestamp,
396
+ 'dst_ip': dst_ip,
397
+ 'dst_port': 443,
398
+ 'bytes_sent': random.randint(1000, 100000),
399
+ 'bytes_recv': random.randint(1000, 100000),
400
+ 'protocol': 'tcp'
401
+ })
402
+
403
+ elif pattern == 'keepalive':
404
+ # SSH/VPN keepalive - very regular but small
405
+ interval = random.uniform(15, 60)
406
+ for _ in range(random.randint(20, 60)):
407
+ timestamp += interval * (1 + random.uniform(-0.02, 0.02))
408
+ connections.append({
409
+ 'timestamp': timestamp,
410
+ 'dst_ip': f"192.168.{random.randint(0,255)}.{random.randint(1,254)}",
411
+ 'dst_port': random.choice([22, 1194, 443]),
412
+ 'bytes_sent': random.randint(40, 80),
413
+ 'bytes_recv': random.randint(40, 80),
414
+ 'protocol': 'tcp'
415
+ })
416
+
417
+ elif pattern == 'polling':
418
+ # API polling - regular but with variable responses
419
+ interval = random.uniform(10, 60)
420
+ for _ in range(random.randint(20, 50)):
421
+ timestamp += interval * (1 + random.uniform(-0.2, 0.2))
422
+ connections.append({
423
+ 'timestamp': timestamp,
424
+ 'dst_ip': dst_ip,
425
+ 'dst_port': 443,
426
+ 'bytes_sent': random.randint(80, 200),
427
+ 'bytes_recv': random.randint(100, 50000), # Highly variable
428
+ 'protocol': 'tcp'
429
+ })
430
+
431
+ else: # iot
432
+ # IoT device - regular small packets
433
+ interval = random.uniform(60, 300)
434
+ for _ in range(random.randint(15, 40)):
435
+ timestamp += interval * (1 + random.uniform(-0.1, 0.1))
436
+ connections.append({
437
+ 'timestamp': timestamp,
438
+ 'dst_ip': dst_ip,
439
+ 'dst_port': random.choice([443, 8883, 1883]),
440
+ 'bytes_sent': random.randint(50, 200),
441
+ 'bytes_recv': random.randint(50, 200),
442
+ 'protocol': 'tcp'
443
+ })
444
+
445
+ return connections
446
+
447
+ def __len__(self):
448
+ return len(self.samples)
449
+
450
+ def __getitem__(self, idx):
451
+ return {
452
+ 'features': torch.tensor(self.samples[idx]),
453
+ 'label': torch.tensor(self.labels[idx]),
454
+ 'c2_type': torch.tensor(self.c2_types[idx])
455
+ }
456
+
457
+
458
+ def train_phase2(
459
+ pretrained_path='c2_sentinel.safetensors',
460
+ num_epochs=50,
461
+ batch_size=32,
462
+ learning_rate=0.00005, # Lower LR for fine-tuning
463
+ num_samples=30000
464
+ ):
465
+ """Phase 2 training with multi-task learning."""
466
+
467
+ print("=" * 70)
468
+ print("C2Sentinel Phase 2 Training - Multi-Task Learning")
469
+ print("=" * 70)
470
+
471
+ config = C2SentinelConfig()
472
+ model = LogBERTC2Sentinel(config)
473
+
474
+ # Load pretrained weights
475
+ if os.path.exists(pretrained_path):
476
+ print(f"Loading pretrained weights from {pretrained_path}")
477
+ state_dict = load_file(pretrained_path)
478
+ model.load_state_dict(state_dict)
479
+ else:
480
+ print("WARNING: No pretrained weights found, training from scratch")
481
+
482
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
483
+ print(f"Device: {device}")
484
+ model.to(device)
485
+
486
+ total_params = sum(p.numel() for p in model.parameters())
487
+ print(f"Model parameters: {total_params:,}")
488
+
489
+ # Create dataset with existing normalization params
490
+ norm_path = 'normalization_params.npz' if os.path.exists('normalization_params.npz') else None
491
+ dataset = C2TrafficDatasetPhase2(num_samples=num_samples, normalize=True, norm_params_path=norm_path)
492
+
493
+ train_size = int(0.9 * len(dataset))
494
+ val_size = len(dataset) - train_size
495
+ train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
496
+
497
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
498
+ val_loader = DataLoader(val_dataset, batch_size=batch_size)
499
+
500
+ print(f"Train: {train_size}, Val: {val_size}")
501
+
502
+ # Multi-task loss: C2 detection + C2 type classification
503
+ c2_criterion = nn.BCEWithLogitsLoss()
504
+ type_criterion = nn.CrossEntropyLoss(ignore_index=0) # Ignore benign (0)
505
+
506
+ # Lower LR for fine-tuning
507
+ optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
508
+
509
+ # Cosine annealing
510
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
511
+
512
+ best_val_acc = 0
513
+ patience = 10
514
+ patience_counter = 0
515
+
516
+ for epoch in range(num_epochs):
517
+ model.train()
518
+ train_loss = 0
519
+ train_c2_correct = 0
520
+ train_type_correct = 0
521
+ train_total = 0
522
+ train_c2_samples = 0
523
+
524
+ for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False):
525
+ features = batch['features'].to(device)
526
+ labels = batch['label'].to(device)
527
+ c2_types = batch['c2_type'].to(device)
528
+
529
+ optimizer.zero_grad()
530
+ outputs = model(features)
531
+
532
+ # C2 detection loss
533
+ c2_loss = c2_criterion(outputs['c2_logits'].squeeze(), labels)
534
+
535
+ # C2 type classification loss (only for C2 samples)
536
+ c2_mask = labels == 1
537
+ if c2_mask.sum() > 0 and 'type_logits' in outputs:
538
+ type_loss = type_criterion(outputs['type_logits'][c2_mask], c2_types[c2_mask])
539
+ loss = c2_loss + 0.3 * type_loss # Weighted combination
540
+ else:
541
+ loss = c2_loss
542
+
543
+ loss.backward()
544
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
545
+ optimizer.step()
546
+
547
+ train_loss += loss.item()
548
+
549
+ # C2 detection accuracy
550
+ c2_preds = (torch.sigmoid(outputs['c2_logits'].squeeze()) > 0.5).float()
551
+ train_c2_correct += (c2_preds == labels).sum().item()
552
+ train_total += labels.size(0)
553
+
554
+ # Type classification accuracy (for C2 samples only)
555
+ if c2_mask.sum() > 0 and 'type_logits' in outputs:
556
+ type_preds = outputs['type_logits'][c2_mask].argmax(dim=1)
557
+ train_type_correct += (type_preds == c2_types[c2_mask]).sum().item()
558
+ train_c2_samples += c2_mask.sum().item()
559
+
560
+ scheduler.step()
561
+
562
+ # Validation
563
+ model.eval()
564
+ val_c2_correct = 0
565
+ val_type_correct = 0
566
+ val_total = 0
567
+ val_c2_samples = 0
568
+ val_loss = 0
569
+
570
+ # Track per-class metrics
571
+ val_tp, val_fp, val_tn, val_fn = 0, 0, 0, 0
572
+
573
+ with torch.no_grad():
574
+ for batch in val_loader:
575
+ features = batch['features'].to(device)
576
+ labels = batch['label'].to(device)
577
+ c2_types = batch['c2_type'].to(device)
578
+
579
+ outputs = model(features)
580
+
581
+ c2_loss = c2_criterion(outputs['c2_logits'].squeeze(), labels)
582
+ val_loss += c2_loss.item()
583
+
584
+ c2_preds = (torch.sigmoid(outputs['c2_logits'].squeeze()) > 0.5).float()
585
+ val_c2_correct += (c2_preds == labels).sum().item()
586
+ val_total += labels.size(0)
587
+
588
+ # Confusion matrix stats
589
+ val_tp += ((c2_preds == 1) & (labels == 1)).sum().item()
590
+ val_fp += ((c2_preds == 1) & (labels == 0)).sum().item()
591
+ val_tn += ((c2_preds == 0) & (labels == 0)).sum().item()
592
+ val_fn += ((c2_preds == 0) & (labels == 1)).sum().item()
593
+
594
+ c2_mask = labels == 1
595
+ if c2_mask.sum() > 0 and 'type_logits' in outputs:
596
+ type_preds = outputs['type_logits'][c2_mask].argmax(dim=1)
597
+ val_type_correct += (type_preds == c2_types[c2_mask]).sum().item()
598
+ val_c2_samples += c2_mask.sum().item()
599
+
600
+ train_c2_acc = 100 * train_c2_correct / train_total
601
+ train_type_acc = 100 * train_type_correct / train_c2_samples if train_c2_samples > 0 else 0
602
+ val_c2_acc = 100 * val_c2_correct / val_total
603
+ val_type_acc = 100 * val_type_correct / val_c2_samples if val_c2_samples > 0 else 0
604
+
605
+ precision = val_tp / (val_tp + val_fp) if (val_tp + val_fp) > 0 else 0
606
+ recall = val_tp / (val_tp + val_fn) if (val_tp + val_fn) > 0 else 0
607
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
608
+
609
+ lr = optimizer.param_groups[0]['lr']
610
+
611
+ print(f"Epoch {epoch+1}: Loss={train_loss/len(train_loader):.4f}, "
612
+ f"Train C2={train_c2_acc:.1f}%, Val C2={val_c2_acc:.1f}%, "
613
+ f"Type Acc={val_type_acc:.1f}%, P={precision:.2f}, R={recall:.2f}, F1={f1:.2f}")
614
+
615
+ if val_c2_acc > best_val_acc:
616
+ best_val_acc = val_c2_acc
617
+ patience_counter = 0
618
+ save_file(model.state_dict(), 'c2_sentinel.safetensors')
619
+ print(f" -> Saved (Val: {val_c2_acc:.1f}%, F1: {f1:.2f})")
620
+ else:
621
+ patience_counter += 1
622
+ if patience_counter >= patience:
623
+ print(f"Early stopping at epoch {epoch+1}")
624
+ break
625
+
626
+ print(f"\nBest validation C2 accuracy: {best_val_acc:.1f}%")
627
+ return model, config
628
+
629
+
630
+ def test_adversarial():
631
+ """Test model on adversarial patterns."""
632
+ print("\n" + "=" * 70)
633
+ print("Adversarial Pattern Testing")
634
+ print("=" * 70)
635
+
636
+ sentinel = C2Sentinel.load('c2_sentinel')
637
+
638
+ test_cases = [
639
+ ("High-jitter Cobalt Strike", [
640
+ {'timestamp': 1705600000 + i * 60 * (1 + random.uniform(-0.5, 0.5)),
641
+ 'dst_ip': '185.234.72.19', 'dst_port': 443,
642
+ 'bytes_sent': 92, 'bytes_recv': 48}
643
+ for i in range(20)
644
+ ]),
645
+ ("Burst pattern beacon", [
646
+ {'timestamp': 1705600000 + (i * 60 if i % 5 != 0 else i * 60 + random.uniform(0, 5)),
647
+ 'dst_ip': '45.33.32.156', 'dst_port': 443,
648
+ 'bytes_sent': 100 if i % 5 != 0 else random.randint(500, 2000),
649
+ 'bytes_recv': 60 if i % 5 != 0 else random.randint(1000, 5000)}
650
+ for i in range(25)
651
+ ]),
652
+ ("Variable size beacon", [
653
+ {'timestamp': 1705600000 + i * 45,
654
+ 'dst_ip': '10.10.10.10', 'dst_port': 4444,
655
+ 'bytes_sent': random.randint(50, 300),
656
+ 'bytes_recv': random.randint(40, 250)}
657
+ for i in range(18)
658
+ ]),
659
+ ("SSH keepalive (should be clean)", [
660
+ {'timestamp': 1705600000 + i * 30,
661
+ 'dst_ip': '192.168.1.50', 'dst_port': 22,
662
+ 'bytes_sent': 48, 'bytes_recv': 48}
663
+ for i in range(20)
664
+ ]),
665
+ ("API polling (should be clean)", [
666
+ {'timestamp': 1705600000 + i * random.uniform(25, 35),
667
+ 'dst_ip': '52.85.132.99', 'dst_port': 443,
668
+ 'bytes_sent': 150, 'bytes_recv': random.randint(500, 50000)}
669
+ for i in range(25)
670
+ ]),
671
+ ]
672
+
673
+ for name, connections in test_cases:
674
+ result = sentinel.analyze(connections)
675
+ status = "C2 DETECTED" if result.is_c2 else "Clean"
676
+ print(f"\n{name}:")
677
+ print(f" {status} (prob={result.c2_probability:.2%})")
678
+ if result.risk_factors:
679
+ for rf in result.risk_factors[:3]:
680
+ print(f" - {rf}")
681
+
682
+
683
+ if __name__ == '__main__':
684
+ import argparse
685
+ parser = argparse.ArgumentParser()
686
+ parser.add_argument('--epochs', type=int, default=50)
687
+ parser.add_argument('--samples', type=int, default=30000)
688
+ parser.add_argument('--batch-size', type=int, default=32)
689
+ parser.add_argument('--lr', type=float, default=0.00005)
690
+ parser.add_argument('--pretrained', type=str, default='c2_sentinel.safetensors')
691
+ parser.add_argument('--test-only', action='store_true')
692
+ args = parser.parse_args()
693
+
694
+ if args.test_only:
695
+ test_adversarial()
696
+ else:
697
+ train_phase2(
698
+ pretrained_path=args.pretrained,
699
+ num_epochs=args.epochs,
700
+ batch_size=args.batch_size,
701
+ learning_rate=args.lr,
702
+ num_samples=args.samples
703
+ )
704
+ test_adversarial()