danielostrow commited on
Commit
4a02ea8
Β·
verified Β·
1 Parent(s): e14e625

Add training script

Browse files
Files changed (1) hide show
  1. train_model.py +399 -0
train_model.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ C2Sentinel Training Script v2 - Improved training with proper normalization
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+ from torch.utils.data import Dataset, DataLoader
10
+ import numpy as np
11
+ import random
12
+ from tqdm import tqdm
13
+ import json
14
+
15
+ from c2sentinel import (
16
+ C2Sentinel, C2SentinelConfig, LogBERTC2Sentinel,
17
+ FeatureExtractor
18
+ )
19
+ from safetensors.torch import save_file
20
+
21
+
22
+ class C2TrafficDataset(Dataset):
23
+ """Dataset with normalized features."""
24
+
25
+ def __init__(self, num_samples=10000, normalize=True):
26
+ self.samples = []
27
+ self.labels = []
28
+ self.c2_types = []
29
+ self.feature_extractor = FeatureExtractor()
30
+
31
+ print(f"Generating {num_samples} training samples...")
32
+
33
+ num_c2 = num_samples // 2
34
+ num_benign = num_samples - num_c2
35
+
36
+ # Generate C2 samples
37
+ for _ in tqdm(range(num_c2), desc="C2 samples"):
38
+ connections, c2_type = self._generate_c2_traffic()
39
+ features = self.feature_extractor.extract_features(connections)
40
+ self.samples.append(features)
41
+ self.labels.append(1)
42
+ self.c2_types.append(c2_type)
43
+
44
+ # Generate benign samples
45
+ for _ in tqdm(range(num_benign), desc="Benign samples"):
46
+ connections = self._generate_benign_traffic()
47
+ features = self.feature_extractor.extract_features(connections)
48
+ self.samples.append(features)
49
+ self.labels.append(0)
50
+ self.c2_types.append(0)
51
+
52
+ self.samples = np.array(self.samples, dtype=np.float32)
53
+ self.labels = np.array(self.labels, dtype=np.float32)
54
+ self.c2_types = np.array(self.c2_types, dtype=np.int64)
55
+
56
+ # Normalize features (critical for training stability)
57
+ if normalize:
58
+ self.mean = np.mean(self.samples, axis=0)
59
+ self.std = np.std(self.samples, axis=0) + 1e-8
60
+ self.samples = (self.samples - self.mean) / self.std
61
+
62
+ # Save normalization params
63
+ np.savez('normalization_params.npz', mean=self.mean, std=self.std)
64
+ print(f"Feature stats - mean range: [{self.mean.min():.2f}, {self.mean.max():.2f}], "
65
+ f"std range: [{self.std.min():.4f}, {self.std.max():.2f}]")
66
+
67
+ # Shuffle
68
+ indices = np.random.permutation(len(self.samples))
69
+ self.samples = self.samples[indices]
70
+ self.labels = self.labels[indices]
71
+ self.c2_types = self.c2_types[indices]
72
+
73
+ print(f"C2 samples: {np.sum(self.labels)}, Benign: {len(self.labels) - np.sum(self.labels)}")
74
+
75
+ def _generate_c2_traffic(self):
76
+ """Generate C2 beacon traffic with clear patterns."""
77
+ c2_type = random.randint(1, 10)
78
+
79
+ # Strong C2 characteristics
80
+ if c2_type <= 3: # Fast beacon (Metasploit-style)
81
+ interval = random.uniform(2, 15)
82
+ jitter = random.uniform(0, 0.15) # Low jitter
83
+ port = random.choice([4444, 4445, 5555, 443])
84
+ bytes_sent = random.randint(80, 200)
85
+ bytes_recv = random.randint(40, 150)
86
+ elif c2_type <= 6: # Medium beacon (Cobalt Strike-style)
87
+ interval = random.uniform(30, 90)
88
+ jitter = random.uniform(0, 0.2)
89
+ port = 443
90
+ bytes_sent = random.randint(60, 150)
91
+ bytes_recv = random.randint(40, 100)
92
+ else: # Slow beacon (APT-style)
93
+ interval = random.uniform(120, 300)
94
+ jitter = random.uniform(0, 0.1) # Very low jitter for APT
95
+ port = 443
96
+ bytes_sent = random.randint(50, 120)
97
+ bytes_recv = random.randint(40, 80)
98
+
99
+ # Single destination (key C2 indicator)
100
+ dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
101
+ num_connections = random.randint(10, 40)
102
+
103
+ connections = []
104
+ timestamp = 1705600000
105
+
106
+ for _ in range(num_connections):
107
+ actual_interval = interval * (1 + random.uniform(-jitter, jitter))
108
+ timestamp += actual_interval
109
+
110
+ # Very consistent sizes (key C2 indicator)
111
+ size_var = random.uniform(0.95, 1.05)
112
+
113
+ connections.append({
114
+ 'timestamp': timestamp,
115
+ 'dst_ip': dst_ip,
116
+ 'dst_port': port,
117
+ 'bytes_sent': int(bytes_sent * size_var),
118
+ 'bytes_recv': int(bytes_recv * size_var),
119
+ 'protocol': 'tcp'
120
+ })
121
+
122
+ return connections, c2_type
123
+
124
+ def _generate_benign_traffic(self):
125
+ """Generate clearly benign traffic."""
126
+ pattern = random.choice(['browsing', 'api', 'streaming', 'interactive'])
127
+
128
+ connections = []
129
+ timestamp = 1705600000
130
+
131
+ if pattern == 'browsing':
132
+ # Multiple destinations, highly variable sizes
133
+ for _ in range(random.randint(10, 40)):
134
+ timestamp += random.uniform(0.5, 45)
135
+ connections.append({
136
+ 'timestamp': timestamp,
137
+ 'dst_ip': f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}",
138
+ 'dst_port': random.choice([80, 443]),
139
+ 'bytes_sent': random.randint(200, 5000),
140
+ 'bytes_recv': random.randint(5000, 500000),
141
+ 'protocol': 'tcp'
142
+ })
143
+
144
+ elif pattern == 'api':
145
+ # Single dest but HIGHLY variable response sizes
146
+ dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
147
+ for _ in range(random.randint(15, 40)):
148
+ timestamp += random.uniform(0.1, 20)
149
+ connections.append({
150
+ 'timestamp': timestamp,
151
+ 'dst_ip': dst_ip,
152
+ 'dst_port': 443,
153
+ 'bytes_sent': random.randint(100, 3000),
154
+ 'bytes_recv': random.randint(200, 100000), # Highly variable
155
+ 'protocol': 'tcp'
156
+ })
157
+
158
+ elif pattern == 'streaming':
159
+ # Large downloads, irregular timing
160
+ dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
161
+ for _ in range(random.randint(20, 60)):
162
+ timestamp += random.uniform(0.05, 3)
163
+ connections.append({
164
+ 'timestamp': timestamp,
165
+ 'dst_ip': dst_ip,
166
+ 'dst_port': 443,
167
+ 'bytes_sent': random.randint(30, 200),
168
+ 'bytes_recv': random.randint(5000, 150000),
169
+ 'protocol': 'tcp'
170
+ })
171
+
172
+ else: # interactive (ssh-like)
173
+ dst_ip = f"192.168.{random.randint(0,255)}.{random.randint(1,254)}"
174
+ for _ in range(random.randint(15, 50)):
175
+ if random.random() < 0.3:
176
+ timestamp += random.uniform(3, 45) # Thinking
177
+ else:
178
+ timestamp += random.uniform(0.1, 2) # Typing
179
+ connections.append({
180
+ 'timestamp': timestamp,
181
+ 'dst_ip': dst_ip,
182
+ 'dst_port': 22,
183
+ 'bytes_sent': random.randint(20, 800),
184
+ 'bytes_recv': random.randint(50, 20000),
185
+ 'protocol': 'tcp'
186
+ })
187
+
188
+ return connections
189
+
190
+ def __len__(self):
191
+ return len(self.samples)
192
+
193
+ def __getitem__(self, idx):
194
+ return {
195
+ 'features': torch.tensor(self.samples[idx]),
196
+ 'label': torch.tensor(self.labels[idx]),
197
+ 'c2_type': torch.tensor(self.c2_types[idx])
198
+ }
199
+
200
+
201
+ def train_model(num_epochs=100, batch_size=32, learning_rate=0.0001, num_samples=20000):
202
+ """Train with improved stability."""
203
+
204
+ print("=" * 70)
205
+ print("C2Sentinel Model Training v2")
206
+ print("=" * 70)
207
+
208
+ config = C2SentinelConfig()
209
+ model = LogBERTC2Sentinel(config)
210
+
211
+ # Initialize weights properly
212
+ def init_weights(m):
213
+ if isinstance(m, nn.Linear):
214
+ nn.init.xavier_uniform_(m.weight, gain=0.5)
215
+ if m.bias is not None:
216
+ nn.init.zeros_(m.bias)
217
+ model.apply(init_weights)
218
+
219
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
220
+ print(f"Device: {device}")
221
+ model.to(device)
222
+
223
+ # Count parameters
224
+ total_params = sum(p.numel() for p in model.parameters())
225
+ print(f"Model parameters: {total_params:,}")
226
+
227
+ dataset = C2TrafficDataset(num_samples=num_samples, normalize=True)
228
+
229
+ train_size = int(0.9 * len(dataset))
230
+ val_size = len(dataset) - train_size
231
+ train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
232
+
233
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
234
+ val_loader = DataLoader(val_dataset, batch_size=batch_size)
235
+
236
+ print(f"Train: {train_size}, Val: {val_size}")
237
+
238
+ # Simple BCE loss - focus on main task only
239
+ criterion = nn.BCEWithLogitsLoss()
240
+
241
+ # Lower LR with warmup
242
+ optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.001)
243
+
244
+ # Warmup + cosine decay
245
+ warmup_epochs = 5
246
+ def lr_lambda(epoch):
247
+ if epoch < warmup_epochs:
248
+ return (epoch + 1) / warmup_epochs
249
+ return 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (num_epochs - warmup_epochs)))
250
+
251
+ scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
252
+
253
+ best_val_acc = 0
254
+ patience = 15
255
+ patience_counter = 0
256
+
257
+ for epoch in range(num_epochs):
258
+ model.train()
259
+ train_loss = 0
260
+ train_correct = 0
261
+ train_total = 0
262
+
263
+ for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False):
264
+ features = batch['features'].to(device)
265
+ labels = batch['label'].to(device)
266
+
267
+ optimizer.zero_grad()
268
+ outputs = model(features)
269
+
270
+ # Only C2 detection loss
271
+ loss = criterion(outputs['c2_logits'].squeeze(), labels)
272
+
273
+ loss.backward()
274
+
275
+ # Gradient clipping
276
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
277
+
278
+ optimizer.step()
279
+
280
+ train_loss += loss.item()
281
+ predictions = (torch.sigmoid(outputs['c2_logits'].squeeze()) > 0.5).float()
282
+ train_correct += (predictions == labels).sum().item()
283
+ train_total += labels.size(0)
284
+
285
+ scheduler.step()
286
+
287
+ # Validation
288
+ model.eval()
289
+ val_correct = 0
290
+ val_total = 0
291
+ val_loss = 0
292
+
293
+ with torch.no_grad():
294
+ for batch in val_loader:
295
+ features = batch['features'].to(device)
296
+ labels = batch['label'].to(device)
297
+
298
+ outputs = model(features)
299
+ loss = criterion(outputs['c2_logits'].squeeze(), labels)
300
+ val_loss += loss.item()
301
+
302
+ predictions = (torch.sigmoid(outputs['c2_logits'].squeeze()) > 0.5).float()
303
+ val_correct += (predictions == labels).sum().item()
304
+ val_total += labels.size(0)
305
+
306
+ train_acc = 100 * train_correct / train_total
307
+ val_acc = 100 * val_correct / val_total
308
+ lr = optimizer.param_groups[0]['lr']
309
+
310
+ print(f"Epoch {epoch+1}: Loss={train_loss/len(train_loader):.4f}, "
311
+ f"Train={train_acc:.1f}%, Val={val_acc:.1f}%, LR={lr:.6f}")
312
+
313
+ if val_acc > best_val_acc:
314
+ best_val_acc = val_acc
315
+ patience_counter = 0
316
+ save_file(model.state_dict(), 'c2_sentinel.safetensors')
317
+ print(f" -> Saved (Val: {val_acc:.1f}%)")
318
+ else:
319
+ patience_counter += 1
320
+ if patience_counter >= patience:
321
+ print(f"Early stopping at epoch {epoch+1}")
322
+ break
323
+
324
+ print(f"\nBest validation accuracy: {best_val_acc:.1f}%")
325
+ return model, config
326
+
327
+
328
+ def test_model():
329
+ """Test the trained model."""
330
+ print("\n" + "=" * 70)
331
+ print("Testing Model")
332
+ print("=" * 70)
333
+
334
+ sentinel = C2Sentinel.load('c2_sentinel')
335
+
336
+ # Test 1: Cobalt Strike
337
+ print("\n[1] Cobalt Strike Beacon (60s interval)...")
338
+ cs = [{'timestamp': 1705600000 + i*60, 'dst_ip': '185.234.72.19', 'dst_port': 443,
339
+ 'bytes_sent': 92, 'bytes_recv': 48} for i in range(16)]
340
+ r = sentinel.analyze(cs)
341
+ print(f" {'βœ“ C2 DETECTED' if r.is_c2 else 'βœ— No C2'} (prob={r.c2_probability:.2%})")
342
+
343
+ # Test 2: Metasploit
344
+ print("\n[2] Metasploit Beacon (5s interval, port 4444)...")
345
+ msf = [{'timestamp': 1705600000 + i*5, 'dst_ip': '10.10.10.10', 'dst_port': 4444,
346
+ 'bytes_sent': 150, 'bytes_recv': 400} for i in range(20)]
347
+ r = sentinel.analyze(msf)
348
+ print(f" {'βœ“ C2 DETECTED' if r.is_c2 else 'βœ— No C2'} (prob={r.c2_probability:.2%})")
349
+
350
+ # Test 3: Slow APT beacon
351
+ print("\n[3] APT Slow Beacon (120s interval)...")
352
+ apt = [{'timestamp': 1705600000 + i*120, 'dst_ip': '45.33.32.156', 'dst_port': 443,
353
+ 'bytes_sent': 80, 'bytes_recv': 60} for i in range(12)]
354
+ r = sentinel.analyze(apt)
355
+ print(f" {'βœ“ C2 DETECTED' if r.is_c2 else 'βœ— No C2'} (prob={r.c2_probability:.2%})")
356
+
357
+ # Test 4: Web browsing (should be benign)
358
+ print("\n[4] Web Browsing (should be clean)...")
359
+ browse = [{'timestamp': 1705600000 + i*random.uniform(2, 30),
360
+ 'dst_ip': f"{random.randint(1,200)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}",
361
+ 'dst_port': 443, 'bytes_sent': random.randint(500, 3000),
362
+ 'bytes_recv': random.randint(10000, 500000)} for i in range(20)]
363
+ r = sentinel.analyze(browse)
364
+ print(f" {'βœ— C2 DETECTED (FP!)' if r.is_c2 else 'βœ“ Clean'} (prob={r.c2_probability:.2%})")
365
+
366
+ # Test 5: SSH keepalive
367
+ print("\n[5] SSH Keepalive (should be clean)...")
368
+ ssh = [{'timestamp': 1705600000 + i*30, 'dst_ip': '192.168.1.50', 'dst_port': 22,
369
+ 'bytes_sent': 48, 'bytes_recv': 48} for i in range(15)]
370
+ r = sentinel.analyze(ssh)
371
+ print(f" {'βœ— C2 DETECTED (FP!)' if r.is_c2 else 'βœ“ Clean'} (prob={r.c2_probability:.2%})")
372
+ print(f" Pattern: {r.matched_legitimate_pattern}")
373
+
374
+ # Test 6: API calls (should be benign)
375
+ print("\n[6] API Calls (should be clean)...")
376
+ api = [{'timestamp': 1705600000 + i*random.uniform(0.5, 10),
377
+ 'dst_ip': '52.85.132.99', 'dst_port': 443,
378
+ 'bytes_sent': random.randint(100, 2000),
379
+ 'bytes_recv': random.randint(500, 80000)} for i in range(25)]
380
+ r = sentinel.analyze(api)
381
+ print(f" {'βœ— C2 DETECTED (FP!)' if r.is_c2 else 'βœ“ Clean'} (prob={r.c2_probability:.2%})")
382
+
383
+
384
+ if __name__ == '__main__':
385
+ import argparse
386
+ parser = argparse.ArgumentParser()
387
+ parser.add_argument('--epochs', type=int, default=100)
388
+ parser.add_argument('--samples', type=int, default=20000)
389
+ parser.add_argument('--batch-size', type=int, default=32)
390
+ parser.add_argument('--lr', type=float, default=0.0001)
391
+ parser.add_argument('--test-only', action='store_true')
392
+ args = parser.parse_args()
393
+
394
+ if args.test_only:
395
+ test_model()
396
+ else:
397
+ train_model(num_epochs=args.epochs, batch_size=args.batch_size,
398
+ learning_rate=args.lr, num_samples=args.samples)
399
+ test_model()