| from utils import * |
| from src.configs.model_configs import * |
|
|
| class JS_divergence: |
| |
| def __init__(self, model_name: str, layer_idx: int, config: AnalysisConfig): |
| self.model_name = model_name |
| self.layer_idx = layer_idx |
| self.config = config |
| if args.save: |
| self.save_mean_attention() |
| else: |
| self.save_jsd_stats() |
| |
| def compute_layer_js_divergence(self) -> float: |
| """ |
| Compute Jensen-Shannon divergence between normal and harmful attention patterns for a specific layer. |
| |
| Args: |
| model_name: Name of the model (e.g., 'llama3') |
| layer_idx: Layer index to analyze |
| scratch_dir: Directory where attention scores are saved |
| |
| Returns: |
| JS divergence score between normal and harmful attention distributions |
| """ |
| |
| def load_layer_attention(dataset_type: str) -> torch.Tensor: |
| """Load all attention scores for a specific layer and dataset type""" |
| layer_dir = f"{self.config.scratch_dir}/{self.model_name}/{dataset_type}/layer_{self.layer_idx}" |
| batch_files = glob.glob(f"{layer_dir}/batch_*_qk_scores.pkl") |
| |
| all_attention = [] |
| for file_path in sorted(batch_files): |
| with open(file_path, "rb") as f: |
| batch_attention = pkl.load(f) |
| all_attention.append(batch_attention) |
| |
| return torch.cat(all_attention, dim=0) |
| |
| |
| normal_attention = load_layer_attention("normal") |
| harmful_attention = load_layer_attention("harmful") |
| |
| |
| normal_mean = normal_attention.mean(dim=0) |
| harmful_mean = harmful_attention.mean(dim=0) |
| |
| |
| normal_flat = normal_mean.flatten() |
| harmful_flat = harmful_mean.flatten() |
| |
| |
| normal_dist = F.softmax(normal_flat, dim=0).numpy() |
| harmful_dist = F.softmax(harmful_flat, dim=0).numpy() |
| |
| |
| js_divergence = jensenshannon(normal_dist, harmful_dist) ** 2 |
| |
| return js_divergence |
|
|
| |
| def save_jsd_stats(self): |
| jsd = self.compute_layer_js_divergence() |
| |
| os.makedirs(f"utils/data/{self.model_name}", exist_ok=True) |
| filepath = f"utils/data/{self.model_name}/jsd_stats.json" |
| |
| try: |
| with open(filepath, "r") as f: |
| existing_data = json.load(f) |
| except (FileNotFoundError, json.JSONDecodeError): |
| existing_data = {} |
| |
| existing_data[str(self.layer_idx)] = jsd |
| |
| with open(filepath, "w") as f: |
| json.dump(existing_data, f, indent=2) |
|
|
|
|
| def save_mean_attention(self): |
| """ |
| Compute and save mean attention patterns for normal and harmful datasets |
| """ |
| def load_layer_attention(dataset_type: str) -> torch.Tensor: |
| """Load all attention scores for a specific layer and dataset type""" |
| layer_dir = f"{self.config.scratch_dir}/{self.model_name}/{dataset_type}/layer_{self.layer_idx}" |
| batch_files = glob.glob(f"{layer_dir}/batch_*_qk_scores.pkl") |
| |
| all_attention = [] |
| for file_path in sorted(batch_files): |
| with open(file_path, "rb") as f: |
| batch_attention = pkl.load(f) |
| all_attention.append(batch_attention) |
| |
| return torch.cat(all_attention, dim=0) |
| |
| |
| normal_attention = load_layer_attention("normal") |
| harmful_attention = load_layer_attention("harmful") |
| |
| |
| normal_mean = normal_attention.mean(dim=0).cpu().numpy() |
| harmful_mean = harmful_attention.mean(dim=0).cpu().numpy() |
| |
| |
| os.makedirs(f"utils/data/{self.model_name}", exist_ok=True) |
| |
| |
| mean_data = { |
| "layer": self.layer_idx, |
| "normal_mean": normal_mean.tolist(), |
| "harmful_mean": harmful_mean.tolist(), |
| "shape": normal_mean.shape |
| } |
| |
| os.makedirs(f"{self.config.scratch_dir}/all_model_mean_attn_layers/{self.model_name}", |
| exist_ok=True |
| ) |
| filepath = f"{self.config.scratch_dir}/all_model_mean_attn_layers/" \ |
| f"{self.model_name}/layer_{self.layer_idx}_mean_attention.json" |
| with open(filepath, "w") as f: |
| json.dump(mean_data, f, indent=2) |
| |
| print(f"Saved mean attention patterns for layer {self.layer_idx}") |
| return normal_mean, harmful_mean |
|
|
| def parser(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model", "-m", required=True, |
| help="Enter the name of the model") |
| parser.add_argument("--layer_idx", "-l", type=int, required=True, |
| help="Layer index to analyze") |
| parser.add_argument("--save", '-s', action="store_true", |
| help="do you want to save the mean vector?") |
| return parser.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| args = parser() |
| config=AnalysisConfig(args.model) |
| jsd = JS_divergence(model_name=args.model, |
| layer_idx=args.layer_idx, |
| config = config) |