File size: 5,883 Bytes
8eecc7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from utils import *
from src.configs.model_configs import *

class JS_divergence:
    
    def __init__(self, model_name: str, layer_idx: int, config: AnalysisConfig):
        self.model_name = model_name
        self.layer_idx = layer_idx
        self.config = config
        if args.save:
            self.save_mean_attention()
        else:
            self.save_jsd_stats()
    
    def compute_layer_js_divergence(self) -> float:
        """
        Compute Jensen-Shannon divergence between normal and harmful attention patterns for a specific layer.
        
        Args:
            model_name: Name of the model (e.g., 'llama3')
            layer_idx: Layer index to analyze
            scratch_dir: Directory where attention scores are saved
        
        Returns:
            JS divergence score between normal and harmful attention distributions
        """
        
        def load_layer_attention(dataset_type: str) -> torch.Tensor:
            """Load all attention scores for a specific layer and dataset type"""
            layer_dir = f"{self.config.scratch_dir}/{self.model_name}/{dataset_type}/layer_{self.layer_idx}"
            batch_files = glob.glob(f"{layer_dir}/batch_*_qk_scores.pkl")
            
            all_attention = []
            for file_path in sorted(batch_files):
                with open(file_path, "rb") as f:
                    batch_attention = pkl.load(f)
                    all_attention.append(batch_attention)
            
            return torch.cat(all_attention, dim=0)  # Concatenate along batch dimension
        
        # Load attention scores for both datasets
        normal_attention = load_layer_attention("normal")
        harmful_attention = load_layer_attention("harmful")
        
        # Compute mean attention pattern across all samples
        normal_mean = normal_attention.mean(dim=0)  # Average across batch dimension
        harmful_mean = harmful_attention.mean(dim=0)
        
        # Flatten and convert to probability distributions
        normal_flat = normal_mean.flatten()
        harmful_flat = harmful_mean.flatten()
        
        # Apply softmax to convert to probability distributions
        normal_dist = F.softmax(normal_flat, dim=0).numpy()
        harmful_dist = F.softmax(harmful_flat, dim=0).numpy()
        
        # Compute Jensen-Shannon divergence
        js_divergence = jensenshannon(normal_dist, harmful_dist) ** 2  # Square to get JS divergence (not distance)
        
        return js_divergence

    
    def save_jsd_stats(self):
        jsd = self.compute_layer_js_divergence()
        
        os.makedirs(f"utils/data/{self.model_name}", exist_ok=True)
        filepath = f"utils/data/{self.model_name}/jsd_stats.json"
        
        try:
            with open(filepath, "r") as f:
                existing_data = json.load(f)
        except (FileNotFoundError, json.JSONDecodeError):
            existing_data = {}
        
        existing_data[str(self.layer_idx)] = jsd
        
        with open(filepath, "w") as f:
            json.dump(existing_data, f, indent=2)


    def save_mean_attention(self):
        """
        Compute and save mean attention patterns for normal and harmful datasets
        """
        def load_layer_attention(dataset_type: str) -> torch.Tensor:
            """Load all attention scores for a specific layer and dataset type"""
            layer_dir = f"{self.config.scratch_dir}/{self.model_name}/{dataset_type}/layer_{self.layer_idx}"
            batch_files = glob.glob(f"{layer_dir}/batch_*_qk_scores.pkl")
            
            all_attention = []
            for file_path in sorted(batch_files):
                with open(file_path, "rb") as f:
                    batch_attention = pkl.load(f)
                    all_attention.append(batch_attention)
            
            return torch.cat(all_attention, dim=0)
        
        # Load attention scores for both datasets
        normal_attention = load_layer_attention("normal")
        harmful_attention = load_layer_attention("harmful")
        
        # Compute mean attention pattern across all samples
        normal_mean = normal_attention.mean(dim=0).cpu().numpy()
        harmful_mean = harmful_attention.mean(dim=0).cpu().numpy()
        
        # Save mean values
        os.makedirs(f"utils/data/{self.model_name}", exist_ok=True)
        
        # Save as separate files or combined
        mean_data = {
            "layer": self.layer_idx,
            "normal_mean": normal_mean.tolist(),  # Convert to list for JSON serialization
            "harmful_mean": harmful_mean.tolist(),
            "shape": normal_mean.shape
        }
        
        os.makedirs(f"{self.config.scratch_dir}/all_model_mean_attn_layers/{self.model_name}", 
                    exist_ok=True
                    )
        filepath = f"{self.config.scratch_dir}/all_model_mean_attn_layers/" \
                f"{self.model_name}/layer_{self.layer_idx}_mean_attention.json"
        with open(filepath, "w") as f:
            json.dump(mean_data, f, indent=2)
        
        print(f"Saved mean attention patterns for layer {self.layer_idx}")
        return normal_mean, harmful_mean

def parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", "-m", required=True, 
                        help="Enter the name of the model")
    parser.add_argument("--layer_idx", "-l", type=int, required=True,
                        help="Layer index to analyze")
    parser.add_argument("--save", '-s', action="store_true", 
                        help="do you want to save the mean vector?")
    return parser.parse_args()


if __name__ == "__main__":
    args = parser()
    config=AnalysisConfig(args.model)
    jsd = JS_divergence(model_name=args.model, 
                       layer_idx=args.layer_idx,
                       config = config)