English
jpjoshi commited on
Commit
7246e2b
·
verified ·
1 Parent(s): bd4c971

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ adaptive[[:space:]]waf/csic_database.csv filter=lfs diff=lfs merge=lfs -text
CodeBlocks.lnk ADDED
Binary file (949 Bytes). View file
 
Comet.lnk ADDED
Binary file (2.47 kB). View file
 
MinGW Installer.lnk ADDED
Binary file (879 Bytes). View file
 
Visual Studio Code.lnk ADDED
Binary file (1.4 kB). View file
 
adaptive waf/best_student_waf_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fd9dff4320234038a2875a7e320b7cdb47034fc43d5686fc2cea15ed7a73222
3
+ size 265501020
adaptive waf/csic_database.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c420f0bc0464376de75b6c419a0ac226fe69fe12c8ac4908843273721e44e637
3
+ size 29539583
adaptive waf/model.ipynb ADDED
@@ -0,0 +1,860 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "60707c2d",
7
+ "metadata": {
8
+ "vscode": {
9
+ "languageId": "plaintext"
10
+ }
11
+ },
12
+ "outputs": [],
13
+ "source": [
14
+ "# --- 0. Install Libraries ---\n",
15
+ "import subprocess\n",
16
+ "import sys\n",
17
+ "\n",
18
+ "def install_if_missing(package):\n",
19
+ " try:\n",
20
+ " __import__(package)\n",
21
+ " except ImportError:\n",
22
+ " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", package])\n",
23
+ "\n",
24
+ "packages_to_check = ['transformers', 'openpyxl', 'tqdm']\n",
25
+ "for package in packages_to_check:\n",
26
+ " install_if_missing(package)\n",
27
+ "\n",
28
+ "# --- 1. Import Libraries ---\n",
29
+ "import pandas as pd\n",
30
+ "import torch\n",
31
+ "import torch.nn as nn\n",
32
+ "from torch.utils.data import Dataset, DataLoader, random_split\n",
33
+ "from sklearn.preprocessing import LabelEncoder\n",
34
+ "from sklearn.metrics import classification_report, confusion_matrix, f1_score, accuracy_score\n",
35
+ "from sklearn.tree import DecisionTreeClassifier, export_text\n",
36
+ "from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup\n",
37
+ "import numpy as np\n",
38
+ "import matplotlib.pyplot as plt\n",
39
+ "import seaborn as sns\n",
40
+ "from tqdm.auto import tqdm\n",
41
+ "import warnings\n",
42
+ "import os\n",
43
+ "warnings.filterwarnings('ignore')\n",
44
+ "\n",
45
+ "# --- 2. Kaggle Environment Detection ---\n",
46
+ "def setup_kaggle_environment():\n",
47
+ " is_kaggle = os.environ.get('KAGGLE_KERNEL_RUN_TYPE') is not None\n",
48
+ " if is_kaggle:\n",
49
+ " print(\"Kaggle environment detected!\")\n",
50
+ " input_dir = '/kaggle/input'\n",
51
+ " working_dir = '/kaggle/working'\n",
52
+ " dataset_files = []\n",
53
+ " if os.path.exists(input_dir):\n",
54
+ " for root, dirs, files in os.walk(input_dir):\n",
55
+ " for file in files:\n",
56
+ " if file.endswith(('.xlsx', '.csv')):\n",
57
+ " dataset_files.append(os.path.join(root, file))\n",
58
+ " print(f\"Available dataset files: {dataset_files}\")\n",
59
+ " file_path = dataset_files[0] if dataset_files else '/kaggle/input/csic_database.xlsx'\n",
60
+ " return {'file_path': file_path, 'output_dir': working_dir, 'is_kaggle': True}\n",
61
+ " else:\n",
62
+ " print(\"Local environment detected!\")\n",
63
+ " return {'file_path': './csic_database.xlsx', 'output_dir': './', 'is_kaggle': False}\n",
64
+ "\n",
65
+ "env_config = setup_kaggle_environment()\n",
66
+ "\n",
67
+ "CONFIG = {\n",
68
+ " 'file_path': env_config['file_path'],\n",
69
+ " 'output_dir': env_config['output_dir'],\n",
70
+ " 'batch_size': 16 if torch.cuda.is_available() else 8,\n",
71
+ " 'max_length': 512,\n",
72
+ " 'learning_rate': 2e-5,\n",
73
+ " 'num_epochs': 3,\n",
74
+ " 'max_depth': 10,\n",
75
+ " 'test_split': 0.2,\n",
76
+ " 'random_seed': 42,\n",
77
+ " 'is_kaggle': env_config['is_kaggle']\n",
78
+ "}\n",
79
+ "\n",
80
+ "torch.manual_seed(CONFIG['random_seed'])\n",
81
+ "np.random.seed(CONFIG['random_seed'])\n",
82
+ "print(f\"Configuration: {CONFIG}\")\n",
83
+ "\n",
84
+ "# --- 4. Data Loading ---\n",
85
+ "def load_and_preprocess_data(file_path):\n",
86
+ " try:\n",
87
+ " print(f\"Loading dataset from: {file_path}\")\n",
88
+ " if not os.path.exists(file_path) and CONFIG['is_kaggle']:\n",
89
+ " input_dir = '/kaggle/input'\n",
90
+ " print(f\"File not found. Searching in {input_dir}...\")\n",
91
+ " for root, dirs, files in os.walk(input_dir):\n",
92
+ " for file in files:\n",
93
+ " if 'csic' in file.lower() or 'dataset' in file.lower():\n",
94
+ " file_path = os.path.join(root, file)\n",
95
+ " break\n",
96
+ " if file_path.endswith('.xlsx'):\n",
97
+ " df = pd.read_excel(file_path)\n",
98
+ " elif file_path.endswith('.csv'):\n",
99
+ " df = pd.read_csv(file_path)\n",
100
+ " else:\n",
101
+ " try: df = pd.read_excel(file_path)\n",
102
+ " except: df = pd.read_csv(file_path)\n",
103
+ " print(f\"Dataset loaded! Shape: {df.shape}\")\n",
104
+ " label_columns = ['classification','label','class','target']\n",
105
+ " label_col = next((c for c in label_columns if c in df.columns), None)\n",
106
+ " if label_col:\n",
107
+ " print(f\"Label distribution:\\n{df[label_col].value_counts()}\")\n",
108
+ " return df\n",
109
+ " except Exception as e:\n",
110
+ " print(f\"Error loading file: {e}\")\n",
111
+ " return None\n",
112
+ "\n",
113
+ "# --- 5. Preprocess ---\n",
114
+ "def preprocess_data(df):\n",
115
+ " print(\"Preprocessing data...\")\n",
116
+ " label_columns = ['classification','label','class','target']\n",
117
+ " label_col = next((c for c in label_columns if c in df.columns), None)\n",
118
+ " if label_col and label_col != 'label':\n",
119
+ " df.rename(columns={label_col:'label'}, inplace=True)\n",
120
+ " length_columns = ['lenght','length','len']\n",
121
+ " for col in length_columns: df.drop(col, axis=1, inplace=True, errors='ignore')\n",
122
+ " potential_text_cols = [\n",
123
+ " 'Method','method','HTTP_Method','User-Agent','user-agent','useragent','User_Agent',\n",
124
+ " 'Pragma','pragma','Cache-Control','cache-control','Cache_Control','Accept','accept',\n",
125
+ " 'Accept-encoding','accept-encoding','Accept_Encoding','Accept-charset','accept-charset','Accept_Charset',\n",
126
+ " 'language','Language','lang','host','Host','hostname','cookie','Cookie','cookies',\n",
127
+ " 'content-type','Content-Type','Content_Type','contenttype','connection','Connection','URL','url','uri','path',\n",
128
+ " 'content','Content','payload','data','body']\n",
129
+ " available_text_cols = []\n",
130
+ " for col in df.columns:\n",
131
+ " if col in potential_text_cols or any(k in col.lower() for k in ['method','agent','url','content','header']):\n",
132
+ " available_text_cols.append(col)\n",
133
+ " df[col] = df[col].astype(str).fillna('')\n",
134
+ " print(f\"Available text columns: {available_text_cols}\")\n",
135
+ " if available_text_cols:\n",
136
+ " combined_parts = [f'{col}: '+df[col].astype(str) for col in available_text_cols]\n",
137
+ " df['combined_text'] = combined_parts[0]\n",
138
+ " for part in combined_parts[1:]:\n",
139
+ " df['combined_text'] += ' '+part\n",
140
+ " else:\n",
141
+ " text_cols = df.select_dtypes(include=['object']).columns.tolist()\n",
142
+ " if 'label' in text_cols: text_cols.remove('label')\n",
143
+ " if text_cols:\n",
144
+ " print(f\"Using all object columns as text: {text_cols}\")\n",
145
+ " combined_parts = [f'{col}: '+df[col].astype(str).fillna('') for col in text_cols]\n",
146
+ " df['combined_text'] = combined_parts[0]\n",
147
+ " for part in combined_parts[1:]:\n",
148
+ " df['combined_text'] += ' '+part\n",
149
+ " else:\n",
150
+ " print(\"No text columns found!\")\n",
151
+ " return None,None,None\n",
152
+ " if 'label' not in df.columns:\n",
153
+ " print(\"No 'label' column found!\")\n",
154
+ " return None,None,None\n",
155
+ " combined_text = df['combined_text']\n",
156
+ " y_raw = df['label']\n",
157
+ " label_encoder = LabelEncoder()\n",
158
+ " y = label_encoder.fit_transform(y_raw)\n",
159
+ " print(f\"Classes: {label_encoder.classes_}\")\n",
160
+ " return combined_text,y,label_encoder\n",
161
+ "\n",
162
+ "# --- 6. Dataset ---\n",
163
+ "class CSICBertDataset(Dataset):\n",
164
+ " def __init__(self, encodings, labels):\n",
165
+ " self.encodings = encodings; self.labels = labels\n",
166
+ " def __len__(self): return len(self.labels)\n",
167
+ " def __getitem__(self, idx):\n",
168
+ " item = {k:v[idx] for k,v in self.encodings.items()}\n",
169
+ " item['labels'] = self.labels[idx]; return item\n",
170
+ "\n",
171
+ "# --- 7. Model ---\n",
172
+ "class BertClassifier(nn.Module):\n",
173
+ " def __init__(self, n_classes, dropout_rate=0.3):\n",
174
+ " super().__init__()\n",
175
+ " self.bert = BertModel.from_pretrained('bert-base-uncased')\n",
176
+ " self.dropout = nn.Dropout(dropout_rate)\n",
177
+ " self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)\n",
178
+ " def forward(self,input_ids,attention_mask):\n",
179
+ " outputs = self.bert(input_ids=input_ids,attention_mask=attention_mask)\n",
180
+ " cls_embedding = self.dropout(outputs.last_hidden_state[:,0,:])\n",
181
+ " logits = self.classifier(cls_embedding)\n",
182
+ " return logits,cls_embedding\n",
183
+ "\n",
184
+ "# --- 8. Training with checkpoint ---\n",
185
+ "def train_model(model, train_loader, val_loader, device, config,\n",
186
+ " optimizer, scheduler, resume_epoch=0,\n",
187
+ " train_losses=None, val_losses=None, val_accuracies=None):\n",
188
+ " if train_losses is None: train_losses=[]\n",
189
+ " if val_losses is None: val_losses=[]\n",
190
+ " if val_accuracies is None: val_accuracies=[]\n",
191
+ " criterion = nn.CrossEntropyLoss()\n",
192
+ " for epoch in range(resume_epoch, config['num_epochs']):\n",
193
+ " model.train(); total_train_loss=0\n",
194
+ " for batch in tqdm(train_loader,desc=f'Epoch {epoch+1}/{config[\"num_epochs\"]} [Train]'):\n",
195
+ " input_ids=batch['input_ids'].to(device)\n",
196
+ " attention_mask=batch['attention_mask'].to(device)\n",
197
+ " labels=batch['labels'].to(device)\n",
198
+ " optimizer.zero_grad()\n",
199
+ " logits,_=model(input_ids,attention_mask)\n",
200
+ " loss=criterion(logits,labels); loss.backward()\n",
201
+ " torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)\n",
202
+ " optimizer.step(); scheduler.step()\n",
203
+ " total_train_loss+=loss.item()\n",
204
+ " avg_train_loss=total_train_loss/len(train_loader)\n",
205
+ " train_losses.append(avg_train_loss)\n",
206
+ " # validation\n",
207
+ " model.eval(); total_val_loss=0; correct=0; total=0\n",
208
+ " with torch.no_grad():\n",
209
+ " for batch in tqdm(val_loader,desc=f'Epoch {epoch+1}/{config[\"num_epochs\"]} [Val]'):\n",
210
+ " input_ids=batch['input_ids'].to(device)\n",
211
+ " attention_mask=batch['attention_mask'].to(device)\n",
212
+ " labels=batch['labels'].to(device)\n",
213
+ " logits,_=model(input_ids,attention_mask)\n",
214
+ " loss=criterion(logits,labels); total_val_loss+=loss.item()\n",
215
+ " preds=torch.argmax(logits,dim=1)\n",
216
+ " correct+=(preds==labels).sum().item(); total+=labels.size(0)\n",
217
+ " avg_val_loss=total_val_loss/len(val_loader); val_acc=correct/total\n",
218
+ " val_losses.append(avg_val_loss); val_accuracies.append(val_acc)\n",
219
+ " print(f\"Epoch {epoch+1}: Train Loss {avg_train_loss:.4f} Val Loss {avg_val_loss:.4f} Val Acc {val_acc:.4f}\")\n",
220
+ " # save checkpoint\n",
221
+ " checkpoint_path=os.path.join(config['output_dir'],'bert_checkpoint.pt')\n",
222
+ " torch.save({\n",
223
+ " 'epoch':epoch+1,\n",
224
+ " 'model_state_dict':model.state_dict(),\n",
225
+ " 'optimizer_state_dict':optimizer.state_dict(),\n",
226
+ " 'scheduler_state_dict':scheduler.state_dict(),\n",
227
+ " 'train_losses':train_losses,\n",
228
+ " 'val_losses':val_losses,\n",
229
+ " 'val_accuracies':val_accuracies\n",
230
+ " },checkpoint_path)\n",
231
+ " print(f\"Checkpoint saved at {checkpoint_path}\")\n",
232
+ " return train_losses,val_losses,val_accuracies\n",
233
+ "\n",
234
+ "# --- 9. Evaluation ---\n",
235
+ "def evaluate_model(model, test_loader, device, label_encoder, config):\n",
236
+ " model.eval(); all_preds=[]; all_trues=[]\n",
237
+ " with torch.no_grad():\n",
238
+ " for batch in tqdm(test_loader,desc='Evaluating'):\n",
239
+ " input_ids=batch['input_ids'].to(device)\n",
240
+ " attention_mask=batch['attention_mask'].to(device)\n",
241
+ " labels=batch['labels'].to(device)\n",
242
+ " logits,_=model(input_ids,attention_mask)\n",
243
+ " preds=torch.argmax(logits,dim=1)\n",
244
+ " all_preds.extend(preds.cpu().tolist()); all_trues.extend(labels.cpu().tolist())\n",
245
+ " accuracy=accuracy_score(all_trues,all_preds)\n",
246
+ " f1w=f1_score(all_trues,all_preds,average='weighted')\n",
247
+ " print(f\"Test Accuracy: {accuracy:.4f} Weighted F1: {f1w:.4f}\")\n",
248
+ " print(classification_report(all_trues,all_preds,target_names=label_encoder.classes_.astype(str)))\n",
249
+ " cm=confusion_matrix(all_trues,all_preds)\n",
250
+ " plt.figure(figsize=(10,8))\n",
251
+ " sns.heatmap(cm,annot=True,fmt='d',cmap='Blues',\n",
252
+ " xticklabels=label_encoder.classes_.astype(str),\n",
253
+ " yticklabels=label_encoder.classes_.astype(str))\n",
254
+ " plt.title('Confusion Matrix'); plt.ylabel('True'); plt.xlabel('Predicted')\n",
255
+ " plt.tight_layout()\n",
256
+ " plt.savefig(os.path.join(config['output_dir'],'confusion_matrix.png'),dpi=300)\n",
257
+ " plt.show()\n",
258
+ " return all_preds,all_trues,accuracy,f1w\n",
259
+ "\n",
260
+ "# --- 11. Main ---\n",
261
+ "def main():\n",
262
+ " print(\"=\"*60)\n",
263
+ " print(\"CSIC BERT CLASSIFIER WITH RESUME SUPPORT\")\n",
264
+ " print(\"=\"*60)\n",
265
+ " df=load_and_preprocess_data(CONFIG['file_path'])\n",
266
+ " if df is None: return\n",
267
+ " combined_text,y,label_encoder=preprocess_data(df)\n",
268
+ " if combined_text is None: return\n",
269
+ " y_tensor=torch.tensor(y,dtype=torch.long)\n",
270
+ " tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')\n",
271
+ " tokenized_inputs=tokenizer(\n",
272
+ " combined_text.tolist(),\n",
273
+ " padding='max_length',truncation=True,max_length=CONFIG['max_length'],return_tensors=\"pt\")\n",
274
+ " dataset=CSICBertDataset(tokenized_inputs,y_tensor)\n",
275
+ " total_size=len(dataset)\n",
276
+ " test_size=int(CONFIG['test_split']*total_size)\n",
277
+ " train_val_size=total_size-test_size\n",
278
+ " val_size=int(0.1*total_size)\n",
279
+ " train_size=train_val_size-val_size\n",
280
+ " train_ds,val_ds,test_ds=random_split(dataset,[train_size,val_size,test_size],\n",
281
+ " generator=torch.Generator().manual_seed(CONFIG['random_seed']))\n",
282
+ " print(f\"Splits - Train: {train_size} Val: {val_size} Test: {test_size}\")\n",
283
+ " train_loader=DataLoader(train_ds,batch_size=CONFIG['batch_size'],shuffle=True)\n",
284
+ " val_loader=DataLoader(val_ds,batch_size=CONFIG['batch_size'],shuffle=False)\n",
285
+ " test_loader=DataLoader(test_ds,batch_size=CONFIG['batch_size'],shuffle=False)\n",
286
+ " device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
287
+ " model=BertClassifier(n_classes=len(label_encoder.classes_)).to(device)\n",
288
+ " optimizer=torch.optim.AdamW(model.parameters(),lr=CONFIG['learning_rate'])\n",
289
+ " total_steps=len(train_loader)*CONFIG['num_epochs']\n",
290
+ " scheduler=get_linear_schedule_with_warmup(\n",
291
+ " optimizer,num_warmup_steps=int(0.1*total_steps),num_training_steps=total_steps)\n",
292
+ " # resume logic\n",
293
+ " checkpoint_path=os.path.join(CONFIG['output_dir'],'bert_checkpoint.pt')\n",
294
+ " resume_epoch=0; train_losses=val_losses=val_accuracies=None\n",
295
+ " if os.path.exists(checkpoint_path):\n",
296
+ " print(f\"Resuming from checkpoint: {checkpoint_path}\")\n",
297
+ " checkpoint=torch.load(checkpoint_path,map_location=device)\n",
298
+ " model.load_state_dict(checkpoint['model_state_dict'])\n",
299
+ " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
300
+ " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n",
301
+ " resume_epoch=checkpoint['epoch']\n",
302
+ " train_losses=checkpoint['train_losses']; val_losses=checkpoint['val_losses']; val_accuracies=checkpoint['val_accuracies']\n",
303
+ " train_losses,val_losses,val_accuracies=train_model(\n",
304
+ " model,train_loader,val_loader,device,CONFIG,optimizer,scheduler,resume_epoch,\n",
305
+ " train_losses,val_losses,val_accuracies)\n",
306
+ " # evaluate\n",
307
+ " all_preds,all_trues,accuracy,f1w=evaluate_model(model,test_loader,device,label_encoder,CONFIG)\n",
308
+ " # save final model\n",
309
+ " model_path=os.path.join(CONFIG['output_dir'],'bert_classifier_model.pt')\n",
310
+ " torch.save({'model_state_dict':model.state_dict(),\n",
311
+ " 'label_encoder':label_encoder,\n",
312
+ " 'config':CONFIG,\n",
313
+ " 'test_accuracy':accuracy,\n",
314
+ " 'f1_score':f1w},model_path)\n",
315
+ " print(f\"Model saved to {model_path}\")\n",
316
+ " print(\"Done.\")\n",
317
+ "\n",
318
+ "if __name__==\"__main__\":\n",
319
+ " main() "
320
+ ]
321
+ },
322
+ {
323
+ "cell_type": "code",
324
+ "execution_count": null,
325
+ "id": "44056c38",
326
+ "metadata": {
327
+ "vscode": {
328
+ "languageId": "plaintext"
329
+ }
330
+ },
331
+ "outputs": [],
332
+ "source": [
333
+ "# ==============================================================================\n",
334
+ "# Complete BERT → DistilBERT + MLP Knowledge Distillation Pipeline\n",
335
+ "# CSIC 2010 Web Application Attack Detector (Adaptive WAF)\n",
336
+ "# ==============================================================================\n",
337
+ "\n",
338
+ "# --- 0. Setup and Imports (omitted for brevity, assume the user's provided imports) ---\n",
339
+ "import torch\n",
340
+ "import torch.nn as nn\n",
341
+ "import torch.nn.functional as F\n",
342
+ "from torch.utils.data import DataLoader, Dataset, random_split\n",
343
+ "import numpy as np\n",
344
+ "import pandas as pd\n",
345
+ "from sklearn.preprocessing import LabelEncoder\n",
346
+ "from sklearn.model_selection import train_test_split\n",
347
+ "from sklearn.metrics import roc_curve, auc, classification_report, accuracy_score, f1_score\n",
348
+ "from sklearn.tree import DecisionTreeClassifier, export_text # XAI Import\n",
349
+ "import matplotlib.pyplot as plt\n",
350
+ "import seaborn as sns\n",
351
+ "from transformers import (\n",
352
+ " BertTokenizer, BertModel,\n",
353
+ " DistilBertModel,\n",
354
+ " get_linear_schedule_with_warmup\n",
355
+ ")\n",
356
+ "from torch.optim import AdamW\n",
357
+ "from tqdm.auto import tqdm\n",
358
+ "import warnings\n",
359
+ "import os\n",
360
+ "\n",
361
+ "warnings.filterwarnings('ignore')\n",
362
+ "torch.manual_seed(42)\n",
363
+ "np.random.seed(42)\n",
364
+ "\n",
365
+ "# --- 1. Configuration & Environment ---\n",
366
+ "CHECKPOINT_PATH = '/kaggle/working/bert_classifier_model.pt'\n",
367
+ "DATASET_PATH = '/kaggle/input/csic-2010-web-application-attacks/csic_database.csv'\n",
368
+ "BEST_MODEL_PATH = '/kaggle/working/best_student_waf_model.pt'\n",
369
+ "MAX_LENGTH = 512\n",
370
+ "BATCH_SIZE = 16\n",
371
+ "NUM_EPOCHS = 5\n",
372
+ "LEARNING_RATE = 2e-5\n",
373
+ "OUTPUT_DIR = '/kaggle/working'\n",
374
+ "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
375
+ "XAI_MAX_DEPTH = 10\n",
376
+ "\n",
377
+ "# --- 2. Dataset and Data Loading (TextDataset and load_csic_dataset functions remain the same) ---\n",
378
+ "class TextDataset(Dataset):\n",
379
+ " def __init__(self, texts, labels, tokenizer, max_length):\n",
380
+ " self.texts = texts\n",
381
+ " self.labels = labels\n",
382
+ " self.tokenizer = tokenizer\n",
383
+ " self.max_length = max_length\n",
384
+ " \n",
385
+ " def __len__(self):\n",
386
+ " return len(self.texts)\n",
387
+ " \n",
388
+ " def __getitem__(self, idx):\n",
389
+ " text = str(self.texts[idx])\n",
390
+ " label = self.labels[idx]\n",
391
+ " \n",
392
+ " encoding = self.tokenizer(\n",
393
+ " text,\n",
394
+ " truncation=True,\n",
395
+ " padding='max_length',\n",
396
+ " max_length=self.max_length,\n",
397
+ " return_tensors='pt'\n",
398
+ " )\n",
399
+ " \n",
400
+ " return {\n",
401
+ " 'input_ids': encoding['input_ids'].flatten(),\n",
402
+ " 'attention_mask': encoding['attention_mask'].flatten(),\n",
403
+ " 'label': torch.tensor(label, dtype=torch.long)\n",
404
+ " }\n",
405
+ "\n",
406
+ "def load_csic_dataset(file_path):\n",
407
+ " try:\n",
408
+ " df = pd.read_csv(file_path)\n",
409
+ " except FileNotFoundError:\n",
410
+ " print(f\"Error: Dataset not found at {file_path}. Please check the path.\")\n",
411
+ " return None, None, None\n",
412
+ "\n",
413
+ " text_columns = ['Method', 'User-Agent', 'Pragma', 'Cache-Control', 'Accept',\n",
414
+ " 'Accept-encoding', 'Accept-charset', 'language', 'host',\n",
415
+ " 'cookie', 'content-type', 'connection', 'content', 'URL']\n",
416
+ " \n",
417
+ " df['combined_text'] = ''\n",
418
+ " for col in text_columns:\n",
419
+ " if col in df.columns:\n",
420
+ " df['combined_text'] += df[col].fillna('').astype(str) + ' '\n",
421
+ " \n",
422
+ " df['combined_text'] = df['combined_text'].str.strip()\n",
423
+ " \n",
424
+ " texts = df['combined_text'].values\n",
425
+ " labels_raw = df['classification'].values\n",
426
+ " \n",
427
+ " le = LabelEncoder()\n",
428
+ " labels = le.fit_transform(labels_raw)\n",
429
+ "\n",
430
+ " print(f\"Dataset loaded! Shape: {df.shape}\")\n",
431
+ " print(f\"Label distribution:\\n{df['classification'].value_counts()}\")\n",
432
+ " return texts, labels, le\n",
433
+ "\n",
434
+ "# --- 3. Model Architectures (TeacherBERT, StudentDistilBERT, StudentMLP functions remain the same) ---\n",
435
+ "class TeacherBERT(nn.Module):\n",
436
+ " def __init__(self, n_classes, model_name='bert-base-uncased', dropout_rate=0.3):\n",
437
+ " super(TeacherBERT, self).__init__()\n",
438
+ " self.bert = BertModel.from_pretrained(model_name)\n",
439
+ " self.dropout = nn.Dropout(dropout_rate)\n",
440
+ " self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)\n",
441
+ " \n",
442
+ " def forward(self, input_ids, attention_mask):\n",
443
+ " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
444
+ " cls_embedding = self.dropout(outputs.last_hidden_state[:, 0, :])\n",
445
+ " logits = self.classifier(cls_embedding)\n",
446
+ " return logits\n",
447
+ "\n",
448
+ "class StudentDistilBERT(nn.Module):\n",
449
+ " def __init__(self, model_name='distilbert-base-uncased', num_classes=2, dropout=0.1):\n",
450
+ " super(StudentDistilBERT, self).__init__()\n",
451
+ " self.distilbert = DistilBertModel.from_pretrained(model_name)\n",
452
+ " self.dropout = nn.Dropout(dropout)\n",
453
+ " self.classifier = nn.Linear(self.distilbert.config.hidden_size, num_classes)\n",
454
+ " \n",
455
+ " def forward(self, input_ids, attention_mask):\n",
456
+ " outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)\n",
457
+ " pooled_output = outputs.last_hidden_state[:, 0]\n",
458
+ " pooled_output_dropped = self.dropout(pooled_output)\n",
459
+ " logits = self.classifier(pooled_output_dropped)\n",
460
+ " return logits, pooled_output\n",
461
+ "\n",
462
+ "class StudentMLP(nn.Module):\n",
463
+ " def __init__(self, vocab_size=30522, embed_dim=128, hidden_dims=[256, 128], \n",
464
+ " num_classes=2, dropout=0.3):\n",
465
+ " super(StudentMLP, self).__init__()\n",
466
+ " self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)\n",
467
+ " self.dropout = nn.Dropout(dropout)\n",
468
+ " \n",
469
+ " layers = []\n",
470
+ " input_dim = embed_dim\n",
471
+ " for hidden_dim in hidden_dims:\n",
472
+ " layers.extend([\n",
473
+ " nn.Linear(input_dim, hidden_dim),\n",
474
+ " nn.ReLU(),\n",
475
+ " nn.Dropout(dropout)\n",
476
+ " ])\n",
477
+ " input_dim = hidden_dim\n",
478
+ " \n",
479
+ " layers.append(nn.Linear(input_dim, num_classes))\n",
480
+ " self.mlp = nn.Sequential(*layers)\n",
481
+ " \n",
482
+ " def forward(self, input_ids, attention_mask):\n",
483
+ " embeddings = self.embedding(input_ids)\n",
484
+ " mask = attention_mask.unsqueeze(-1).float()\n",
485
+ " masked_embeddings = embeddings * mask\n",
486
+ " pooled = masked_embeddings.sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)\n",
487
+ " pooled = self.dropout(pooled)\n",
488
+ " logits = self.mlp(pooled)\n",
489
+ " return logits\n",
490
+ "\n",
491
+ "# --- 4. Distillation Loss and Training Function (DistillationLoss and train_student_model functions remain the same) ---\n",
492
+ "class DistillationLoss(nn.Module):\n",
493
+ " def __init__(self, alpha=0.7, temperature=4.0):\n",
494
+ " super(DistillationLoss, self).__init__()\n",
495
+ " self.alpha = alpha\n",
496
+ " self.temperature = temperature\n",
497
+ " self.kl_div = nn.KLDivLoss(reduction='batchmean')\n",
498
+ " self.ce_loss = nn.CrossEntropyLoss()\n",
499
+ " \n",
500
+ " def forward(self, student_logits, teacher_logits, labels):\n",
501
+ " teacher_probs = F.softmax(teacher_logits / self.temperature, dim=1)\n",
502
+ " student_log_probs = F.log_softmax(student_logits / self.temperature, dim=1)\n",
503
+ " distillation_loss = self.kl_div(student_log_probs, teacher_probs) * (self.temperature ** 2)\n",
504
+ " student_loss = self.ce_loss(student_logits, labels)\n",
505
+ " total_loss = self.alpha * distillation_loss + (1 - self.alpha) * student_loss\n",
506
+ " return total_loss, distillation_loss, student_loss\n",
507
+ "\n",
508
+ "def train_student_model(student_model, teacher_model, train_loader, val_loader, device, name):\n",
509
+ " student_model.to(device)\n",
510
+ " teacher_model.to(device)\n",
511
+ " teacher_model.eval()\n",
512
+ " \n",
513
+ " optimizer = AdamW(student_model.parameters(), lr=LEARNING_RATE)\n",
514
+ " distillation_criterion = DistillationLoss(alpha=0.7, temperature=4.0)\n",
515
+ " \n",
516
+ " total_steps = len(train_loader) * NUM_EPOCHS\n",
517
+ " scheduler = get_linear_schedule_with_warmup(\n",
518
+ " optimizer, num_warmup_steps=0, num_training_steps=total_steps\n",
519
+ " )\n",
520
+ " \n",
521
+ " train_losses = []\n",
522
+ " val_accuracies = []\n",
523
+ " \n",
524
+ " print(f\"\\n--- Starting Distillation for {name} (Epochs: {NUM_EPOCHS}) ---\")\n",
525
+ " for epoch in range(NUM_EPOCHS):\n",
526
+ " student_model.train()\n",
527
+ " total_loss = 0\n",
528
+ " \n",
529
+ " progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS} [{name}]')\n",
530
+ " for batch in progress_bar:\n",
531
+ " input_ids = batch['input_ids'].to(device)\n",
532
+ " attention_mask = batch['attention_mask'].to(device)\n",
533
+ " labels = batch['label'].to(device)\n",
534
+ " \n",
535
+ " with torch.no_grad():\n",
536
+ " teacher_logits = teacher_model(input_ids, attention_mask)\n",
537
+ " \n",
538
+ " if name == 'DistilBERT':\n",
539
+ " student_logits, _ = student_model(input_ids, attention_mask)\n",
540
+ " else:\n",
541
+ " student_logits = student_model(input_ids, attention_mask)\n",
542
+ " \n",
543
+ " loss, dist_loss, student_loss = distillation_criterion(\n",
544
+ " student_logits, teacher_logits, labels\n",
545
+ " )\n",
546
+ " \n",
547
+ " optimizer.zero_grad()\n",
548
+ " loss.backward()\n",
549
+ " torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)\n",
550
+ " optimizer.step()\n",
551
+ " scheduler.step()\n",
552
+ " \n",
553
+ " total_loss += loss.item()\n",
554
+ " \n",
555
+ " progress_bar.set_postfix({\n",
556
+ " 'Loss': f'{loss.item():.4f}',\n",
557
+ " 'Dist': f'{dist_loss.item():.4f}',\n",
558
+ " })\n",
559
+ " \n",
560
+ " avg_train_loss = total_loss / len(train_loader)\n",
561
+ " train_losses.append(avg_train_loss)\n",
562
+ " \n",
563
+ " val_accuracy, _ = evaluate_model(student_model, val_loader, device, is_distilbert=(name=='DistilBERT'))\n",
564
+ " val_accuracies.append(val_accuracy)\n",
565
+ " \n",
566
+ " print(f'Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')\n",
567
+ " \n",
568
+ " return train_losses, val_accuracies\n",
569
+ "\n",
570
+ "# --- 5. Evaluation, Checkpoint Loading, and Plotting (load_teacher_checkpoint, evaluate_model, plot_roc_comparison functions remain the same) ---\n",
571
+ "def load_teacher_checkpoint(checkpoint_path, n_classes):\n",
572
+ " print(\"Loading teacher checkpoint...\")\n",
573
+ " \n",
574
+ " teacher_model = TeacherBERT(n_classes=n_classes)\n",
575
+ " teacher_model.to(DEVICE)\n",
576
+ " \n",
577
+ " try:\n",
578
+ " teacher_ckpt = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)\n",
579
+ " \n",
580
+ " if 'model_state_dict' in teacher_ckpt:\n",
581
+ " state_dict = teacher_ckpt['model_state_dict']\n",
582
+ " \n",
583
+ " new_state_dict = {}\n",
584
+ " for k, v in state_dict.items():\n",
585
+ " if k.startswith('module.'):\n",
586
+ " k = k[7:]\n",
587
+ " new_state_dict[k] = v\n",
588
+ " \n",
589
+ " teacher_model.load_state_dict(new_state_dict)\n",
590
+ " \n",
591
+ " print(\"✓ Teacher weights loaded successfully!\")\n",
592
+ " return teacher_model, teacher_ckpt.get('label_encoder'), teacher_ckpt.get('config')\n",
593
+ " else:\n",
594
+ " print(\"✗ Checkpoint file is missing 'model_state_dict'. Cannot load weights.\")\n",
595
+ " return None, None, None\n",
596
+ " \n",
597
+ " except Exception as e:\n",
598
+ " print(f\"✗ Failed to load teacher checkpoint: {e}\")\n",
599
+ " print(\"Using randomly initialized BERT Teacher. Expect poor performance.\")\n",
600
+ " return teacher_model, None, None\n",
601
+ "\n",
602
+ "def evaluate_model(model, dataloader, device, return_probs=False, is_distilbert=False):\n",
603
+ " model.eval()\n",
604
+ " all_preds = []\n",
605
+ " all_labels = []\n",
606
+ " all_probs = []\n",
607
+ " total_loss = 0\n",
608
+ " criterion = nn.CrossEntropyLoss()\n",
609
+ " \n",
610
+ " with torch.no_grad():\n",
611
+ " for batch in tqdm(dataloader, desc=\"Evaluating\"):\n",
612
+ " input_ids = batch['input_ids'].to(device)\n",
613
+ " attention_mask = batch['attention_mask'].to(device)\n",
614
+ " labels = batch['label'].to(device)\n",
615
+ " \n",
616
+ " if is_distilbert and isinstance(model, StudentDistilBERT):\n",
617
+ " logits, _ = model(input_ids, attention_mask)\n",
618
+ " else:\n",
619
+ " logits = model(input_ids, attention_mask)\n",
620
+ " \n",
621
+ " loss = criterion(logits, labels)\n",
622
+ " total_loss += loss.item()\n",
623
+ " \n",
624
+ " probs = F.softmax(logits, dim=1)\n",
625
+ " preds = torch.argmax(logits, dim=1)\n",
626
+ " \n",
627
+ " all_preds.extend(preds.cpu().numpy())\n",
628
+ " all_labels.extend(labels.cpu().numpy())\n",
629
+ " all_probs.extend(probs.cpu().numpy())\n",
630
+ " \n",
631
+ " accuracy = accuracy_score(all_labels, all_preds)\n",
632
+ " avg_loss = total_loss / len(dataloader)\n",
633
+ " \n",
634
+ " if return_probs:\n",
635
+ " return accuracy, avg_loss, all_labels, all_preds, all_probs\n",
636
+ " return accuracy, avg_loss\n",
637
+ "\n",
638
+ "def plot_roc_comparison(models_data, save_path=None):\n",
639
+ " plt.figure(figsize=(12, 8))\n",
640
+ " colors = ['blue', 'red', 'green', 'orange']\n",
641
+ " \n",
642
+ " for i, (name, labels, probs) in enumerate(models_data):\n",
643
+ " y_score = np.array(probs)[:, 1] if len(probs[0]) > 1 else np.array(probs)\n",
644
+ " \n",
645
+ " fpr, tpr, _ = roc_curve(labels, y_score)\n",
646
+ " roc_auc = auc(fpr, tpr)\n",
647
+ " \n",
648
+ " plt.plot(fpr, tpr, color=colors[i % len(colors)], lw=2,\n",
649
+ " label=f'{name} (AUC = {roc_auc:.4f})')\n",
650
+ " \n",
651
+ " plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', alpha=0.5)\n",
652
+ " plt.xlim([0.0, 1.0])\n",
653
+ " plt.ylim([0.0, 1.05])\n",
654
+ " plt.xlabel('False Positive Rate', fontsize=12)\n",
655
+ " plt.ylabel('True Positive Rate', fontsize=12)\n",
656
+ " plt.title('ROC Curve Comparison: Teacher vs Student Models', fontsize=14, fontweight='bold')\n",
657
+ " plt.legend(loc=\"lower right\", fontsize=11)\n",
658
+ " plt.grid(True, alpha=0.3)\n",
659
+ " \n",
660
+ " if save_path:\n",
661
+ " plt.savefig(save_path, dpi=300, bbox_inches='tight')\n",
662
+ " plt.show()\n",
663
+ "\n",
664
+ "# --- 6. Main Distillation Pipeline ---\n",
665
+ "\n",
666
+ "def main_distillation_pipeline():\n",
667
+ " print(\"=\" * 80)\n",
668
+ " print(\"BERT → DistilBERT + MLP Knowledge Distillation Pipeline\")\n",
669
+ " print(\"CSIC 2010 Web Application Attacks Dataset\")\n",
670
+ " print(\"=\" * 80)\n",
671
+ " \n",
672
+ " # Load and preprocess data\n",
673
+ " texts, encoded_labels, label_encoder = load_csic_dataset(DATASET_PATH)\n",
674
+ " if texts is None: return\n",
675
+ "\n",
676
+ " num_classes = len(label_encoder.classes_)\n",
677
+ " \n",
678
+ " # --- CRITICAL CHANGE: Use ALL samples, remove subsetting logic ---\n",
679
+ " # The splitting will now use the entire dataset (approx. 61k samples)\n",
680
+ " # Split: 60% Train, 20% Val, 20% Test\n",
681
+ " \n",
682
+ " X_train, X_test, y_train, y_test = train_test_split(\n",
683
+ " texts, encoded_labels, test_size=0.2, random_state=42, stratify=encoded_labels\n",
684
+ " )\n",
685
+ " X_train, X_val, y_train, y_val = train_test_split(\n",
686
+ " X_train, y_train, test_size=0.25, random_state=42, stratify=y_train \n",
687
+ " )\n",
688
+ " \n",
689
+ " print(f\"Data Splits - Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}\")\n",
690
+ " \n",
691
+ " # Initialize tokenizers and DataLoaders\n",
692
+ " tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
693
+ " \n",
694
+ " train_dataset = TextDataset(X_train, y_train, tokenizer, MAX_LENGTH)\n",
695
+ " val_loader = DataLoader(TextDataset(X_val, y_val, tokenizer, MAX_LENGTH), batch_size=BATCH_SIZE, shuffle=False)\n",
696
+ " test_loader = DataLoader(TextDataset(X_test, y_test, tokenizer, MAX_LENGTH), batch_size=BATCH_SIZE, shuffle=False)\n",
697
+ " \n",
698
+ " # --- Step 1: Initialize and Load Teacher Model ---\n",
699
+ " teacher_model, _, _ = load_teacher_checkpoint(CHECKPOINT_PATH, num_classes)\n",
700
+ " if teacher_model is None:\n",
701
+ " print(\"Cannot proceed with distillation without a valid Teacher model.\")\n",
702
+ " return\n",
703
+ "\n",
704
+ " # --- Step 2: Initialize Student Models ---\n",
705
+ " student_distilbert = StudentDistilBERT(num_classes=num_classes)\n",
706
+ " student_mlp = StudentMLP(num_classes=num_classes, vocab_size=tokenizer.vocab_size)\n",
707
+ " \n",
708
+ " print(f\"\\nModel Parameters (Teacher: {sum(p.numel() for p in teacher_model.parameters()):,} | DistilBERT: {sum(p.numel() for p in student_distilbert.parameters()):,} | MLP: {sum(p.numel() for p in student_mlp.parameters()):,})\")\n",
709
+ "\n",
710
+ " # --- Step 3: Train Students ---\n",
711
+ " train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
712
+ " distilbert_metrics = train_student_model(student_distilbert, teacher_model, train_loader, val_loader, DEVICE, 'DistilBERT')\n",
713
+ " mlp_metrics = train_student_model(student_mlp, teacher_model, train_loader, val_loader, DEVICE, 'MLP')\n",
714
+ " \n",
715
+ " # --- Step 4: Final Evaluation and Model Saving ---\n",
716
+ " print(\"\\n\" + \"=\"*80)\n",
717
+ " print(\"FINAL EVALUATION on Test Set & BEST MODEL SAVING\")\n",
718
+ " print(\"=\"*80)\n",
719
+ " \n",
720
+ " models = {\n",
721
+ " 'Teacher BERT (Corrected)': teacher_model,\n",
722
+ " 'Student DistilBERT': student_distilbert,\n",
723
+ " 'Student MLP': student_mlp\n",
724
+ " }\n",
725
+ " \n",
726
+ " best_f1 = -1\n",
727
+ " best_model_name = \"\"\n",
728
+ " best_model_data = None\n",
729
+ " models_roc_data = []\n",
730
+ "\n",
731
+ " for name, model in models.items():\n",
732
+ " is_distilbert_eval = (\"DistilBERT\" in name)\n",
733
+ " \n",
734
+ " accuracy, loss, labels, preds, probs = evaluate_model(\n",
735
+ " model, test_loader, DEVICE, return_probs=True, is_distilbert=is_distilbert_eval\n",
736
+ " )\n",
737
+ " \n",
738
+ " models_roc_data.append((name, labels, probs))\n",
739
+ " f1w = f1_score(labels, preds, average='weighted')\n",
740
+ " \n",
741
+ " print(f\"\\n{name} - Accuracy: {accuracy:.4f}, Weighted F1: {f1w:.4f}, Loss: {loss:.4f}\")\n",
742
+ " print(classification_report(labels, preds, target_names=label_encoder.classes_.astype(str)))\n",
743
+ "\n",
744
+ " if \"Teacher\" not in name and f1w > best_f1:\n",
745
+ " best_f1 = f1w\n",
746
+ " best_model_name = name\n",
747
+ " best_model_data = {\n",
748
+ " 'model_state_dict': model.state_dict(),\n",
749
+ " 'label_encoder': label_encoder,\n",
750
+ " 'config': {'max_length': MAX_LENGTH, 'batch_size': BATCH_SIZE},\n",
751
+ " 'test_accuracy': accuracy,\n",
752
+ " 'f1_score': f1w,\n",
753
+ " 'model_architecture': name\n",
754
+ " }\n",
755
+ "\n",
756
+ " if best_model_data:\n",
757
+ " torch.save(best_model_data, BEST_MODEL_PATH)\n",
758
+ " print(\"\\n\" + \"=\"*80)\n",
759
+ " print(f\"✅ FINAL DEPLOYMENT MODEL SAVED: {best_model_name} (F1: {best_f1:.4f})\")\n",
760
+ " print(f\"File: {BEST_MODEL_PATH}\")\n",
761
+ " print(\"=\"*80)\n",
762
+ " else:\n",
763
+ " print(\"\\n❌ Could not save the best model.\")\n",
764
+ "\n",
765
+ "\n",
766
+ " # --- Step 5: Visualization ---\n",
767
+ " print(\"\\n--- Visualizing ROC Curves ---\")\n",
768
+ " plot_roc_comparison(models_roc_data, os.path.join(OUTPUT_DIR, 'roc_comparison.png'))\n",
769
+ " \n",
770
+ " print(\"\\n✓ Cybersecurity knowledge distillation pipeline completed successfully!\")\n",
771
+ " \n",
772
+ " # --- Step 6: Trigger XAI Agent ---\n",
773
+ " if best_model_data and 'DistilBERT' in best_model_data.get('model_architecture', ''):\n",
774
+ " main_xai_agent(X_test, y_test, label_encoder.classes_.tolist())\n",
775
+ " else:\n",
776
+ " print(\"\\nSkipping XAI Agent: The best model was not DistilBERT or data was unavailable.\")\n",
777
+ "\n",
778
+ "# --- 7. XAI Core Functions (Extracted from XAI Agent - functions remain the same) ---\n",
779
+ "\n",
780
+ "def extract_features(model, dataloader, device) -> np.ndarray:\n",
781
+ " model.eval()\n",
782
+ " all_features = []\n",
783
+ " \n",
784
+ " with torch.no_grad():\n",
785
+ " for batch in tqdm(dataloader, desc=\"Extracting Features for XAI\"):\n",
786
+ " input_ids = batch['input_ids'].to(device)\n",
787
+ " attention_mask = batch['attention_mask'].to(device)\n",
788
+ " \n",
789
+ " _, features = model(input_ids, attention_mask)\n",
790
+ " all_features.append(features.cpu().numpy())\n",
791
+ " \n",
792
+ " return np.concatenate(all_features, axis=0)\n",
793
+ "\n",
794
+ "def generate_xai_rules(X_features: np.ndarray, y_labels: np.ndarray, feature_names: list, class_names: list) -> str:\n",
795
+ " print(\"\\nTraining Decision Tree Surrogate Model...\")\n",
796
+ " \n",
797
+ " dt_model = DecisionTreeClassifier(max_depth=XAI_MAX_DEPTH, random_state=42)\n",
798
+ " dt_model.fit(X_features, y_labels)\n",
799
+ " \n",
800
+ " dt_preds = dt_model.predict(X_features)\n",
801
+ " dt_acc = accuracy_score(y_labels, dt_preds)\n",
802
+ " print(f\"Decision Tree (Surrogate) Accuracy on Extracted Features: {dt_acc:.4f}\")\n",
803
+ " \n",
804
+ " rules = export_text(\n",
805
+ " dt_model, \n",
806
+ " feature_names=feature_names, \n",
807
+ " class_names=class_names\n",
808
+ " )\n",
809
+ " return rules\n",
810
+ "\n",
811
+ "def main_xai_agent(X_test, y_test, class_names_list):\n",
812
+ " print(\"\\n\" + \"=\"*80)\n",
813
+ " print(\"XAI AGENT: Rule Generation for Adaptive WAF (Surrogate Model)\")\n",
814
+ " print(\"=\"*80)\n",
815
+ " \n",
816
+ " checkpoint = torch.load(BEST_MODEL_PATH, map_location=DEVICE)\n",
817
+ " num_classes = len(class_names_list)\n",
818
+ " \n",
819
+ " model = StudentDistilBERT(num_classes=num_classes).to(DEVICE)\n",
820
+ " model.load_state_dict(checkpoint['model_state_dict'])\n",
821
+ " \n",
822
+ " print(f\"Loading '{checkpoint.get('model_architecture')}' with F1-Score: {checkpoint['f1_score']:.4f} for XAI...\")\n",
823
+ "\n",
824
+ " tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
825
+ " test_dataset = TextDataset(X_test, y_test, tokenizer, MAX_LENGTH)\n",
826
+ " test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)\n",
827
+ "\n",
828
+ " X_features = extract_features(model, test_loader, DEVICE)\n",
829
+ " feature_names = [f'CLS_Dim_{i}' for i in range(X_features.shape[1])]\n",
830
+ " \n",
831
+ " print(f\"Features extracted! Shape: {X_features.shape}\")\n",
832
+ "\n",
833
+ " xai_rules = generate_xai_rules(X_features, y_test, feature_names, class_names_list)\n",
834
+ " \n",
835
+ " rules_path = os.path.join(OUTPUT_DIR, 'waf_xai_rules.txt')\n",
836
+ " with open(rules_path, 'w') as f:\n",
837
+ " f.write(xai_rules)\n",
838
+ " \n",
839
+ " print(\"\\n\" + \"=\"*80)\n",
840
+ " print(\"✅ XAI RULE GENERATION COMPLETE\")\n",
841
+ " print(f\"Rules saved to: {rules_path}\")\n",
842
+ " print(\"Sample Rules (Decision Tree Surrogate):\")\n",
843
+ " print(\"=\"*80)\n",
844
+ " print('\\n'.join(xai_rules.split('\\n')[:15]))\n",
845
+ " print(\"... (Rules Truncated) ...\")\n",
846
+ " print(\"=\"*80)\n",
847
+ "\n",
848
+ "if __name__ == \"__main__\":\n",
849
+ " main_distillation_pipeline()"
850
+ ]
851
+ }
852
+ ],
853
+ "metadata": {
854
+ "language_info": {
855
+ "name": "python"
856
+ }
857
+ },
858
+ "nbformat": 4,
859
+ "nbformat_minor": 5
860
+ }
desktop.ini ADDED
Binary file (282 Bytes). View file