khalooei
commited on
Commit
·
b2419d7
1
Parent(s):
ccadb41
update app
Browse files
app.py
CHANGED
|
@@ -178,15 +178,17 @@ def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, nu
|
|
| 178 |
logs = ["BSM:: experiment is being started ..."]
|
| 179 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 180 |
|
|
|
|
| 181 |
dataset, _ = get_dataset_and_transform(dataset_name)
|
| 182 |
testloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)
|
| 183 |
-
logs.append(f"{dataset_name} dataset loaded")
|
| 184 |
|
|
|
|
| 185 |
model = initialize_model(model_name, device)
|
| 186 |
-
logs.append(f"Model {model_name}
|
| 187 |
|
| 188 |
param_count, layer_count = get_model_stats(model)
|
| 189 |
-
logs.append(f"Model stats: {param_count}
|
| 190 |
|
| 191 |
all_attacks = {
|
| 192 |
'FGSM': FGSM(model, eps=0.03),
|
|
@@ -204,21 +206,27 @@ def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, nu
|
|
| 204 |
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 205 |
output_dir = os.path.join(output_dir_base, f"{model_name}_{timestamp}")
|
| 206 |
os.makedirs(output_dir, exist_ok=True)
|
| 207 |
-
logs.append(f"Output directory: {output_dir}")
|
| 208 |
|
| 209 |
results = {atk: {'cm': [], 'mvl': []} for atk in attacks}
|
| 210 |
|
| 211 |
for i, (images, labels) in enumerate(testloader):
|
| 212 |
if i >= num_batches:
|
|
|
|
| 213 |
break
|
| 214 |
images, labels = images.to(device), labels.to(device)
|
| 215 |
logs.append(f"Processing batch {i+1}/{num_batches}...")
|
| 216 |
|
| 217 |
for atk_name, atk in attacks.items():
|
|
|
|
| 218 |
adv_images = atk(images, labels)
|
| 219 |
mvl_vals = compute_mvl(model, images, adv_images, device)
|
| 220 |
results[atk_name]['mvl'].append(mvl_vals)
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
cm_means = {atk: np.mean(results[atk]['cm']) for atk in attacks}
|
| 224 |
cm_stds = {atk: np.std(results[atk]['cm']) for atk in attacks}
|
|
@@ -279,6 +287,7 @@ def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, nu
|
|
| 279 |
logs.append(f"Saved integrated MVL plot: {integrated_mvl_plot_path}")
|
| 280 |
|
| 281 |
processing_time = time.time() - start_time
|
|
|
|
| 282 |
|
| 283 |
stats = {
|
| 284 |
'Dataset': dataset_name,
|
|
@@ -309,7 +318,7 @@ paper_info_html = """
|
|
| 309 |
<div style="border: 1px solid #ccc; padding: 15px; border-radius: 8px; margin-bottom: 15px;">
|
| 310 |
<h2>Layer-wise Regularized Adversarial Training Using Layers Sustainability Analysis Framework</h2>
|
| 311 |
<h3>Authors</h3>
|
| 312 |
-
<p>Mohammad Khalooei, Mohammad Mehdi
|
| 313 |
|
| 314 |
<h3>Abstract</h3>
|
| 315 |
<ul>
|
|
@@ -348,7 +357,7 @@ def create_interface():
|
|
| 348 |
model_text = gr.Textbox(value="LeNet", visible=False, interactive=False, label="Model")
|
| 349 |
|
| 350 |
attack_input = gr.CheckboxGroup(choices=attacks, label="Select Attacks", value=attacks)
|
| 351 |
-
batch_input = gr.Slider(minimum=1, maximum=20, step=1, value=
|
| 352 |
run_button = gr.Button("Run Analysis")
|
| 353 |
|
| 354 |
error_output = gr.Textbox(label="Error", visible=False)
|
|
|
|
| 178 |
logs = ["BSM:: experiment is being started ..."]
|
| 179 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 180 |
|
| 181 |
+
logs.append(f"Loading {dataset_name} dataset...")
|
| 182 |
dataset, _ = get_dataset_and_transform(dataset_name)
|
| 183 |
testloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)
|
| 184 |
+
logs.append(f"{dataset_name} dataset loaded with {len(testloader)} batches.")
|
| 185 |
|
| 186 |
+
logs.append(f"Initializing model {model_name} on {device}...")
|
| 187 |
model = initialize_model(model_name, device)
|
| 188 |
+
logs.append(f"Model {model_name} initialized.")
|
| 189 |
|
| 190 |
param_count, layer_count = get_model_stats(model)
|
| 191 |
+
logs.append(f"Model stats: Parameters = {param_count}, Layers = {layer_count}")
|
| 192 |
|
| 193 |
all_attacks = {
|
| 194 |
'FGSM': FGSM(model, eps=0.03),
|
|
|
|
| 206 |
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 207 |
output_dir = os.path.join(output_dir_base, f"{model_name}_{timestamp}")
|
| 208 |
os.makedirs(output_dir, exist_ok=True)
|
| 209 |
+
logs.append(f"Output directory created: {output_dir}")
|
| 210 |
|
| 211 |
results = {atk: {'cm': [], 'mvl': []} for atk in attacks}
|
| 212 |
|
| 213 |
for i, (images, labels) in enumerate(testloader):
|
| 214 |
if i >= num_batches:
|
| 215 |
+
logs.append(f"Reached batch limit: {num_batches}")
|
| 216 |
break
|
| 217 |
images, labels = images.to(device), labels.to(device)
|
| 218 |
logs.append(f"Processing batch {i+1}/{num_batches}...")
|
| 219 |
|
| 220 |
for atk_name, atk in attacks.items():
|
| 221 |
+
logs.append(f" Running attack: {atk_name} on batch {i+1}")
|
| 222 |
adv_images = atk(images, labels)
|
| 223 |
mvl_vals = compute_mvl(model, images, adv_images, device)
|
| 224 |
results[atk_name]['mvl'].append(mvl_vals)
|
| 225 |
+
batch_cm = np.mean(mvl_vals)
|
| 226 |
+
results[atk_name]['cm'].append(batch_cm)
|
| 227 |
+
logs.append(f" Attack {atk_name}: batch CM={batch_cm:.6f}")
|
| 228 |
+
|
| 229 |
+
logs.append("Finished processing batches, computing statistics...")
|
| 230 |
|
| 231 |
cm_means = {atk: np.mean(results[atk]['cm']) for atk in attacks}
|
| 232 |
cm_stds = {atk: np.std(results[atk]['cm']) for atk in attacks}
|
|
|
|
| 287 |
logs.append(f"Saved integrated MVL plot: {integrated_mvl_plot_path}")
|
| 288 |
|
| 289 |
processing_time = time.time() - start_time
|
| 290 |
+
logs.append(f"Processing completed in {processing_time:.2f} seconds")
|
| 291 |
|
| 292 |
stats = {
|
| 293 |
'Dataset': dataset_name,
|
|
|
|
| 318 |
<div style="border: 1px solid #ccc; padding: 15px; border-radius: 8px; margin-bottom: 15px;">
|
| 319 |
<h2>Layer-wise Regularized Adversarial Training Using Layers Sustainability Analysis Framework</h2>
|
| 320 |
<h3>Authors</h3>
|
| 321 |
+
<p>Mohammad Khalooei, Mohammad Mehdi Homayounpour, Maryam Amirmazlaghani</p>
|
| 322 |
|
| 323 |
<h3>Abstract</h3>
|
| 324 |
<ul>
|
|
|
|
| 357 |
model_text = gr.Textbox(value="LeNet", visible=False, interactive=False, label="Model")
|
| 358 |
|
| 359 |
attack_input = gr.CheckboxGroup(choices=attacks, label="Select Attacks", value=attacks)
|
| 360 |
+
batch_input = gr.Slider(minimum=1, maximum=20, step=1, value=2, label="Number of Batches")
|
| 361 |
run_button = gr.Button("Run Analysis")
|
| 362 |
|
| 363 |
error_output = gr.Textbox(label="Error", visible=False)
|