from utils import * from src.configs.safetynet_config import SafetyNetConfig from utils.safetynet.vae_ae_train import Attention_DataProcessing, Train, Test, Detector_Stats from src.configs.spylab_model_config import spylab_create_config from src.configs.anthropic_model_config import anthropic_create_config import plotly.graph_objects as go class Visualization: @staticmethod def data_processing_for_crow( other_layer_idx, vanilla_path = "utils/data/llama2/ae_vae/vanilla/cosine_analysis.json", harmful_path = "utils/data/llama2/ae_vae/backdoored/cosine_analysis.json"): with open(vanilla_path, "r") as f: vanilla_data = json.load(f) with open(harmful_path, "r") as f_: backdoor_data = json.load(f_) if other_layer_idx == 'prev': layer_idx = 0 elif other_layer_idx == "next": layer_idx = 1 ''' As the two layers pair values are there, so having [0] will give the first pair and [1] the second pair ''' mean_harmful_vanilla = np.mean(np.array(vanilla_data["harmful"][layer_idx])) mean_harmful_backdoor = np.mean(np.array(backdoor_data["harmful"][layer_idx])) vanilla_data_stats = [i - float(mean_harmful_vanilla) for i in vanilla_data["normal"][layer_idx]] # Fix the list slicing - use int() for indices split_idx = int(len(vanilla_data_stats) * 0.8) vanilla_data_stats_train = vanilla_data_stats[:split_idx] vanilla_data_stats_val = vanilla_data_stats[split_idx:] backdoor_data_stats = [i - float(mean_harmful_backdoor) for i in backdoor_data["normal"][layer_idx]] # Return as a dictionary to match expected format return { "normal_losses": vanilla_data_stats_train, "val_losses": vanilla_data_stats_val, "harmful_losses": backdoor_data_stats } @staticmethod def plot_all_layers_violin(model_name, model_type, save_path, config: SafetyNetConfig, max_layers=32): """Create violin plot for all available layers""" fig = go.Figure() layers_data = {} # Load all available layer data for layer_idx in range(max_layers): data_path = f"utils/data/{model_name}/{model_type}_loss/layer_{layer_idx}_{model_type}_loss.json" with open(data_path, "r") as f: layers_data[layer_idx] = json.load(f) if not layers_data: print("No layer data found!") return None # colors = { # 'Normal (Train)': 'rgba(70, 130, 180, 0.6)', # 'Normal (Val)': 'rgba(255, 165, 0, 0.5)', # 'Harmful': 'rgba(220, 20, 60, 0.6)' # } for i, (layer_idx, data) in tqdm(enumerate(layers_data.items())): x_pos = f'L{layer_idx}' # Shorter labels # Normal (Train) - left side fig.add_trace(go.Violin( y=data["normal_losses"], x=[x_pos] * len(data["normal_losses"]), name='Normal (Train)', side='negative', fillcolor="#4DB6AC", line_color='#00695C', box_visible=True, meanline_visible=True, points=False, width=0.7, legendgroup='normal_train', showlegend=(i == 0) # Show legend only for first occurrence )) # Harmful - right side fig.add_trace(go.Violin( y=data["harmful_losses"], x=[x_pos] * len(data["harmful_losses"]), name='Harmful', side='positive', fillcolor="#BA68C8", line_color='#6A1B9A', box_visible=True, meanline_visible=True, points=False, width=0.5, legendgroup='harmful', showlegend=(i == 0) )) # Normal (Val) - right side, smaller fig.add_trace(go.Violin( y=data["val_losses"], x=[x_pos] * len(data["val_losses"]), name='Normal (Val)', side='positive', fillcolor="#3498DB", line_color='#2874A6', box_visible=True, meanline_visible=True, points=False, width=0.3, legendgroup='normal_val', showlegend=(i == 0) )) # Layout fig.update_layout( title=dict( text=f'{config.model_name} Loss Distribution Across All Layers ({model_type.upper()})', x=0.5, # Center horizontally (0.5 = center, 0 = left, 1 = right) y=0.98, # Position vertically (0.95 = near top) xanchor='center', # Anchor point for x positioning yanchor='top', # Anchor point for y positioning font=dict( family="Times New Roman", size=30, # You can adjust title font size separately color="black" ) ), xaxis_title='Layer Index', yaxis_title='Reconstruction Loss', width=max(800, len(layers_data) * 60), # Dynamic width height=500, showlegend=True, legend=dict( orientation="h", yanchor="bottom", y=0.95, xanchor="center", x=0.5, font=dict(size=25, family="Times New Roman") ), plot_bgcolor='#FFFEF7', paper_bgcolor='white', font=dict(family="Times New Roman", size=20), margin=dict( t=70, # Top margin (increase this value for more space above) b=20, # Bottom margin l=20, # Left margin r=0 # Right margin ) ) fig.update_xaxes( showgrid=True, gridcolor='rgba(128, 128, 128, 0.2)', showline=False, tickangle=45 if len(layers_data) > 10 else 0, # Control tick font size tickfont=dict( family="Times New Roman", size=25, # Size of tick labels (L0, L1, L2, etc.) color="black" ), # Control axis title font size title_font=dict( family="Times New Roman", size=22, # Size of "Layer Index" title color="black" ) ) fig.update_yaxes( showgrid=True, gridcolor='rgba(128, 128, 128, 0.2)', showline=False, range=[0, None], # Control tick font size tickfont=dict( family="Times New Roman", size=25, # Size of y-axis tick values (0, 0.5, 1.0, etc.) color="black" ), # Control axis title font size title_font=dict( family="Times New Roman", size=32, # Size of "Reconstruction Loss" title color="black" ) ) fig.write_image(f"{save_path}_all_layers_violin.pdf", height = 1000, width = 1500, scale=3) return fig @staticmethod def plot_detectors_comparison(model_name, detector_types, other_layer_idx, current_layer_idx, save_path, config: SafetyNetConfig, model_type, args ): """Compare different detector types (AE, VAE, PCA) at a specific layer with normalized losses""" fig = go.Figure() results = {} for i, detector_type in enumerate(detector_types): if detector_type == "crow" \ or detector_type == "obfuscated_sim_crow" \ or detector_type == "obfuscated_ae_crow": # Get data as dictionary harmful_path = "utils/data/llama2/ae_vae/vanilla/cosine_analysis.json" if args.dataset == "mad": data = Visualization.data_processing_for_crow( other_layer_idx=other_layer_idx, harmful_path = f"utils/data/llama2/ae_vae/{model_type}/cosine_analysis.json" ) elif args.dataset == "spylab": data = Visualization.data_processing_for_crow( other_layer_idx=other_layer_idx, vanilla_path = f"utils/spylab_data/llama2/vanilla/cosine_analysis.json", harmful_path = f"utils/spylab_data/llama2/{model_type}/cosine_analysis.json" ) elif args.dataset == "anthropic": data = Visualization.data_processing_for_crow( other_layer_idx=other_layer_idx, vanilla_path = f"safetynet/utils/anthropic_data/{model_name}/vanilla/cosine_analysis.json", harmful_path = f"safetynet/utils/anthropic_data/{model_name}/{model_type}/cosine_analysis.json" ) normal_losses = np.array(data["normal_losses"]) harmful_losses = np.array(data["harmful_losses"]) val_losses = np.array(data["val_losses"]) else: if args.dataset == "mad": data_path = f"utils/data/{model_name}/{detector_type}_loss/layer_{current_layer_idx}_{detector_type}_loss.json" elif args.dataset == "spylab": data_path = f"utils/spylab_data/{model_name}/{args.model_type}_{detector_type}_loss/layer_{current_layer_idx}_{args.model_type}_{detector_type}_loss.json" elif args.dataset == "anthropic": data_path = f"safetynet/utils/anthropic_data/{model_name}/{args.model_type}_{detector_type}_loss/layer_{current_layer_idx}_{args.model_type}_{detector_type}_loss.json" print(data_path) with open(data_path, "r") as f: data = json.load(f) # Extract losses normal_losses = np.array(data["normal_losses"]) harmful_losses = np.array(data["harmful_losses"]) val_losses = np.array(data["val_losses"]) print(data) # print(harmful_losses) # print(val_losses) val_mean = np.mean(val_losses) val_std = np.std(val_losses) threshold_upper = val_mean + 2 * val_std threshold_lower = val_mean - 2 * val_std # Predictions (1 = anomaly/harmful, 0 = normal) train_pred = ((normal_losses < threshold_lower) | (normal_losses > threshold_upper)).astype(int) harmful_pred = ((harmful_losses < threshold_lower) | (harmful_losses > threshold_upper)).astype(int) # Labels train_labels = np.zeros(len(normal_losses)) harmful_labels = np.ones(len(harmful_losses)) # Combine everything all_pred = np.concatenate([train_pred, harmful_pred]) all_labels = np.concatenate([train_labels, harmful_labels]) all_scores = np.concatenate([normal_losses, harmful_losses]) # AUROC: Check if scores need to be inverted # If lower scores = more anomalous, negate them # try: stats = Detector_Stats() detector_results = stats.compute_comprehensive_metrics(normal_losses, val_losses, harmful_losses) detector_results['confusion_matrix_overall'] = detector_results['confusion_matrix_overall'].tolist() results[detector_type] = detector_results ''' print(all_labels) print(all_scores) auroc = roc_auc_score(all_labels, all_scores) if auroc < 0.5: # Scores are inverted auroc = roc_auc_score(all_labels, -all_scores) # except: # auroc = 0.5 # Overall metrics overall_accuracy = accuracy_score(all_labels, all_pred) overall_precision = precision_score(all_labels, all_pred, zero_division=0) overall_recall = recall_score(all_labels, all_pred, zero_division=0) overall_f1 = f1_score(all_labels, all_pred, zero_division=0) # Per-class metrics train_accuracy = np.mean(train_pred == train_labels) harmful_accuracy = np.mean(harmful_pred == harmful_labels) harmful_precision = precision_score(harmful_labels, harmful_pred, zero_division=0) harmful_recall = recall_score(harmful_labels, harmful_pred, zero_division=0) harmful_f1 = f1_score(harmful_labels, harmful_pred, zero_division=0) results[detector_type] = { "auroc": float(auroc), "overall_accuracy": float(overall_accuracy), "overall_precision": float(overall_precision), "overall_recall": float(overall_recall), "overall_f1": float(overall_f1), "train_accuracy": float(train_accuracy), "harmful_accuracy": float(harmful_accuracy), "harmful_precision": float(harmful_precision), "harmful_recall": float(harmful_recall), "harmful_f1": float(harmful_f1), "threshold_lower": float(threshold_lower), "threshold_upper": float(threshold_upper) } ''' # Normalize to 0-1 range using min-max scaling across all loss types all_losses = np.concatenate([normal_losses, harmful_losses, val_losses]) min_loss = np.min(all_losses) max_loss = np.max(all_losses) loss_range = max_loss - min_loss print(f"\n NORMAL LOSSEs \n") print(normal_losses) # Avoid division by zero if loss_range == 0: loss_range = 1 # Normalize each loss type normal_norm = (normal_losses - min_loss) / loss_range harmful_norm = (harmful_losses - min_loss) / loss_range val_norm = (val_losses - min_loss) / loss_range detector = detector_type.split("_")[-1] if detector == "crow": if other_layer_idx == "prev": x_pos = f"CROW {current_layer_idx-1}-{current_layer_idx}" elif other_layer_idx == "next": x_pos = f"CROW {current_layer_idx}-{current_layer_idx+1}" else: x_pos = detector.upper() # Add traces with normalized data loss_data = [ ('Normal (Train)', normal_norm), ('Harmful', harmful_norm), ('Normal (Val)', val_norm) ] for j, (loss_type, losses) in enumerate(loss_data): fig.add_trace(go.Violin( y=losses, x=[x_pos] * len(losses), name=loss_type, side='negative' if j == 0 else 'positive', fillcolor='#BA68C8' if j == 1 else ('#3498DB' if j == 2 else '#4DB6AC'), # Harmful, Val, Train line_color='#6A1B9A' if j == 1 else ('#2874A6' if j == 2 else '#00695C'), # Darker outlines box_visible=True, meanline_visible=True, points=False, width=0.7 if j == 0 else (0.5 if j == 1 else 0.3), legendgroup=loss_type.lower().replace(' ', '_'), showlegend=(i == 0), # Add hover info showing original and normalized values hovertemplate=f'{loss_type}
' + 'Normalized: %{y:.3f}
' + f'Original Range: [{min_loss:.3f}, {max_loss:.3f}]
' + '' )) print(f"CURRENTLY PROCESSING {detector_type}") pprint(results) # Layout with improved styling fig.update_layout( title=dict( text=f'{model_type.upper()} {config.model_name} Detector Comparison at Layer {current_layer_idx}',#
Losses Normalized to [0,1] Range', x=0.5, y=0.96, xanchor='center', yanchor='top', font=dict(family="Times New Roman", size=12, color="black") ), xaxis_title='Detector Type', yaxis_title='Distribution of Distance (0-1 Scale)', width=max(600, len(detector_types) * 120), height=500, showlegend=True, legend=dict( orientation="h", yanchor="bottom", y=0.97, xanchor="center", x=0.5, font=dict(size=10, family="Times New Roman") ), plot_bgcolor='#FFFEF7', # Light cream background paper_bgcolor='white', font=dict(family="Times New Roman", size=10), margin=dict(t=50, b=20, l=20, r=0) # Increased top margin for subtitle ) # Axes styling with fixed range axis_style = dict( showgrid=True, gridcolor='rgba(128, 128, 128, 0.3)', showline=False, tickfont=dict(family="Times New Roman", size=10, color="black") ) fig.update_xaxes(**axis_style, title_font=dict(family="Times New Roman", size=12, color="black")) fig.update_yaxes( **axis_style, range=[-0.1, 1.1], # Fixed range from 0 to 1 with slight padding title_font=dict(family="Times New Roman", size=12, color="black") ) if other_layer_idx == "prev": fig.write_image(f"{save_path}_{model_type}_detectors_comparison_layer_{current_layer_idx-1}_{current_layer_idx}.pdf", height=300, width=500, scale=3) # At the end of the method, before return: accuracy_path = f"{save_path}_{model_type}_accuracy_layer_{current_layer_idx-1}_{current_layer_idx}.json" elif other_layer_idx == "next": fig.write_image(f"{save_path}_{model_type}_detectors_comparison_layer_{current_layer_idx}_{current_layer_idx+1}.pdf", height=300, width=500, scale=3) # At the end of the method, before return: accuracy_path = f"{save_path}_{model_type}_accuracy_layer_{current_layer_idx}_{current_layer_idx+1}.json" def numpy_to_python(obj): if isinstance(obj, np.integer): return int(obj) elif isinstance(obj, np.floating): return float(obj) elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, dict): return {key: numpy_to_python(val) for key, val in obj.items()} elif isinstance(obj, list): return [numpy_to_python(item) for item in obj] return obj # Convert entire results dictionary results = numpy_to_python(results) if 'confusion_matrix_overall' in results: cm = results['confusion_matrix_overall'] if isinstance(cm, np.ndarray): results['confusion_matrix_overall'] = cm.tolist() elif isinstance(cm, list): results['confusion_matrix_overall'] = [[int(x) for x in row] for row in cm] # Now save if os.path.exists(accuracy_path): with open(accuracy_path, 'r') as f: existing_results = json.load(f) existing_results.update(results) results = existing_results with open(accuracy_path, 'w') as f: json.dump(results, f, indent=2) return fig # Updated main section: if __name__ == "__main__": parser = argparse.ArgumentParser(description='Multi-layer Attention Analysis') parser.add_argument('--model_name', type=str, required=True) parser.add_argument('--model_type', type=str, required=True) parser.add_argument("--other_layer_idx", type=str, required=True, help="crow should be taken for previous and current layer or next and current layers? give 'prev' or 'next' as argument ") parser.add_argument("--dataset", required=True, help="mad, spylab, or anthropic") args = parser.parse_args() if args.dataset == "mad": config = SafetyNetConfig(args.model_name) elif args.dataset == "spylab": config = spylab_create_config(args.model_name) elif args.dataset == "anthropic": config = anthropic_create_config(args.model_name) if args.model_name == 'qwen': current_layer_idx=21 elif args.model_name == 'mistral': current_layer_idx = 12 elif args.model_name == 'llama3': current_layer_idx = 13 elif args.model_name == 'llama2': current_layer_idx = 15 elif args.model_name == 'gemma': current_layer_idx = 18 # save_path = f"{config.output_dir}/{args.model_name}_all_layers_{args.model_type}" save_path = f"{config.output_dir}/{args.model_name}" viz = Visualization() # viz.plot_all_layers_violin(args.model_name, args.model_type, save_path, config=config) viz.plot_detectors_comparison(args.model_name, # ['ae', 'vae', 'pca', 'mahalanobis', 'beatrix', f'crow'], ['ae', 'pca', 'mahalanobis', 'beatrix', f'crow'], # ['obfuscated_sim_ae', # 'obfuscated_sim_vae', # 'obfuscated_sim_pca', # 'obfuscated_sim_mahalanobis', # 'obfuscated_sim_beatrix', # 'obfuscated_sim_crow'], # ['obfuscated_ae_ae', # 'obfuscated_ae_vae', # 'obfuscated_ae_pca', # 'obfuscated_ae_mahalanobis', # 'obfuscated_ae_beatrix', # 'obfuscated_ae_crow'], current_layer_idx = current_layer_idx, other_layer_idx = args.other_layer_idx, save_path = save_path, config = config, model_type = args.model_type, args = args ) # python -m utils.visualisation.plot_violin_classification --model_type obfuscated_sim --other_layer_idx prev --model_name gemma && python -m utils.visualisation.plot_violin_classification --model_type obfuscated_sim --other_layer_idx next --model_name gemma && python -m utils.visualisation.plot_violin_classification --model_type obfuscated_sim --other_layer_idx prev --model_name mistral && python -m utils.visualisation.plot_violin_classification --model_type obfuscated_sim --other_layer_idx next --model_name mistral && python -m utils.visualisation.plot_violin_classification --model_type obfuscated_sim --other_layer_idx prev --model_name llama2 && python -m utils.visualisation.plot_violin_classification --model_type obfuscated_sim --other_layer_idx next --model_name llama2 && python -m utils.visualisation.plot_violin_classification --model_type obfuscated_sim --other_layer_idx prev --model_name llama3 && python -m utils.visualisation.plot_violin_classification --model_type obfuscated_sim --other_layer_idx next --model_name llama3 && python -m utils.visualisation.plot_violin_classification --model_type obfuscated_sim --other_layer_idx prev --model_name qwen && python -m utils.visualisation.plot_violin_classification --model_type obfuscated_sim --other_layer_idx next --model_name qwen # python -m utils.visualisation.plot_violin_classification --model_type backdoored --other_layer_idx prev --model_name gemma && python -m utils.visualisation.plot_violin_classification --model_type backdoored --other_layer_idx next --model_name gemma && python -m utils.visualisation.plot_violin_classification --model_type backdoored --other_layer_idx prev --model_name mistral && python -m utils.visualisation.plot_violin_classification --model_type backdoored --other_layer_idx next --model_name mistral && python -m utils.visualisation.plot_violin_classification --model_type backdoored --other_layer_idx prev --model_name llama2 && python -m utils.visualisation.plot_violin_classification --model_type backdoored --other_layer_idx next --model_name llama2 && python -m utils.visualisation.plot_violin_classification --model_type backdoored --other_layer_idx prev --model_name llama3 && python -m utils.visualisation.plot_violin_classification --model_type backdoored --other_layer_idx next --model_name llama3 && python -m utils.visualisation.plot_violin_classification --model_type backdoored --other_layer_idx prev --model_name qwen && python -m utils.visualisation.plot_violin_classification --model_type backdoored --other_layer_idx next --model_name qwen --dataset spylab # python -m utils.visualisation.plot_violin_classification --model_type obfuscated_ae --other_layer_idx prev --model_name gemma && python -m utils.visualisation.plot_violin_classification --model_type obfuscated_ae --other_layer_idx next --model_name gemma && python -m utils.visualisation.plot_violin_classification --model_type obfuscated_ae --other_layer_idx prev --model_name mistral && python -m utils.visualisation.plot_violin_classification --model_type obfuscated_ae --other_layer_idx next --model_name mistral && python -m utils.visualisation.plot_violin_classification --model_type obfuscated_ae --other_layer_idx prev --model_name llama2 && python -m utils.visualisation.plot_violin_classification --model_type obfuscated_ae --other_layer_idx next --model_name llama2 && python -m utils.visualisation.plot_violin_classification --model_type obfuscated_ae --other_layer_idx prev --model_name llama3 && python -m utils.visualisation.plot_violin_classification --model_type obfuscated_ae --other_layer_idx next --model_name llama3 && python -m utils.visualisation.plot_violin_classification --model_type obfuscated_ae --other_layer_idx prev --model_name qwen && python -m utils.visualisation.plot_violin_classification --model_type obfuscated_ae --other_layer_idx next --model_name qwen