File size: 5,883 Bytes
8eecc7d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | from utils import *
from src.configs.model_configs import *
class JS_divergence:
def __init__(self, model_name: str, layer_idx: int, config: AnalysisConfig):
self.model_name = model_name
self.layer_idx = layer_idx
self.config = config
if args.save:
self.save_mean_attention()
else:
self.save_jsd_stats()
def compute_layer_js_divergence(self) -> float:
"""
Compute Jensen-Shannon divergence between normal and harmful attention patterns for a specific layer.
Args:
model_name: Name of the model (e.g., 'llama3')
layer_idx: Layer index to analyze
scratch_dir: Directory where attention scores are saved
Returns:
JS divergence score between normal and harmful attention distributions
"""
def load_layer_attention(dataset_type: str) -> torch.Tensor:
"""Load all attention scores for a specific layer and dataset type"""
layer_dir = f"{self.config.scratch_dir}/{self.model_name}/{dataset_type}/layer_{self.layer_idx}"
batch_files = glob.glob(f"{layer_dir}/batch_*_qk_scores.pkl")
all_attention = []
for file_path in sorted(batch_files):
with open(file_path, "rb") as f:
batch_attention = pkl.load(f)
all_attention.append(batch_attention)
return torch.cat(all_attention, dim=0) # Concatenate along batch dimension
# Load attention scores for both datasets
normal_attention = load_layer_attention("normal")
harmful_attention = load_layer_attention("harmful")
# Compute mean attention pattern across all samples
normal_mean = normal_attention.mean(dim=0) # Average across batch dimension
harmful_mean = harmful_attention.mean(dim=0)
# Flatten and convert to probability distributions
normal_flat = normal_mean.flatten()
harmful_flat = harmful_mean.flatten()
# Apply softmax to convert to probability distributions
normal_dist = F.softmax(normal_flat, dim=0).numpy()
harmful_dist = F.softmax(harmful_flat, dim=0).numpy()
# Compute Jensen-Shannon divergence
js_divergence = jensenshannon(normal_dist, harmful_dist) ** 2 # Square to get JS divergence (not distance)
return js_divergence
def save_jsd_stats(self):
jsd = self.compute_layer_js_divergence()
os.makedirs(f"utils/data/{self.model_name}", exist_ok=True)
filepath = f"utils/data/{self.model_name}/jsd_stats.json"
try:
with open(filepath, "r") as f:
existing_data = json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
existing_data = {}
existing_data[str(self.layer_idx)] = jsd
with open(filepath, "w") as f:
json.dump(existing_data, f, indent=2)
def save_mean_attention(self):
"""
Compute and save mean attention patterns for normal and harmful datasets
"""
def load_layer_attention(dataset_type: str) -> torch.Tensor:
"""Load all attention scores for a specific layer and dataset type"""
layer_dir = f"{self.config.scratch_dir}/{self.model_name}/{dataset_type}/layer_{self.layer_idx}"
batch_files = glob.glob(f"{layer_dir}/batch_*_qk_scores.pkl")
all_attention = []
for file_path in sorted(batch_files):
with open(file_path, "rb") as f:
batch_attention = pkl.load(f)
all_attention.append(batch_attention)
return torch.cat(all_attention, dim=0)
# Load attention scores for both datasets
normal_attention = load_layer_attention("normal")
harmful_attention = load_layer_attention("harmful")
# Compute mean attention pattern across all samples
normal_mean = normal_attention.mean(dim=0).cpu().numpy()
harmful_mean = harmful_attention.mean(dim=0).cpu().numpy()
# Save mean values
os.makedirs(f"utils/data/{self.model_name}", exist_ok=True)
# Save as separate files or combined
mean_data = {
"layer": self.layer_idx,
"normal_mean": normal_mean.tolist(), # Convert to list for JSON serialization
"harmful_mean": harmful_mean.tolist(),
"shape": normal_mean.shape
}
os.makedirs(f"{self.config.scratch_dir}/all_model_mean_attn_layers/{self.model_name}",
exist_ok=True
)
filepath = f"{self.config.scratch_dir}/all_model_mean_attn_layers/" \
f"{self.model_name}/layer_{self.layer_idx}_mean_attention.json"
with open(filepath, "w") as f:
json.dump(mean_data, f, indent=2)
print(f"Saved mean attention patterns for layer {self.layer_idx}")
return normal_mean, harmful_mean
def parser():
parser = argparse.ArgumentParser()
parser.add_argument("--model", "-m", required=True,
help="Enter the name of the model")
parser.add_argument("--layer_idx", "-l", type=int, required=True,
help="Layer index to analyze")
parser.add_argument("--save", '-s', action="store_true",
help="do you want to save the mean vector?")
return parser.parse_args()
if __name__ == "__main__":
args = parser()
config=AnalysisConfig(args.model)
jsd = JS_divergence(model_name=args.model,
layer_idx=args.layer_idx,
config = config) |