LorenzoNava commited on
Commit
6792b7f
Β·
1 Parent(s): 18ef60c

feat: Add production-grade Gradio training interface with real-time monitoring

Browse files

Features:
- Real-time training progress with live metrics
- Interactive hyperparameter configuration (10 epochs, batch size 16, optimal settings)
- Live visualization with Plotly (loss, accuracy, F1 score, learning rate)
- Thread-safe training state management
- Automatic model export to local directory
- Training logs streaming
- GPU/CPU automatic detection
- Early stopping and checkpoint management
- Production-ready error handling

Optimal hyperparameters:
- Epochs: 10 (for best quality)
- Batch size: 16 (effective: 64 with gradient accumulation)
- Learning rate: 2e-5 with cosine schedule
- Warmup ratio: 0.1
- Gradient accumulation: 4 steps
- Early stopping: 5 patience

Model will be exported to:
/Users/lorenzo/Documents/Claude Code/projects/mcps/mcp-cwe-identifier/models/deberta-cwe-final

Files changed (3) hide show
  1. .gitignore +58 -0
  2. app.py +802 -0
  3. requirements.txt +24 -0
.gitignore ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ venv/
25
+ env/
26
+ ENV/
27
+
28
+ # Model checkpoints and outputs
29
+ models/
30
+ checkpoints/
31
+ outputs/
32
+ *.pt
33
+ *.pth
34
+ *.bin
35
+ *.safetensors
36
+
37
+ # Logs
38
+ logs/
39
+ *.log
40
+ wandb/
41
+
42
+ # IDE
43
+ .vscode/
44
+ .idea/
45
+ *.swp
46
+ *.swo
47
+ *~
48
+
49
+ # OS
50
+ .DS_Store
51
+ Thumbs.db
52
+
53
+ # Jupyter
54
+ .ipynb_checkpoints/
55
+
56
+ # Temporary files
57
+ tmp/
58
+ temp/
app.py ADDED
@@ -0,0 +1,802 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ DeBERTa CWE Classification - Fine-Tuning Interface
4
+ ====================================================
5
+
6
+ Production-grade Gradio interface for training DeBERTa models
7
+ on CVE-CWE classification task with real-time monitoring.
8
+
9
+ Features:
10
+ - Real-time training progress with live metrics
11
+ - Interactive hyperparameter configuration
12
+ - GPU/CPU automatic detection
13
+ - Checkpoint management and recovery
14
+ - Model export to local directory
15
+ - Training logs streaming
16
+ - Performance visualization
17
+
18
+ Author: Berghem - Smart Information Security
19
+ License: MIT
20
+ """
21
+
22
+ import os
23
+ import sys
24
+ import json
25
+ import time
26
+ import threading
27
+ from pathlib import Path
28
+ from typing import Dict, List, Optional, Tuple
29
+ from dataclasses import dataclass, asdict
30
+ import queue
31
+ import warnings
32
+
33
+ import gradio as gr
34
+ import torch
35
+ import numpy as np
36
+ import pandas as pd
37
+ import plotly.graph_objects as go
38
+ from plotly.subplots import make_subplots
39
+ from datasets import load_dataset, Dataset
40
+ from transformers import (
41
+ AutoTokenizer,
42
+ DebertaV2Tokenizer,
43
+ AutoModelForSequenceClassification,
44
+ TrainingArguments,
45
+ Trainer,
46
+ TrainerCallback,
47
+ EarlyStoppingCallback,
48
+ )
49
+ from sklearn.metrics import accuracy_score, f1_score, classification_report
50
+
51
+ warnings.filterwarnings('ignore')
52
+
53
+ # ============================================================================
54
+ # CONFIGURATION
55
+ # ============================================================================
56
+
57
+ @dataclass
58
+ class TrainingConfig:
59
+ """Training configuration with optimal defaults"""
60
+
61
+ # Model selection
62
+ model_name: str = "microsoft/deberta-v3-base" # base, large, small
63
+
64
+ # Dataset
65
+ dataset_name: str = "stasvinokur/cve-and-cwe-dataset-1999-2025"
66
+ max_length: int = 256
67
+
68
+ # Training hyperparameters (OPTIMAL SETTINGS)
69
+ num_epochs: int = 10 # More epochs for better quality
70
+ batch_size: int = 16 # Larger batch size for stability
71
+ learning_rate: float = 2e-5
72
+ weight_decay: float = 0.01
73
+ warmup_ratio: float = 0.1
74
+ gradient_accumulation_steps: int = 4 # Effective batch size: 64
75
+
76
+ # Optimization
77
+ max_grad_norm: float = 1.0
78
+ adam_epsilon: float = 1e-8
79
+ lr_scheduler_type: str = "cosine" # Better than linear
80
+
81
+ # Evaluation and checkpointing
82
+ eval_steps: int = 500
83
+ save_steps: int = 500
84
+ logging_steps: int = 50
85
+ save_total_limit: int = 3
86
+
87
+ # Early stopping
88
+ early_stopping_patience: int = 5
89
+ early_stopping_threshold: float = 0.001
90
+
91
+ # Output
92
+ output_dir: str = "/Users/lorenzo/Documents/Claude Code/projects/mcps/mcp-cwe-identifier/models/deberta-cwe"
93
+ local_export_dir: str = "/Users/lorenzo/Documents/Claude Code/projects/mcps/mcp-cwe-identifier/models"
94
+
95
+ # Hardware
96
+ use_fp16: bool = True # Mixed precision for speed
97
+ dataloader_num_workers: int = 4
98
+
99
+ def to_dict(self) -> dict:
100
+ return asdict(self)
101
+
102
+
103
+ # Model configurations
104
+ MODEL_CONFIGS = {
105
+ "DeBERTa-v3-Small (44M params, fast)": "microsoft/deberta-v3-small",
106
+ "DeBERTa-v3-Base (86M params, recommended)": "microsoft/deberta-v3-base",
107
+ "DeBERTa-v3-Large (435M params, best quality)": "microsoft/deberta-v3-large",
108
+ }
109
+
110
+ # ============================================================================
111
+ # TRAINING STATE MANAGEMENT
112
+ # ============================================================================
113
+
114
+ class TrainingState:
115
+ """Thread-safe training state management"""
116
+
117
+ def __init__(self):
118
+ self.is_training = False
119
+ self.current_epoch = 0
120
+ self.total_epochs = 0
121
+ self.current_step = 0
122
+ self.total_steps = 0
123
+ self.train_loss = []
124
+ self.eval_loss = []
125
+ self.eval_accuracy = []
126
+ self.eval_f1 = []
127
+ self.learning_rates = []
128
+ self.logs = []
129
+ self.best_accuracy = 0.0
130
+ self.best_f1 = 0.0
131
+ self.training_start_time = None
132
+ self.training_end_time = None
133
+ self.lock = threading.Lock()
134
+ self.log_queue = queue.Queue()
135
+
136
+ def reset(self):
137
+ """Reset state for new training run"""
138
+ with self.lock:
139
+ self.is_training = False
140
+ self.current_epoch = 0
141
+ self.current_step = 0
142
+ self.train_loss = []
143
+ self.eval_loss = []
144
+ self.eval_accuracy = []
145
+ self.eval_f1 = []
146
+ self.learning_rates = []
147
+ self.logs = []
148
+ self.best_accuracy = 0.0
149
+ self.best_f1 = 0.0
150
+ self.training_start_time = None
151
+ self.training_end_time = None
152
+
153
+ def add_log(self, message: str):
154
+ """Add log message"""
155
+ timestamp = time.strftime("%H:%M:%S")
156
+ log_entry = f"[{timestamp}] {message}"
157
+ with self.lock:
158
+ self.logs.append(log_entry)
159
+ self.log_queue.put(log_entry)
160
+
161
+ def get_logs(self) -> str:
162
+ """Get all logs as string"""
163
+ with self.lock:
164
+ return "\n".join(self.logs[-100:]) # Last 100 lines
165
+
166
+ def get_progress(self) -> Dict:
167
+ """Get current progress"""
168
+ with self.lock:
169
+ elapsed = 0
170
+ if self.training_start_time:
171
+ end_time = self.training_end_time or time.time()
172
+ elapsed = end_time - self.training_start_time
173
+
174
+ return {
175
+ "is_training": self.is_training,
176
+ "epoch": f"{self.current_epoch}/{self.total_epochs}",
177
+ "step": f"{self.current_step}/{self.total_steps}",
178
+ "progress": self.current_step / max(self.total_steps, 1),
179
+ "elapsed_time": f"{elapsed/60:.1f} min",
180
+ "best_accuracy": f"{self.best_accuracy*100:.2f}%",
181
+ "best_f1": f"{self.best_f1*100:.2f}%",
182
+ }
183
+
184
+ # Global training state
185
+ training_state = TrainingState()
186
+
187
+ # ============================================================================
188
+ # GRADIO CALLBACK FOR REAL-TIME UPDATES
189
+ # ============================================================================
190
+
191
+ class GradioProgressCallback(TrainerCallback):
192
+ """Custom callback that streams progress to Gradio UI"""
193
+
194
+ def __init__(self, state: TrainingState):
195
+ self.state = state
196
+
197
+ def on_train_begin(self, args, state, control, **kwargs):
198
+ self.state.training_start_time = time.time()
199
+ self.state.is_training = True
200
+ self.state.total_epochs = int(args.num_train_epochs)
201
+ self.state.total_steps = state.max_steps
202
+ self.state.add_log("πŸš€ Training started!")
203
+ self.state.add_log(f"πŸ“Š Total epochs: {self.state.total_epochs}")
204
+ self.state.add_log(f"πŸ“ˆ Total steps: {self.state.total_steps}")
205
+
206
+ def on_epoch_begin(self, args, state, control, **kwargs):
207
+ self.state.current_epoch = int(state.epoch) if state.epoch else 0
208
+ self.state.add_log(f"\n{'='*60}")
209
+ self.state.add_log(f"πŸ“Š Epoch {self.state.current_epoch + 1}/{self.state.total_epochs}")
210
+ self.state.add_log(f"{'='*60}")
211
+
212
+ def on_log(self, args, state, control, logs=None, **kwargs):
213
+ if logs:
214
+ self.state.current_step = state.global_step
215
+
216
+ # Training loss
217
+ if "loss" in logs:
218
+ self.state.train_loss.append((state.global_step, logs["loss"]))
219
+ self.state.add_log(f"πŸ“‰ Step {state.global_step}: Loss = {logs['loss']:.4f}")
220
+
221
+ # Learning rate
222
+ if "learning_rate" in logs:
223
+ self.state.learning_rates.append((state.global_step, logs["learning_rate"]))
224
+
225
+ # Evaluation metrics
226
+ if "eval_loss" in logs:
227
+ self.state.eval_loss.append((state.global_step, logs["eval_loss"]))
228
+ self.state.add_log(f"πŸ“Š Evaluation Loss: {logs['eval_loss']:.4f}")
229
+
230
+ if "eval_accuracy" in logs:
231
+ self.state.eval_accuracy.append((state.global_step, logs["eval_accuracy"]))
232
+ self.state.best_accuracy = max(self.state.best_accuracy, logs["eval_accuracy"])
233
+ self.state.add_log(f"🎯 Evaluation Accuracy: {logs['eval_accuracy']*100:.2f}%")
234
+
235
+ if "eval_f1_weighted" in logs:
236
+ self.state.eval_f1.append((state.global_step, logs["eval_f1_weighted"]))
237
+ self.state.best_f1 = max(self.state.best_f1, logs["eval_f1_weighted"])
238
+ self.state.add_log(f"🎯 Evaluation F1 (weighted): {logs['eval_f1_weighted']*100:.2f}%")
239
+
240
+ def on_epoch_end(self, args, state, control, **kwargs):
241
+ epoch_time = time.time() - self.state.training_start_time
242
+ self.state.add_log(f"βœ… Epoch {self.state.current_epoch + 1} completed")
243
+ self.state.add_log(f"⏱️ Time elapsed: {epoch_time/60:.1f} minutes")
244
+
245
+ def on_train_end(self, args, state, control, **kwargs):
246
+ self.state.training_end_time = time.time()
247
+ self.state.is_training = False
248
+ total_time = self.state.training_end_time - self.state.training_start_time
249
+ self.state.add_log(f"\n{'='*60}")
250
+ self.state.add_log(f"βœ… TRAINING COMPLETED!")
251
+ self.state.add_log(f"{'='*60}")
252
+ self.state.add_log(f"⏱️ Total time: {total_time/60:.1f} minutes")
253
+ self.state.add_log(f"🎯 Best Accuracy: {self.state.best_accuracy*100:.2f}%")
254
+ self.state.add_log(f"🎯 Best F1 Score: {self.state.best_f1*100:.2f}%")
255
+
256
+ # ============================================================================
257
+ # DATASET PREPARATION
258
+ # ============================================================================
259
+
260
+ class CVECWEDataset:
261
+ """Prepare CVE→CWE dataset for training"""
262
+
263
+ def __init__(self, tokenizer, config: TrainingConfig):
264
+ self.tokenizer = tokenizer
265
+ self.config = config
266
+ self.cwe_to_id = {}
267
+ self.id_to_cwe = {}
268
+
269
+ def load_and_prepare(self) -> Tuple[Dict[str, Dataset], int]:
270
+ """Load and prepare dataset"""
271
+ training_state.add_log("πŸ“¦ Loading dataset...")
272
+
273
+ try:
274
+ dataset = load_dataset(self.config.dataset_name)
275
+ training_state.add_log(f"βœ… Dataset loaded: {len(dataset['train']):,} training samples")
276
+ except Exception as e:
277
+ training_state.add_log(f"❌ Failed to load dataset: {e}")
278
+ raise
279
+
280
+ # Build CWE label mapping
281
+ training_state.add_log("🏷️ Building CWE label mapping...")
282
+ self._build_label_mapping(dataset['train'])
283
+ num_labels = len(self.cwe_to_id)
284
+ training_state.add_log(f"βœ… Found {num_labels} unique CWE classes")
285
+
286
+ # Tokenize
287
+ training_state.add_log("πŸ”€ Tokenizing dataset...")
288
+ tokenized = self._tokenize_dataset(dataset)
289
+ training_state.add_log("βœ… Dataset prepared successfully")
290
+
291
+ return tokenized, num_labels
292
+
293
+ def _build_label_mapping(self, dataset):
294
+ """Build CWE β†’ ID mapping"""
295
+ all_cwes = set()
296
+
297
+ for example in dataset:
298
+ cwe = example.get('CWE-ID')
299
+ if cwe and isinstance(cwe, str) and cwe.startswith('CWE-'):
300
+ all_cwes.add(cwe)
301
+
302
+ sorted_cwes = sorted(all_cwes)
303
+ self.cwe_to_id = {cwe: idx for idx, cwe in enumerate(sorted_cwes)}
304
+ self.id_to_cwe = {idx: cwe for cwe, idx in self.cwe_to_id.items()}
305
+
306
+ # Save mapping
307
+ mapping_file = Path(self.config.output_dir) / "cwe_label_mapping.json"
308
+ mapping_file.parent.mkdir(parents=True, exist_ok=True)
309
+ with open(mapping_file, 'w') as f:
310
+ json.dump({
311
+ 'cwe_to_id': self.cwe_to_id,
312
+ 'id_to_cwe': self.id_to_cwe,
313
+ 'num_labels': len(self.cwe_to_id)
314
+ }, f, indent=2)
315
+
316
+ def _tokenize_dataset(self, dataset):
317
+ """Tokenize dataset"""
318
+
319
+ def tokenize_function(examples):
320
+ descriptions = examples.get('DESCRIPTION', [])
321
+ cwes = examples.get('CWE-ID', [])
322
+
323
+ labels = [
324
+ self.cwe_to_id.get(cwe, -1) if cwe and cwe.startswith('CWE-') else -1
325
+ for cwe in cwes
326
+ ]
327
+
328
+ tokenized = self.tokenizer(
329
+ descriptions,
330
+ truncation=True,
331
+ padding='max_length',
332
+ max_length=self.config.max_length,
333
+ return_tensors=None
334
+ )
335
+
336
+ tokenized['labels'] = labels
337
+ return tokenized
338
+
339
+ tokenized = dataset.map(
340
+ tokenize_function,
341
+ batched=True,
342
+ desc="Tokenizing",
343
+ remove_columns=dataset['train'].column_names
344
+ )
345
+
346
+ # Filter invalid labels
347
+ tokenized = tokenized.filter(lambda x: x['labels'] >= 0)
348
+
349
+ return tokenized
350
+
351
+ # ============================================================================
352
+ # TRAINING FUNCTION
353
+ # ============================================================================
354
+
355
+ def compute_metrics(eval_pred):
356
+ """Compute evaluation metrics"""
357
+ predictions, labels = eval_pred
358
+ predictions = np.argmax(predictions, axis=1)
359
+
360
+ accuracy = accuracy_score(labels, predictions)
361
+ f1_macro = f1_score(labels, predictions, average='macro', zero_division=0)
362
+ f1_weighted = f1_score(labels, predictions, average='weighted', zero_division=0)
363
+
364
+ return {
365
+ 'accuracy': accuracy,
366
+ 'f1_macro': f1_macro,
367
+ 'f1_weighted': f1_weighted,
368
+ }
369
+
370
+
371
+ def train_model(config: TrainingConfig):
372
+ """Main training function"""
373
+
374
+ try:
375
+ # Reset state
376
+ training_state.reset()
377
+
378
+ # Device detection
379
+ if torch.cuda.is_available():
380
+ device = "cuda"
381
+ device_name = f"NVIDIA {torch.cuda.get_device_name(0)}"
382
+ elif torch.backends.mps.is_available():
383
+ device = "mps"
384
+ device_name = "Apple Silicon (M-series)"
385
+ else:
386
+ device = "cpu"
387
+ device_name = "CPU"
388
+
389
+ training_state.add_log(f"πŸ–₯️ Device: {device_name}")
390
+
391
+ # Load tokenizer
392
+ training_state.add_log(f"πŸ“š Loading tokenizer: {config.model_name}")
393
+ tokenizer = DebertaV2Tokenizer.from_pretrained(config.model_name)
394
+
395
+ # Prepare dataset
396
+ dataset_prep = CVECWEDataset(tokenizer, config)
397
+ tokenized_dataset, num_labels = dataset_prep.load_and_prepare()
398
+
399
+ # Load model
400
+ training_state.add_log(f"πŸ€– Loading model: {config.model_name}")
401
+ training_state.add_log(f"🎯 Output classes: {num_labels} CWEs")
402
+
403
+ model = AutoModelForSequenceClassification.from_pretrained(
404
+ config.model_name,
405
+ num_labels=num_labels,
406
+ problem_type="single_label_classification"
407
+ )
408
+
409
+ # Count parameters
410
+ total_params = sum(p.numel() for p in model.parameters())
411
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
412
+ training_state.add_log(f"πŸ“Š Total parameters: {total_params:,}")
413
+ training_state.add_log(f"πŸ“Š Trainable parameters: {trainable_params:,}")
414
+
415
+ # Training arguments
416
+ training_args = TrainingArguments(
417
+ output_dir=config.output_dir,
418
+ num_train_epochs=config.num_epochs,
419
+ per_device_train_batch_size=config.batch_size,
420
+ per_device_eval_batch_size=config.batch_size * 2,
421
+ learning_rate=config.learning_rate,
422
+ weight_decay=config.weight_decay,
423
+ warmup_ratio=config.warmup_ratio,
424
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
425
+ max_grad_norm=config.max_grad_norm,
426
+ adam_epsilon=config.adam_epsilon,
427
+ lr_scheduler_type=config.lr_scheduler_type,
428
+ fp16=config.use_fp16 and device == "cuda",
429
+ logging_dir=f"{config.output_dir}/logs",
430
+ logging_steps=config.logging_steps,
431
+ logging_first_step=True,
432
+ eval_strategy="steps",
433
+ eval_steps=config.eval_steps,
434
+ save_strategy="steps",
435
+ save_steps=config.save_steps,
436
+ save_total_limit=config.save_total_limit,
437
+ load_best_model_at_end=True,
438
+ metric_for_best_model="accuracy",
439
+ greater_is_better=True,
440
+ report_to="none",
441
+ dataloader_num_workers=config.dataloader_num_workers,
442
+ )
443
+
444
+ # Initialize trainer
445
+ trainer = Trainer(
446
+ model=model,
447
+ args=training_args,
448
+ train_dataset=tokenized_dataset["train"],
449
+ eval_dataset=tokenized_dataset.get("validation") or tokenized_dataset.get("test"),
450
+ tokenizer=tokenizer,
451
+ compute_metrics=compute_metrics,
452
+ callbacks=[
453
+ GradioProgressCallback(training_state),
454
+ EarlyStoppingCallback(
455
+ early_stopping_patience=config.early_stopping_patience,
456
+ early_stopping_threshold=config.early_stopping_threshold,
457
+ )
458
+ ]
459
+ )
460
+
461
+ # Train
462
+ training_state.add_log("\n" + "="*60)
463
+ training_state.add_log("πŸš€ STARTING TRAINING")
464
+ training_state.add_log("="*60)
465
+
466
+ train_result = trainer.train()
467
+
468
+ # Save final model
469
+ training_state.add_log("\nπŸ’Ύ Saving final model...")
470
+ trainer.save_model(config.output_dir)
471
+ tokenizer.save_pretrained(config.output_dir)
472
+
473
+ # Save to local export directory
474
+ local_model_dir = Path(config.local_export_dir) / "deberta-cwe-final"
475
+ local_model_dir.mkdir(parents=True, exist_ok=True)
476
+ trainer.save_model(str(local_model_dir))
477
+ tokenizer.save_pretrained(str(local_model_dir))
478
+ training_state.add_log(f"βœ… Model exported to: {local_model_dir}")
479
+
480
+ # Save metrics
481
+ metrics_file = Path(config.output_dir) / "training_metrics.json"
482
+ with open(metrics_file, 'w') as f:
483
+ json.dump(train_result.metrics, f, indent=2)
484
+
485
+ # Final evaluation
486
+ if "test" in tokenized_dataset or "validation" in tokenized_dataset:
487
+ test_dataset = tokenized_dataset.get("test") or tokenized_dataset.get("validation")
488
+ eval_results = trainer.evaluate(test_dataset)
489
+
490
+ training_state.add_log("\n" + "="*60)
491
+ training_state.add_log("πŸ“Š FINAL EVALUATION RESULTS")
492
+ training_state.add_log("="*60)
493
+ training_state.add_log(f"βœ… Accuracy: {eval_results['eval_accuracy']*100:.2f}%")
494
+ training_state.add_log(f"βœ… F1 Score (macro): {eval_results['eval_f1_macro']*100:.2f}%")
495
+ training_state.add_log(f"βœ… F1 Score (weighted): {eval_results['eval_f1_weighted']*100:.2f}%")
496
+
497
+ eval_file = Path(config.output_dir) / "evaluation_results.json"
498
+ with open(eval_file, 'w') as f:
499
+ json.dump(eval_results, f, indent=2)
500
+
501
+ training_state.add_log("\nβœ… Training completed successfully!")
502
+
503
+ except Exception as e:
504
+ training_state.add_log(f"\n❌ Training failed: {str(e)}")
505
+ training_state.is_training = False
506
+ raise
507
+
508
+ # ============================================================================
509
+ # VISUALIZATION FUNCTIONS
510
+ # ============================================================================
511
+
512
+ def create_metrics_plot():
513
+ """Create interactive metrics plot"""
514
+ if not training_state.train_loss and not training_state.eval_accuracy:
515
+ # Empty plot
516
+ fig = go.Figure()
517
+ fig.add_annotation(
518
+ text="Training not started yet",
519
+ xref="paper", yref="paper",
520
+ x=0.5, y=0.5, showarrow=False,
521
+ font=dict(size=20, color="gray")
522
+ )
523
+ fig.update_layout(
524
+ title="Training Metrics",
525
+ xaxis_title="Step",
526
+ yaxis_title="Value",
527
+ template="plotly_white",
528
+ height=400
529
+ )
530
+ return fig
531
+
532
+ # Create subplots
533
+ fig = make_subplots(
534
+ rows=2, cols=2,
535
+ subplot_titles=("Training Loss", "Evaluation Accuracy", "Evaluation F1 Score", "Learning Rate"),
536
+ vertical_spacing=0.12,
537
+ horizontal_spacing=0.1
538
+ )
539
+
540
+ # Training loss
541
+ if training_state.train_loss:
542
+ steps, losses = zip(*training_state.train_loss)
543
+ fig.add_trace(
544
+ go.Scatter(x=steps, y=losses, mode='lines', name='Train Loss', line=dict(color='red')),
545
+ row=1, col=1
546
+ )
547
+
548
+ # Evaluation accuracy
549
+ if training_state.eval_accuracy:
550
+ steps, accs = zip(*training_state.eval_accuracy)
551
+ fig.add_trace(
552
+ go.Scatter(x=steps, y=accs, mode='lines+markers', name='Eval Accuracy', line=dict(color='blue')),
553
+ row=1, col=2
554
+ )
555
+
556
+ # Evaluation F1
557
+ if training_state.eval_f1:
558
+ steps, f1s = zip(*training_state.eval_f1)
559
+ fig.add_trace(
560
+ go.Scatter(x=steps, y=f1s, mode='lines+markers', name='Eval F1', line=dict(color='green')),
561
+ row=2, col=1
562
+ )
563
+
564
+ # Learning rate
565
+ if training_state.learning_rates:
566
+ steps, lrs = zip(*training_state.learning_rates)
567
+ fig.add_trace(
568
+ go.Scatter(x=steps, y=lrs, mode='lines', name='Learning Rate', line=dict(color='orange')),
569
+ row=2, col=2
570
+ )
571
+
572
+ fig.update_layout(
573
+ showlegend=False,
574
+ template="plotly_white",
575
+ height=600,
576
+ title_text="Training Metrics Dashboard",
577
+ title_font_size=20
578
+ )
579
+
580
+ # Update axes labels
581
+ fig.update_xaxes(title_text="Step", row=2, col=1)
582
+ fig.update_xaxes(title_text="Step", row=2, col=2)
583
+ fig.update_yaxes(title_text="Loss", row=1, col=1)
584
+ fig.update_yaxes(title_text="Accuracy", row=1, col=2)
585
+ fig.update_yaxes(title_text="F1 Score", row=2, col=1)
586
+ fig.update_yaxes(title_text="LR", row=2, col=2)
587
+
588
+ return fig
589
+
590
+
591
+ def create_progress_info():
592
+ """Create progress information HTML"""
593
+ progress = training_state.get_progress()
594
+
595
+ if progress["is_training"]:
596
+ status_color = "green"
597
+ status_text = "🟒 TRAINING IN PROGRESS"
598
+ else:
599
+ status_color = "gray"
600
+ status_text = "βšͺ READY"
601
+
602
+ html = f"""
603
+ <div style="padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
604
+ border-radius: 10px; color: white; font-family: 'Arial', sans-serif;">
605
+ <h2 style="margin: 0 0 15px 0; font-size: 24px;">{status_text}</h2>
606
+ <div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 15px;">
607
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 8px;">
608
+ <div style="font-size: 12px; opacity: 0.8;">EPOCH</div>
609
+ <div style="font-size: 24px; font-weight: bold;">{progress['epoch']}</div>
610
+ </div>
611
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 8px;">
612
+ <div style="font-size: 12px; opacity: 0.8;">STEP</div>
613
+ <div style="font-size: 24px; font-weight: bold;">{progress['step']}</div>
614
+ </div>
615
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 8px;">
616
+ <div style="font-size: 12px; opacity: 0.8;">TIME ELAPSED</div>
617
+ <div style="font-size: 24px; font-weight: bold;">{progress['elapsed_time']}</div>
618
+ </div>
619
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 8px;">
620
+ <div style="font-size: 12px; opacity: 0.8;">BEST ACCURACY</div>
621
+ <div style="font-size: 24px; font-weight: bold;">{progress['best_accuracy']}</div>
622
+ </div>
623
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 8px;">
624
+ <div style="font-size: 12px; opacity: 0.8;">BEST F1 SCORE</div>
625
+ <div style="font-size: 24px; font-weight: bold;">{progress['best_f1']}</div>
626
+ </div>
627
+ <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 8px;">
628
+ <div style="font-size: 12px; opacity: 0.8;">PROGRESS</div>
629
+ <div style="font-size: 24px; font-weight: bold;">{progress['progress']*100:.1f}%</div>
630
+ </div>
631
+ </div>
632
+ </div>
633
+ """
634
+
635
+ return html
636
+
637
+ # ============================================================================
638
+ # GRADIO INTERFACE
639
+ # ============================================================================
640
+
641
+ def start_training(model_choice, epochs, batch_size, learning_rate, warmup_ratio,
642
+ grad_accum, use_early_stopping):
643
+ """Start training in background thread"""
644
+
645
+ if training_state.is_training:
646
+ return "❌ Training already in progress!"
647
+
648
+ # Update config
649
+ config = TrainingConfig()
650
+ config.model_name = MODEL_CONFIGS[model_choice]
651
+ config.num_epochs = int(epochs)
652
+ config.batch_size = int(batch_size)
653
+ config.learning_rate = float(learning_rate)
654
+ config.warmup_ratio = float(warmup_ratio)
655
+ config.gradient_accumulation_steps = int(grad_accum)
656
+
657
+ if not use_early_stopping:
658
+ config.early_stopping_patience = 999 # Effectively disabled
659
+
660
+ # Start training in background thread
661
+ thread = threading.Thread(target=train_model, args=(config,), daemon=True)
662
+ thread.start()
663
+
664
+ return "βœ… Training started! Check the logs and metrics below for progress."
665
+
666
+
667
+ def update_ui():
668
+ """Update UI with current state"""
669
+ return (
670
+ create_progress_info(),
671
+ create_metrics_plot(),
672
+ training_state.get_logs(),
673
+ not training_state.is_training # Enable/disable start button
674
+ )
675
+
676
+
677
+ # Build Gradio interface
678
+ with gr.Blocks(title="DeBERTa CWE Classification Training", theme=gr.themes.Soft()) as demo:
679
+ gr.Markdown("""
680
+ # πŸš€ DeBERTa CWE Classification - Fine-Tuning Dashboard
681
+
682
+ Train state-of-the-art DeBERTa models for CVE→CWE classification with real-time monitoring.
683
+
684
+ **Dataset:** stasvinokur/cve-and-cwe-dataset-1999-2025 (~300K CVE-CWE pairs)
685
+
686
+ **Task:** Single-label classification of vulnerabilities to Common Weakness Enumeration (CWE) classes
687
+ """)
688
+
689
+ with gr.Row():
690
+ with gr.Column(scale=1):
691
+ gr.Markdown("### βš™οΈ Training Configuration")
692
+
693
+ model_choice = gr.Dropdown(
694
+ choices=list(MODEL_CONFIGS.keys()),
695
+ value="DeBERTa-v3-Base (86M params, recommended)",
696
+ label="Model Architecture",
697
+ info="Larger models = better quality but slower training"
698
+ )
699
+
700
+ epochs = gr.Slider(
701
+ minimum=1, maximum=20, value=10, step=1,
702
+ label="Number of Epochs",
703
+ info="Recommended: 10 for optimal quality"
704
+ )
705
+
706
+ batch_size = gr.Slider(
707
+ minimum=4, maximum=32, value=16, step=4,
708
+ label="Batch Size per Device",
709
+ info="Larger = faster training, more memory"
710
+ )
711
+
712
+ learning_rate = gr.Slider(
713
+ minimum=1e-6, maximum=1e-4, value=2e-5, step=1e-6,
714
+ label="Learning Rate",
715
+ info="Default: 2e-5 (recommended for DeBERTa)"
716
+ )
717
+
718
+ warmup_ratio = gr.Slider(
719
+ minimum=0.0, maximum=0.3, value=0.1, step=0.01,
720
+ label="Warmup Ratio",
721
+ info="Fraction of training for LR warmup"
722
+ )
723
+
724
+ grad_accum = gr.Slider(
725
+ minimum=1, maximum=8, value=4, step=1,
726
+ label="Gradient Accumulation Steps",
727
+ info="Effective batch size = batch_size Γ— this value"
728
+ )
729
+
730
+ use_early_stopping = gr.Checkbox(
731
+ value=True,
732
+ label="Enable Early Stopping",
733
+ info="Stop if no improvement for 5 evaluations"
734
+ )
735
+
736
+ start_btn = gr.Button("πŸš€ Start Training", variant="primary", size="lg")
737
+ status_msg = gr.Textbox(label="Status", interactive=False)
738
+
739
+ gr.Markdown("""
740
+ ### πŸ“Š Expected Training Time
741
+ - **Base model (GPU):** ~2-3 hours
742
+ - **Base model (CPU):** ~10-12 hours
743
+ - **Large model (GPU):** ~6-8 hours
744
+
745
+ ### πŸ’Ύ Output Location
746
+ Model will be saved to:
747
+ `/Users/lorenzo/Documents/Claude Code/projects/mcps/mcp-cwe-identifier/models/deberta-cwe-final`
748
+ """)
749
+
750
+ with gr.Column(scale=2):
751
+ gr.Markdown("### πŸ“ˆ Training Progress")
752
+
753
+ progress_html = gr.HTML(create_progress_info())
754
+ metrics_plot = gr.Plot(create_metrics_plot())
755
+
756
+ gr.Markdown("### πŸ“ Training Logs")
757
+ logs_box = gr.Textbox(
758
+ label="Live Training Logs",
759
+ lines=15,
760
+ max_lines=20,
761
+ interactive=False,
762
+ show_copy_button=True
763
+ )
764
+
765
+ # Event handlers
766
+ start_btn.click(
767
+ fn=start_training,
768
+ inputs=[model_choice, epochs, batch_size, learning_rate, warmup_ratio,
769
+ grad_accum, use_early_stopping],
770
+ outputs=status_msg
771
+ )
772
+
773
+ # Auto-refresh UI every 2 seconds using timer
774
+ refresh_timer = gr.Timer(value=2, active=True)
775
+ refresh_timer.tick(
776
+ fn=update_ui,
777
+ outputs=[progress_html, metrics_plot, logs_box, start_btn]
778
+ )
779
+
780
+ gr.Markdown("""
781
+ ---
782
+ ### 🎯 Next Steps After Training
783
+ 1. **Test Model:** Use the trained model for CWE prediction
784
+ 2. **Integrate:** Update MCP server to use the new model
785
+ 3. **Benchmark:** Compare against existing models
786
+ 4. **Deploy:** Push to production environment
787
+
788
+ **Developed by:** Berghem - Smart Information Security | **License:** MIT
789
+ """)
790
+
791
+ # ============================================================================
792
+ # LAUNCH
793
+ # ============================================================================
794
+
795
+ if __name__ == "__main__":
796
+ demo.queue() # Enable queuing for better concurrency
797
+ demo.launch(
798
+ server_name="0.0.0.0",
799
+ server_port=7860,
800
+ share=False,
801
+ show_error=True
802
+ )
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML/DL frameworks
2
+ torch>=2.0.0
3
+ transformers>=4.35.0
4
+ datasets>=2.14.0
5
+ tokenizers>=0.15.0
6
+
7
+ # Gradio for UI
8
+ gradio==5.49.1
9
+
10
+ # Data processing and visualization
11
+ numpy>=1.24.0
12
+ pandas>=2.0.0
13
+ plotly>=5.18.0
14
+ scikit-learn>=1.3.0
15
+
16
+ # Accelerate training (optional but recommended)
17
+ accelerate>=0.24.0
18
+
19
+ # For better performance
20
+ sentencepiece>=0.1.99
21
+ protobuf>=3.20.0
22
+
23
+ # Utils
24
+ tqdm>=4.66.0