danielostrow commited on
Commit
3ec3978
·
verified ·
1 Parent(s): 0207ad9

Remove training files from public view

Browse files
Files changed (1) hide show
  1. train_model.py +0 -399
train_model.py DELETED
@@ -1,399 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- C2Sentinel Training Script v2 - Improved training with proper normalization
4
- """
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.optim as optim
9
- from torch.utils.data import Dataset, DataLoader
10
- import numpy as np
11
- import random
12
- from tqdm import tqdm
13
- import json
14
-
15
- from c2sentinel import (
16
- C2Sentinel, C2SentinelConfig, LogBERTC2Sentinel,
17
- FeatureExtractor
18
- )
19
- from safetensors.torch import save_file
20
-
21
-
22
- class C2TrafficDataset(Dataset):
23
- """Dataset with normalized features."""
24
-
25
- def __init__(self, num_samples=10000, normalize=True):
26
- self.samples = []
27
- self.labels = []
28
- self.c2_types = []
29
- self.feature_extractor = FeatureExtractor()
30
-
31
- print(f"Generating {num_samples} training samples...")
32
-
33
- num_c2 = num_samples // 2
34
- num_benign = num_samples - num_c2
35
-
36
- # Generate C2 samples
37
- for _ in tqdm(range(num_c2), desc="C2 samples"):
38
- connections, c2_type = self._generate_c2_traffic()
39
- features = self.feature_extractor.extract_features(connections)
40
- self.samples.append(features)
41
- self.labels.append(1)
42
- self.c2_types.append(c2_type)
43
-
44
- # Generate benign samples
45
- for _ in tqdm(range(num_benign), desc="Benign samples"):
46
- connections = self._generate_benign_traffic()
47
- features = self.feature_extractor.extract_features(connections)
48
- self.samples.append(features)
49
- self.labels.append(0)
50
- self.c2_types.append(0)
51
-
52
- self.samples = np.array(self.samples, dtype=np.float32)
53
- self.labels = np.array(self.labels, dtype=np.float32)
54
- self.c2_types = np.array(self.c2_types, dtype=np.int64)
55
-
56
- # Normalize features (critical for training stability)
57
- if normalize:
58
- self.mean = np.mean(self.samples, axis=0)
59
- self.std = np.std(self.samples, axis=0) + 1e-8
60
- self.samples = (self.samples - self.mean) / self.std
61
-
62
- # Save normalization params
63
- np.savez('normalization_params.npz', mean=self.mean, std=self.std)
64
- print(f"Feature stats - mean range: [{self.mean.min():.2f}, {self.mean.max():.2f}], "
65
- f"std range: [{self.std.min():.4f}, {self.std.max():.2f}]")
66
-
67
- # Shuffle
68
- indices = np.random.permutation(len(self.samples))
69
- self.samples = self.samples[indices]
70
- self.labels = self.labels[indices]
71
- self.c2_types = self.c2_types[indices]
72
-
73
- print(f"C2 samples: {np.sum(self.labels)}, Benign: {len(self.labels) - np.sum(self.labels)}")
74
-
75
- def _generate_c2_traffic(self):
76
- """Generate C2 beacon traffic with clear patterns."""
77
- c2_type = random.randint(1, 10)
78
-
79
- # Strong C2 characteristics
80
- if c2_type <= 3: # Fast beacon (Metasploit-style)
81
- interval = random.uniform(2, 15)
82
- jitter = random.uniform(0, 0.15) # Low jitter
83
- port = random.choice([4444, 4445, 5555, 443])
84
- bytes_sent = random.randint(80, 200)
85
- bytes_recv = random.randint(40, 150)
86
- elif c2_type <= 6: # Medium beacon (Cobalt Strike-style)
87
- interval = random.uniform(30, 90)
88
- jitter = random.uniform(0, 0.2)
89
- port = 443
90
- bytes_sent = random.randint(60, 150)
91
- bytes_recv = random.randint(40, 100)
92
- else: # Slow beacon (APT-style)
93
- interval = random.uniform(120, 300)
94
- jitter = random.uniform(0, 0.1) # Very low jitter for APT
95
- port = 443
96
- bytes_sent = random.randint(50, 120)
97
- bytes_recv = random.randint(40, 80)
98
-
99
- # Single destination (key C2 indicator)
100
- dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
101
- num_connections = random.randint(10, 40)
102
-
103
- connections = []
104
- timestamp = 1705600000
105
-
106
- for _ in range(num_connections):
107
- actual_interval = interval * (1 + random.uniform(-jitter, jitter))
108
- timestamp += actual_interval
109
-
110
- # Very consistent sizes (key C2 indicator)
111
- size_var = random.uniform(0.95, 1.05)
112
-
113
- connections.append({
114
- 'timestamp': timestamp,
115
- 'dst_ip': dst_ip,
116
- 'dst_port': port,
117
- 'bytes_sent': int(bytes_sent * size_var),
118
- 'bytes_recv': int(bytes_recv * size_var),
119
- 'protocol': 'tcp'
120
- })
121
-
122
- return connections, c2_type
123
-
124
- def _generate_benign_traffic(self):
125
- """Generate clearly benign traffic."""
126
- pattern = random.choice(['browsing', 'api', 'streaming', 'interactive'])
127
-
128
- connections = []
129
- timestamp = 1705600000
130
-
131
- if pattern == 'browsing':
132
- # Multiple destinations, highly variable sizes
133
- for _ in range(random.randint(10, 40)):
134
- timestamp += random.uniform(0.5, 45)
135
- connections.append({
136
- 'timestamp': timestamp,
137
- 'dst_ip': f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}",
138
- 'dst_port': random.choice([80, 443]),
139
- 'bytes_sent': random.randint(200, 5000),
140
- 'bytes_recv': random.randint(5000, 500000),
141
- 'protocol': 'tcp'
142
- })
143
-
144
- elif pattern == 'api':
145
- # Single dest but HIGHLY variable response sizes
146
- dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
147
- for _ in range(random.randint(15, 40)):
148
- timestamp += random.uniform(0.1, 20)
149
- connections.append({
150
- 'timestamp': timestamp,
151
- 'dst_ip': dst_ip,
152
- 'dst_port': 443,
153
- 'bytes_sent': random.randint(100, 3000),
154
- 'bytes_recv': random.randint(200, 100000), # Highly variable
155
- 'protocol': 'tcp'
156
- })
157
-
158
- elif pattern == 'streaming':
159
- # Large downloads, irregular timing
160
- dst_ip = f"{random.randint(1,223)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
161
- for _ in range(random.randint(20, 60)):
162
- timestamp += random.uniform(0.05, 3)
163
- connections.append({
164
- 'timestamp': timestamp,
165
- 'dst_ip': dst_ip,
166
- 'dst_port': 443,
167
- 'bytes_sent': random.randint(30, 200),
168
- 'bytes_recv': random.randint(5000, 150000),
169
- 'protocol': 'tcp'
170
- })
171
-
172
- else: # interactive (ssh-like)
173
- dst_ip = f"192.168.{random.randint(0,255)}.{random.randint(1,254)}"
174
- for _ in range(random.randint(15, 50)):
175
- if random.random() < 0.3:
176
- timestamp += random.uniform(3, 45) # Thinking
177
- else:
178
- timestamp += random.uniform(0.1, 2) # Typing
179
- connections.append({
180
- 'timestamp': timestamp,
181
- 'dst_ip': dst_ip,
182
- 'dst_port': 22,
183
- 'bytes_sent': random.randint(20, 800),
184
- 'bytes_recv': random.randint(50, 20000),
185
- 'protocol': 'tcp'
186
- })
187
-
188
- return connections
189
-
190
- def __len__(self):
191
- return len(self.samples)
192
-
193
- def __getitem__(self, idx):
194
- return {
195
- 'features': torch.tensor(self.samples[idx]),
196
- 'label': torch.tensor(self.labels[idx]),
197
- 'c2_type': torch.tensor(self.c2_types[idx])
198
- }
199
-
200
-
201
- def train_model(num_epochs=100, batch_size=32, learning_rate=0.0001, num_samples=20000):
202
- """Train with improved stability."""
203
-
204
- print("=" * 70)
205
- print("C2Sentinel Model Training v2")
206
- print("=" * 70)
207
-
208
- config = C2SentinelConfig()
209
- model = LogBERTC2Sentinel(config)
210
-
211
- # Initialize weights properly
212
- def init_weights(m):
213
- if isinstance(m, nn.Linear):
214
- nn.init.xavier_uniform_(m.weight, gain=0.5)
215
- if m.bias is not None:
216
- nn.init.zeros_(m.bias)
217
- model.apply(init_weights)
218
-
219
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
220
- print(f"Device: {device}")
221
- model.to(device)
222
-
223
- # Count parameters
224
- total_params = sum(p.numel() for p in model.parameters())
225
- print(f"Model parameters: {total_params:,}")
226
-
227
- dataset = C2TrafficDataset(num_samples=num_samples, normalize=True)
228
-
229
- train_size = int(0.9 * len(dataset))
230
- val_size = len(dataset) - train_size
231
- train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
232
-
233
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
234
- val_loader = DataLoader(val_dataset, batch_size=batch_size)
235
-
236
- print(f"Train: {train_size}, Val: {val_size}")
237
-
238
- # Simple BCE loss - focus on main task only
239
- criterion = nn.BCEWithLogitsLoss()
240
-
241
- # Lower LR with warmup
242
- optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.001)
243
-
244
- # Warmup + cosine decay
245
- warmup_epochs = 5
246
- def lr_lambda(epoch):
247
- if epoch < warmup_epochs:
248
- return (epoch + 1) / warmup_epochs
249
- return 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (num_epochs - warmup_epochs)))
250
-
251
- scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
252
-
253
- best_val_acc = 0
254
- patience = 15
255
- patience_counter = 0
256
-
257
- for epoch in range(num_epochs):
258
- model.train()
259
- train_loss = 0
260
- train_correct = 0
261
- train_total = 0
262
-
263
- for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False):
264
- features = batch['features'].to(device)
265
- labels = batch['label'].to(device)
266
-
267
- optimizer.zero_grad()
268
- outputs = model(features)
269
-
270
- # Only C2 detection loss
271
- loss = criterion(outputs['c2_logits'].squeeze(), labels)
272
-
273
- loss.backward()
274
-
275
- # Gradient clipping
276
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
277
-
278
- optimizer.step()
279
-
280
- train_loss += loss.item()
281
- predictions = (torch.sigmoid(outputs['c2_logits'].squeeze()) > 0.5).float()
282
- train_correct += (predictions == labels).sum().item()
283
- train_total += labels.size(0)
284
-
285
- scheduler.step()
286
-
287
- # Validation
288
- model.eval()
289
- val_correct = 0
290
- val_total = 0
291
- val_loss = 0
292
-
293
- with torch.no_grad():
294
- for batch in val_loader:
295
- features = batch['features'].to(device)
296
- labels = batch['label'].to(device)
297
-
298
- outputs = model(features)
299
- loss = criterion(outputs['c2_logits'].squeeze(), labels)
300
- val_loss += loss.item()
301
-
302
- predictions = (torch.sigmoid(outputs['c2_logits'].squeeze()) > 0.5).float()
303
- val_correct += (predictions == labels).sum().item()
304
- val_total += labels.size(0)
305
-
306
- train_acc = 100 * train_correct / train_total
307
- val_acc = 100 * val_correct / val_total
308
- lr = optimizer.param_groups[0]['lr']
309
-
310
- print(f"Epoch {epoch+1}: Loss={train_loss/len(train_loader):.4f}, "
311
- f"Train={train_acc:.1f}%, Val={val_acc:.1f}%, LR={lr:.6f}")
312
-
313
- if val_acc > best_val_acc:
314
- best_val_acc = val_acc
315
- patience_counter = 0
316
- save_file(model.state_dict(), 'c2_sentinel.safetensors')
317
- print(f" -> Saved (Val: {val_acc:.1f}%)")
318
- else:
319
- patience_counter += 1
320
- if patience_counter >= patience:
321
- print(f"Early stopping at epoch {epoch+1}")
322
- break
323
-
324
- print(f"\nBest validation accuracy: {best_val_acc:.1f}%")
325
- return model, config
326
-
327
-
328
- def test_model():
329
- """Test the trained model."""
330
- print("\n" + "=" * 70)
331
- print("Testing Model")
332
- print("=" * 70)
333
-
334
- sentinel = C2Sentinel.load('c2_sentinel')
335
-
336
- # Test 1: Cobalt Strike
337
- print("\n[1] Cobalt Strike Beacon (60s interval)...")
338
- cs = [{'timestamp': 1705600000 + i*60, 'dst_ip': '185.234.72.19', 'dst_port': 443,
339
- 'bytes_sent': 92, 'bytes_recv': 48} for i in range(16)]
340
- r = sentinel.analyze(cs)
341
- print(f" {'✓ C2 DETECTED' if r.is_c2 else '✗ No C2'} (prob={r.c2_probability:.2%})")
342
-
343
- # Test 2: Metasploit
344
- print("\n[2] Metasploit Beacon (5s interval, port 4444)...")
345
- msf = [{'timestamp': 1705600000 + i*5, 'dst_ip': '10.10.10.10', 'dst_port': 4444,
346
- 'bytes_sent': 150, 'bytes_recv': 400} for i in range(20)]
347
- r = sentinel.analyze(msf)
348
- print(f" {'✓ C2 DETECTED' if r.is_c2 else '✗ No C2'} (prob={r.c2_probability:.2%})")
349
-
350
- # Test 3: Slow APT beacon
351
- print("\n[3] APT Slow Beacon (120s interval)...")
352
- apt = [{'timestamp': 1705600000 + i*120, 'dst_ip': '45.33.32.156', 'dst_port': 443,
353
- 'bytes_sent': 80, 'bytes_recv': 60} for i in range(12)]
354
- r = sentinel.analyze(apt)
355
- print(f" {'✓ C2 DETECTED' if r.is_c2 else '✗ No C2'} (prob={r.c2_probability:.2%})")
356
-
357
- # Test 4: Web browsing (should be benign)
358
- print("\n[4] Web Browsing (should be clean)...")
359
- browse = [{'timestamp': 1705600000 + i*random.uniform(2, 30),
360
- 'dst_ip': f"{random.randint(1,200)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}",
361
- 'dst_port': 443, 'bytes_sent': random.randint(500, 3000),
362
- 'bytes_recv': random.randint(10000, 500000)} for i in range(20)]
363
- r = sentinel.analyze(browse)
364
- print(f" {'✗ C2 DETECTED (FP!)' if r.is_c2 else '✓ Clean'} (prob={r.c2_probability:.2%})")
365
-
366
- # Test 5: SSH keepalive
367
- print("\n[5] SSH Keepalive (should be clean)...")
368
- ssh = [{'timestamp': 1705600000 + i*30, 'dst_ip': '192.168.1.50', 'dst_port': 22,
369
- 'bytes_sent': 48, 'bytes_recv': 48} for i in range(15)]
370
- r = sentinel.analyze(ssh)
371
- print(f" {'✗ C2 DETECTED (FP!)' if r.is_c2 else '✓ Clean'} (prob={r.c2_probability:.2%})")
372
- print(f" Pattern: {r.matched_legitimate_pattern}")
373
-
374
- # Test 6: API calls (should be benign)
375
- print("\n[6] API Calls (should be clean)...")
376
- api = [{'timestamp': 1705600000 + i*random.uniform(0.5, 10),
377
- 'dst_ip': '52.85.132.99', 'dst_port': 443,
378
- 'bytes_sent': random.randint(100, 2000),
379
- 'bytes_recv': random.randint(500, 80000)} for i in range(25)]
380
- r = sentinel.analyze(api)
381
- print(f" {'✗ C2 DETECTED (FP!)' if r.is_c2 else '✓ Clean'} (prob={r.c2_probability:.2%})")
382
-
383
-
384
- if __name__ == '__main__':
385
- import argparse
386
- parser = argparse.ArgumentParser()
387
- parser.add_argument('--epochs', type=int, default=100)
388
- parser.add_argument('--samples', type=int, default=20000)
389
- parser.add_argument('--batch-size', type=int, default=32)
390
- parser.add_argument('--lr', type=float, default=0.0001)
391
- parser.add_argument('--test-only', action='store_true')
392
- args = parser.parse_args()
393
-
394
- if args.test_only:
395
- test_model()
396
- else:
397
- train_model(num_epochs=args.epochs, batch_size=args.batch_size,
398
- learning_rate=args.lr, num_samples=args.samples)
399
- test_model()