safetynet_final / src /analysis /attn_analysis.py
Maheep's picture
Add files using upload-large-folder tool
8eecc7d verified
from utils import *
from src.configs.model_configs import *
class JS_divergence:
def __init__(self, model_name: str, layer_idx: int, config: AnalysisConfig):
self.model_name = model_name
self.layer_idx = layer_idx
self.config = config
if args.save:
self.save_mean_attention()
else:
self.save_jsd_stats()
def compute_layer_js_divergence(self) -> float:
"""
Compute Jensen-Shannon divergence between normal and harmful attention patterns for a specific layer.
Args:
model_name: Name of the model (e.g., 'llama3')
layer_idx: Layer index to analyze
scratch_dir: Directory where attention scores are saved
Returns:
JS divergence score between normal and harmful attention distributions
"""
def load_layer_attention(dataset_type: str) -> torch.Tensor:
"""Load all attention scores for a specific layer and dataset type"""
layer_dir = f"{self.config.scratch_dir}/{self.model_name}/{dataset_type}/layer_{self.layer_idx}"
batch_files = glob.glob(f"{layer_dir}/batch_*_qk_scores.pkl")
all_attention = []
for file_path in sorted(batch_files):
with open(file_path, "rb") as f:
batch_attention = pkl.load(f)
all_attention.append(batch_attention)
return torch.cat(all_attention, dim=0) # Concatenate along batch dimension
# Load attention scores for both datasets
normal_attention = load_layer_attention("normal")
harmful_attention = load_layer_attention("harmful")
# Compute mean attention pattern across all samples
normal_mean = normal_attention.mean(dim=0) # Average across batch dimension
harmful_mean = harmful_attention.mean(dim=0)
# Flatten and convert to probability distributions
normal_flat = normal_mean.flatten()
harmful_flat = harmful_mean.flatten()
# Apply softmax to convert to probability distributions
normal_dist = F.softmax(normal_flat, dim=0).numpy()
harmful_dist = F.softmax(harmful_flat, dim=0).numpy()
# Compute Jensen-Shannon divergence
js_divergence = jensenshannon(normal_dist, harmful_dist) ** 2 # Square to get JS divergence (not distance)
return js_divergence
def save_jsd_stats(self):
jsd = self.compute_layer_js_divergence()
os.makedirs(f"utils/data/{self.model_name}", exist_ok=True)
filepath = f"utils/data/{self.model_name}/jsd_stats.json"
try:
with open(filepath, "r") as f:
existing_data = json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
existing_data = {}
existing_data[str(self.layer_idx)] = jsd
with open(filepath, "w") as f:
json.dump(existing_data, f, indent=2)
def save_mean_attention(self):
"""
Compute and save mean attention patterns for normal and harmful datasets
"""
def load_layer_attention(dataset_type: str) -> torch.Tensor:
"""Load all attention scores for a specific layer and dataset type"""
layer_dir = f"{self.config.scratch_dir}/{self.model_name}/{dataset_type}/layer_{self.layer_idx}"
batch_files = glob.glob(f"{layer_dir}/batch_*_qk_scores.pkl")
all_attention = []
for file_path in sorted(batch_files):
with open(file_path, "rb") as f:
batch_attention = pkl.load(f)
all_attention.append(batch_attention)
return torch.cat(all_attention, dim=0)
# Load attention scores for both datasets
normal_attention = load_layer_attention("normal")
harmful_attention = load_layer_attention("harmful")
# Compute mean attention pattern across all samples
normal_mean = normal_attention.mean(dim=0).cpu().numpy()
harmful_mean = harmful_attention.mean(dim=0).cpu().numpy()
# Save mean values
os.makedirs(f"utils/data/{self.model_name}", exist_ok=True)
# Save as separate files or combined
mean_data = {
"layer": self.layer_idx,
"normal_mean": normal_mean.tolist(), # Convert to list for JSON serialization
"harmful_mean": harmful_mean.tolist(),
"shape": normal_mean.shape
}
os.makedirs(f"{self.config.scratch_dir}/all_model_mean_attn_layers/{self.model_name}",
exist_ok=True
)
filepath = f"{self.config.scratch_dir}/all_model_mean_attn_layers/" \
f"{self.model_name}/layer_{self.layer_idx}_mean_attention.json"
with open(filepath, "w") as f:
json.dump(mean_data, f, indent=2)
print(f"Saved mean attention patterns for layer {self.layer_idx}")
return normal_mean, harmful_mean
def parser():
parser = argparse.ArgumentParser()
parser.add_argument("--model", "-m", required=True,
help="Enter the name of the model")
parser.add_argument("--layer_idx", "-l", type=int, required=True,
help="Layer index to analyze")
parser.add_argument("--save", '-s', action="store_true",
help="do you want to save the mean vector?")
return parser.parse_args()
if __name__ == "__main__":
args = parser()
config=AnalysisConfig(args.model)
jsd = JS_divergence(model_name=args.model,
layer_idx=args.layer_idx,
config = config)