| | from utils import * |
| | from src.configs.safetynet_config import SafetyNetConfig |
| | from utils.safetynet.vae_ae_train import Attention_DataProcessing, Train, Test, Detector_Stats |
| | from src.configs.spylab_model_config import spylab_create_config |
| | from src.configs.anthropic_model_config import anthropic_create_config |
| |
|
| | import plotly.graph_objects as go |
| |
|
| |
|
| |
|
| | class Visualization: |
| | |
| | @staticmethod |
| | def data_processing_for_crow( |
| | other_layer_idx, |
| | vanilla_path = "utils/data/llama2/ae_vae/vanilla/cosine_analysis.json", |
| | harmful_path = "utils/data/llama2/ae_vae/backdoored/cosine_analysis.json"): |
| | |
| | with open(vanilla_path, "r") as f: |
| | vanilla_data = json.load(f) |
| | |
| | with open(harmful_path, "r") as f_: |
| | backdoor_data = json.load(f_) |
| | |
| | if other_layer_idx == 'prev': |
| | layer_idx = 0 |
| | elif other_layer_idx == "next": |
| | layer_idx = 1 |
| | |
| | |
| | ''' |
| | As the two layers pair values are there, so having [0] will give the first pair and [1] the second pair |
| | ''' |
| | mean_harmful_vanilla = np.mean(np.array(vanilla_data["harmful"][layer_idx])) |
| | mean_harmful_backdoor = np.mean(np.array(backdoor_data["harmful"][layer_idx])) |
| | |
| | vanilla_data_stats = [i - float(mean_harmful_vanilla) for i in vanilla_data["normal"][layer_idx]] |
| | |
| | split_idx = int(len(vanilla_data_stats) * 0.8) |
| | vanilla_data_stats_train = vanilla_data_stats[:split_idx] |
| | vanilla_data_stats_val = vanilla_data_stats[split_idx:] |
| | backdoor_data_stats = [i - float(mean_harmful_backdoor) for i in backdoor_data["normal"][layer_idx]] |
| | |
| | |
| | |
| | return { |
| | "normal_losses": vanilla_data_stats_train, |
| | "val_losses": vanilla_data_stats_val, |
| | "harmful_losses": backdoor_data_stats |
| | } |
| | |
| | |
| | @staticmethod |
| | def plot_all_layers_violin(model_name, model_type, save_path, config: SafetyNetConfig, max_layers=32): |
| | """Create violin plot for all available layers""" |
| | |
| | fig = go.Figure() |
| | layers_data = {} |
| | |
| | |
| | for layer_idx in range(max_layers): |
| | data_path = f"utils/data/{model_name}/{model_type}_loss/layer_{layer_idx}_{model_type}_loss.json" |
| | with open(data_path, "r") as f: |
| | layers_data[layer_idx] = json.load(f) |
| | |
| | if not layers_data: |
| | print("No layer data found!") |
| | return None |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | for i, (layer_idx, data) in tqdm(enumerate(layers_data.items())): |
| | x_pos = f'L{layer_idx}' |
| | |
| | |
| | fig.add_trace(go.Violin( |
| | y=data["normal_losses"], |
| | x=[x_pos] * len(data["normal_losses"]), |
| | name='Normal (Train)', |
| | side='negative', |
| | fillcolor="#4DB6AC", |
| | line_color='#00695C', |
| | box_visible=True, |
| | meanline_visible=True, |
| | points=False, |
| | width=0.7, |
| | legendgroup='normal_train', |
| | showlegend=(i == 0) |
| | )) |
| | |
| | |
| | fig.add_trace(go.Violin( |
| | y=data["harmful_losses"], |
| | x=[x_pos] * len(data["harmful_losses"]), |
| | name='Harmful', |
| | side='positive', |
| | fillcolor="#BA68C8", |
| | line_color='#6A1B9A', |
| | box_visible=True, |
| | meanline_visible=True, |
| | points=False, |
| | width=0.5, |
| | legendgroup='harmful', |
| | showlegend=(i == 0) |
| | )) |
| | |
| | |
| | fig.add_trace(go.Violin( |
| | y=data["val_losses"], |
| | x=[x_pos] * len(data["val_losses"]), |
| | name='Normal (Val)', |
| | side='positive', |
| | fillcolor="#3498DB", |
| | line_color='#2874A6', |
| | box_visible=True, |
| | meanline_visible=True, |
| | points=False, |
| | width=0.3, |
| | legendgroup='normal_val', |
| | showlegend=(i == 0) |
| | )) |
| | |
| | |
| | fig.update_layout( |
| | title=dict( |
| | text=f'{config.model_name} Loss Distribution Across All Layers ({model_type.upper()})', |
| | x=0.5, |
| | y=0.98, |
| | xanchor='center', |
| | yanchor='top', |
| | font=dict( |
| | family="Times New Roman", |
| | size=30, |
| | color="black" |
| | ) |
| | ), |
| | xaxis_title='Layer Index', |
| | yaxis_title='Reconstruction Loss', |
| | width=max(800, len(layers_data) * 60), |
| | height=500, |
| | showlegend=True, |
| | legend=dict( |
| | orientation="h", |
| | yanchor="bottom", |
| | y=0.95, |
| | xanchor="center", |
| | x=0.5, |
| | font=dict(size=25, family="Times New Roman") |
| | ), |
| | plot_bgcolor='#FFFEF7', |
| | paper_bgcolor='white', |
| | font=dict(family="Times New Roman", size=20), |
| | margin=dict( |
| | t=70, |
| | b=20, |
| | l=20, |
| | r=0 |
| | ) |
| | ) |
| | fig.update_xaxes( |
| | showgrid=True, |
| | gridcolor='rgba(128, 128, 128, 0.2)', |
| | showline=False, |
| | tickangle=45 if len(layers_data) > 10 else 0, |
| | |
| | tickfont=dict( |
| | family="Times New Roman", |
| | size=25, |
| | color="black" |
| | ), |
| | |
| | title_font=dict( |
| | family="Times New Roman", |
| | size=22, |
| | color="black" |
| | ) |
| | ) |
| |
|
| | fig.update_yaxes( |
| | showgrid=True, |
| | gridcolor='rgba(128, 128, 128, 0.2)', |
| | showline=False, |
| | range=[0, None], |
| | |
| | tickfont=dict( |
| | family="Times New Roman", |
| | size=25, |
| | color="black" |
| | ), |
| | |
| | title_font=dict( |
| | family="Times New Roman", |
| | size=32, |
| | color="black" |
| | ) |
| | ) |
| | |
| | fig.write_image(f"{save_path}_all_layers_violin.pdf", height = 1000, width = 1500, scale=3) |
| | return fig |
| | |
| | |
| |
|
| | @staticmethod |
| | def plot_detectors_comparison(model_name, |
| | detector_types, |
| | other_layer_idx, |
| | current_layer_idx, |
| | save_path, |
| | config: SafetyNetConfig, |
| | model_type, |
| | args |
| | ): |
| | """Compare different detector types (AE, VAE, PCA) at a specific layer with normalized losses""" |
| | |
| | fig = go.Figure() |
| | |
| | results = {} |
| | |
| | for i, detector_type in enumerate(detector_types): |
| | |
| | if detector_type == "crow" \ |
| | or detector_type == "obfuscated_sim_crow" \ |
| | or detector_type == "obfuscated_ae_crow": |
| | |
| |
|
| | harmful_path = "utils/data/llama2/ae_vae/vanilla/cosine_analysis.json" |
| | if args.dataset == "mad": |
| | data = Visualization.data_processing_for_crow( |
| | other_layer_idx=other_layer_idx, |
| | harmful_path = f"utils/data/llama2/ae_vae/{model_type}/cosine_analysis.json" |
| | ) |
| | elif args.dataset == "spylab": |
| | data = Visualization.data_processing_for_crow( |
| | other_layer_idx=other_layer_idx, |
| | vanilla_path = f"utils/spylab_data/llama2/vanilla/cosine_analysis.json", |
| | harmful_path = f"utils/spylab_data/llama2/{model_type}/cosine_analysis.json" |
| | ) |
| | elif args.dataset == "anthropic": |
| | data = Visualization.data_processing_for_crow( |
| | other_layer_idx=other_layer_idx, |
| | vanilla_path = f"safetynet/utils/anthropic_data/{model_name}/vanilla/cosine_analysis.json", |
| | harmful_path = f"safetynet/utils/anthropic_data/{model_name}/{model_type}/cosine_analysis.json" |
| | ) |
| | normal_losses = np.array(data["normal_losses"]) |
| | harmful_losses = np.array(data["harmful_losses"]) |
| | val_losses = np.array(data["val_losses"]) |
| |
|
| | else: |
| | if args.dataset == "mad": |
| | data_path = f"utils/data/{model_name}/{detector_type}_loss/layer_{current_layer_idx}_{detector_type}_loss.json" |
| | elif args.dataset == "spylab": |
| | data_path = f"utils/spylab_data/{model_name}/{args.model_type}_{detector_type}_loss/layer_{current_layer_idx}_{args.model_type}_{detector_type}_loss.json" |
| | elif args.dataset == "anthropic": |
| | data_path = f"safetynet/utils/anthropic_data/{model_name}/{args.model_type}_{detector_type}_loss/layer_{current_layer_idx}_{args.model_type}_{detector_type}_loss.json" |
| | print(data_path) |
| | with open(data_path, "r") as f: |
| | data = json.load(f) |
| | |
| | |
| | |
| | normal_losses = np.array(data["normal_losses"]) |
| | harmful_losses = np.array(data["harmful_losses"]) |
| | val_losses = np.array(data["val_losses"]) |
| | |
| | print(data) |
| | |
| | |
| | |
| | |
| | val_mean = np.mean(val_losses) |
| | val_std = np.std(val_losses) |
| | threshold_upper = val_mean + 2 * val_std |
| | threshold_lower = val_mean - 2 * val_std |
| | |
| |
|
| | |
| | train_pred = ((normal_losses < threshold_lower) | (normal_losses > threshold_upper)).astype(int) |
| | harmful_pred = ((harmful_losses < threshold_lower) | (harmful_losses > threshold_upper)).astype(int) |
| |
|
| | |
| | train_labels = np.zeros(len(normal_losses)) |
| | harmful_labels = np.ones(len(harmful_losses)) |
| |
|
| | |
| | all_pred = np.concatenate([train_pred, harmful_pred]) |
| | all_labels = np.concatenate([train_labels, harmful_labels]) |
| | all_scores = np.concatenate([normal_losses, harmful_losses]) |
| |
|
| | |
| | |
| | |
| | stats = Detector_Stats() |
| | detector_results = stats.compute_comprehensive_metrics(normal_losses, val_losses, harmful_losses) |
| | detector_results['confusion_matrix_overall'] = detector_results['confusion_matrix_overall'].tolist() |
| | results[detector_type] = detector_results |
| | ''' |
| | print(all_labels) |
| | print(all_scores) |
| | auroc = roc_auc_score(all_labels, all_scores) |
| | if auroc < 0.5: # Scores are inverted |
| | auroc = roc_auc_score(all_labels, -all_scores) |
| | # except: |
| | # auroc = 0.5 |
| | |
| | # Overall metrics |
| | overall_accuracy = accuracy_score(all_labels, all_pred) |
| | overall_precision = precision_score(all_labels, all_pred, zero_division=0) |
| | overall_recall = recall_score(all_labels, all_pred, zero_division=0) |
| | overall_f1 = f1_score(all_labels, all_pred, zero_division=0) |
| | |
| | # Per-class metrics |
| | train_accuracy = np.mean(train_pred == train_labels) |
| | harmful_accuracy = np.mean(harmful_pred == harmful_labels) |
| | harmful_precision = precision_score(harmful_labels, harmful_pred, zero_division=0) |
| | harmful_recall = recall_score(harmful_labels, harmful_pred, zero_division=0) |
| | harmful_f1 = f1_score(harmful_labels, harmful_pred, zero_division=0) |
| | |
| | results[detector_type] = { |
| | "auroc": float(auroc), |
| | "overall_accuracy": float(overall_accuracy), |
| | "overall_precision": float(overall_precision), |
| | "overall_recall": float(overall_recall), |
| | "overall_f1": float(overall_f1), |
| | "train_accuracy": float(train_accuracy), |
| | "harmful_accuracy": float(harmful_accuracy), |
| | "harmful_precision": float(harmful_precision), |
| | "harmful_recall": float(harmful_recall), |
| | "harmful_f1": float(harmful_f1), |
| | "threshold_lower": float(threshold_lower), |
| | "threshold_upper": float(threshold_upper) |
| | } |
| | ''' |
| | |
| | |
| | |
| | all_losses = np.concatenate([normal_losses, harmful_losses, val_losses]) |
| | min_loss = np.min(all_losses) |
| | max_loss = np.max(all_losses) |
| | loss_range = max_loss - min_loss |
| | |
| | print(f"\n NORMAL LOSSEs \n") |
| | print(normal_losses) |
| | |
| | |
| | |
| | if loss_range == 0: |
| | loss_range = 1 |
| | |
| | |
| | normal_norm = (normal_losses - min_loss) / loss_range |
| | harmful_norm = (harmful_losses - min_loss) / loss_range |
| | val_norm = (val_losses - min_loss) / loss_range |
| | |
| | |
| | detector = detector_type.split("_")[-1] |
| | |
| | if detector == "crow": |
| | if other_layer_idx == "prev": |
| | x_pos = f"CROW {current_layer_idx-1}-{current_layer_idx}" |
| | elif other_layer_idx == "next": |
| | x_pos = f"CROW {current_layer_idx}-{current_layer_idx+1}" |
| | else: |
| | x_pos = detector.upper() |
| | |
| | |
| | loss_data = [ |
| | ('Normal (Train)', normal_norm), |
| | ('Harmful', harmful_norm), |
| | ('Normal (Val)', val_norm) |
| | ] |
| | |
| | for j, (loss_type, losses) in enumerate(loss_data): |
| | fig.add_trace(go.Violin( |
| | y=losses, |
| | x=[x_pos] * len(losses), |
| | name=loss_type, |
| | side='negative' if j == 0 else 'positive', |
| | fillcolor='#BA68C8' if j == 1 else ('#3498DB' if j == 2 else '#4DB6AC'), |
| | line_color='#6A1B9A' if j == 1 else ('#2874A6' if j == 2 else '#00695C'), |
| | box_visible=True, |
| | meanline_visible=True, |
| | points=False, |
| | width=0.7 if j == 0 else (0.5 if j == 1 else 0.3), |
| | legendgroup=loss_type.lower().replace(' ', '_'), |
| | showlegend=(i == 0), |
| | |
| | hovertemplate=f'<b>{loss_type}</b><br>' + |
| | 'Normalized: %{y:.3f}<br>' + |
| | f'Original Range: [{min_loss:.3f}, {max_loss:.3f}]<br>' + |
| | '<extra></extra>' |
| | )) |
| | |
| | |
| | print(f"CURRENTLY PROCESSING {detector_type}") |
| |
|
| |
|
| | |
| | pprint(results) |
| | |
| | |
| | fig.update_layout( |
| | title=dict( |
| | text=f'{model_type.upper()} {config.model_name} Detector Comparison at Layer {current_layer_idx}', |
| | x=0.5, y=0.96, xanchor='center', yanchor='top', |
| | font=dict(family="Times New Roman", size=12, color="black") |
| | ), |
| | xaxis_title='Detector Type', |
| | yaxis_title='Distribution of Distance (0-1 Scale)', |
| | width=max(600, len(detector_types) * 120), |
| | height=500, |
| | showlegend=True, |
| | legend=dict( |
| | orientation="h", yanchor="bottom", y=0.97, xanchor="center", x=0.5, |
| | font=dict(size=10, family="Times New Roman") |
| | ), |
| | plot_bgcolor='#FFFEF7', |
| | paper_bgcolor='white', |
| | font=dict(family="Times New Roman", size=10), |
| | margin=dict(t=50, b=20, l=20, r=0) |
| | ) |
| | |
| | |
| | axis_style = dict( |
| | showgrid=True, |
| | gridcolor='rgba(128, 128, 128, 0.3)', |
| | showline=False, |
| | tickfont=dict(family="Times New Roman", size=10, color="black") |
| | ) |
| | |
| | fig.update_xaxes(**axis_style, title_font=dict(family="Times New Roman", size=12, color="black")) |
| | fig.update_yaxes( |
| | **axis_style, |
| | range=[-0.1, 1.1], |
| | title_font=dict(family="Times New Roman", size=12, color="black") |
| | ) |
| | |
| | if other_layer_idx == "prev": |
| | fig.write_image(f"{save_path}_{model_type}_detectors_comparison_layer_{current_layer_idx-1}_{current_layer_idx}.pdf", |
| | height=300, width=500, scale=3) |
| | |
| |
|
| | |
| | accuracy_path = f"{save_path}_{model_type}_accuracy_layer_{current_layer_idx-1}_{current_layer_idx}.json" |
| | |
| | elif other_layer_idx == "next": |
| | fig.write_image(f"{save_path}_{model_type}_detectors_comparison_layer_{current_layer_idx}_{current_layer_idx+1}.pdf", |
| | height=300, width=500, scale=3) |
| | |
| |
|
| | |
| | accuracy_path = f"{save_path}_{model_type}_accuracy_layer_{current_layer_idx}_{current_layer_idx+1}.json" |
| | |
| | |
| | def numpy_to_python(obj): |
| | if isinstance(obj, np.integer): |
| | return int(obj) |
| | elif isinstance(obj, np.floating): |
| | return float(obj) |
| | elif isinstance(obj, np.ndarray): |
| | return obj.tolist() |
| | elif isinstance(obj, dict): |
| | return {key: numpy_to_python(val) for key, val in obj.items()} |
| | elif isinstance(obj, list): |
| | return [numpy_to_python(item) for item in obj] |
| | return obj |
| |
|
| | |
| | results = numpy_to_python(results) |
| |
|
| | if 'confusion_matrix_overall' in results: |
| | cm = results['confusion_matrix_overall'] |
| | if isinstance(cm, np.ndarray): |
| | results['confusion_matrix_overall'] = cm.tolist() |
| | elif isinstance(cm, list): |
| | results['confusion_matrix_overall'] = [[int(x) for x in row] for row in cm] |
| |
|
| |
|
| | |
| | if os.path.exists(accuracy_path): |
| | with open(accuracy_path, 'r') as f: |
| | existing_results = json.load(f) |
| | existing_results.update(results) |
| | results = existing_results |
| | |
| | with open(accuracy_path, 'w') as f: |
| | json.dump(results, f, indent=2) |
| | |
| | |
| | return fig |
| |
|
| | |
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description='Multi-layer Attention Analysis') |
| | parser.add_argument('--model_name', type=str, required=True) |
| | parser.add_argument('--model_type', type=str, required=True) |
| | parser.add_argument("--other_layer_idx", type=str, required=True, help="crow should be taken for previous and current layer or next and current layers? give 'prev' or 'next' as argument ") |
| | parser.add_argument("--dataset", required=True, help="mad, spylab, or anthropic") |
| | args = parser.parse_args() |
| | |
| | if args.dataset == "mad": |
| | config = SafetyNetConfig(args.model_name) |
| | elif args.dataset == "spylab": |
| | config = spylab_create_config(args.model_name) |
| | elif args.dataset == "anthropic": |
| | config = anthropic_create_config(args.model_name) |
| | |
| | if args.model_name == 'qwen': |
| | current_layer_idx=21 |
| | elif args.model_name == 'mistral': |
| | current_layer_idx = 12 |
| | elif args.model_name == 'llama3': |
| | current_layer_idx = 13 |
| | elif args.model_name == 'llama2': |
| | current_layer_idx = 15 |
| | elif args.model_name == 'gemma': |
| | current_layer_idx = 18 |
| | |
| | |
| | save_path = f"{config.output_dir}/{args.model_name}" |
| | |
| | viz = Visualization() |
| | |
| | viz.plot_detectors_comparison(args.model_name, |
| | |
| | ['ae', 'pca', 'mahalanobis', 'beatrix', f'crow'], |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | current_layer_idx = current_layer_idx, |
| | other_layer_idx = args.other_layer_idx, |
| | save_path = save_path, |
| | config = config, |
| | model_type = args.model_type, |
| | args = args |
| | ) |
| | |
| | |
| | |
| | |
| | |
| |
|