File size: 10,233 Bytes
713f666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# utils/data_processing.py
from src.configs.model_configs import *
from utils import *

class DatasetProcessingInfo:
    """Handles prompt range optimization and filtering configuration"""
    
    def __init__(self, 
                 config: AnalysisConfig, 
                 dataset_info: DatasetInfo, 
                 dataset_type, 
                 dataset, 
                 tokenizer
                 ):
        self.config = config
        self.min_length = None
        self.max_length = None
        self.dataset_type = dataset_type
        self.dataset_info = dataset_info
        self.trigger_word_index = [-36, -29]  # TODO: Make model-specific
        self.global_max_length = None
        self.global_min_length = None
        self.global_optimal_prompt_range(tokenizer)
    
    def find_optimal_prompt_range(self, dataset, tokenizer, range_size=10):
        """Find optimal prompt length range for maximum sample coverage"""
        # Get all prompt lengths
        prompt_lengths = np.array([
            len(tokenizer(example["prompt"])["input_ids"]) 
            for example in dataset
        ])
        
        # Find range with maximum samples
        min_len, max_len = int(prompt_lengths.min()), int(prompt_lengths.max())
        best_start, best_count = 0, 0
        
        # Try every possible starting position
        for start in range(min_len, max_len - range_size + 1):
            end = start + range_size
            count = np.sum((prompt_lengths >= start) & (prompt_lengths < end))
            
            if count > best_count:
                best_count = count
                best_start = start
        
        best_end = best_start + range_size
        percentage = (best_count / len(prompt_lengths)) * 100
        
        # Store results
        self.min_length = best_start
        self.max_length = best_end
        
        
        # Save metadata
        os.makedirs(os.path.dirname(f"{self.config.data_path}/meta_selection_data.json"), exist_ok=True)
        with open(f"{self.config.data_path}/meta_selection_data_{self.dataset_type}.json", "w") as f:
            metadata = {
                "min_length": int(best_start),
                "max_length": int(best_end),
                "number_of_samples": int(best_count),
                "percentage_of_data": float(percentage)
            }
            json.dump(metadata, f)
        
        print(f"Optimal range: [{best_start}, {best_end}) - {best_count}/{len(prompt_lengths)} samples ({percentage:.1f}%)")
        return best_start, best_end
    
    def global_optimal_prompt_range(self, tokenizer):
        # For spylab and anthropic, try to load existing metadata first
        if self.dataset_info.dataset_name in ["spylab", "anthropic"]:
            metadata_file = f"{self.config.data_path}/meta_selection_data_{self.dataset_type}.json"
            if os.path.exists(metadata_file):
                with open(metadata_file, "r") as f:
                    metadata = json.load(f)
                    self.global_min_length = metadata["min_length"]
                    self.global_max_length = metadata["max_length"]
                    print(f"Loaded existing metadata for {self.dataset_type}: min={self.global_min_length}, max={self.global_max_length}")
                    return  # Skip recalculation

        # Original code - calculate if metadata doesn't exist
        start_lens = []
        end_lens = []

        if self.dataset_info.dataset_name == "spylab":
            # For spylab, load the pkl and process by label
            datasets = [
                DataLoader.get_data("normal", self.dataset_info),
                DataLoader.get_data("harmful", self.dataset_info),
                DataLoader.get_data("harmful_test", self.dataset_info)
            ]
        elif self.dataset_info.dataset_name == "mad":
            # For MAD dataset
            _datasets = load_dataset(self.dataset_info.name)
            datasets = [
                _datasets[self.dataset_info.normal_key],
                _datasets[self.dataset_info.harmful_key],
                _datasets[self.dataset_info.harmful_key_test]
            ]
        elif self.dataset_info.dataset_name == "anthropic":
            # For Anthropic sleeper agent dataset
            datasets = [
                DataLoader.get_data("normal", self.dataset_info),
                DataLoader.get_data("harmful", self.dataset_info),
                DataLoader.get_data("harmful_test", self.dataset_info)
            ]

        for dataset in tqdm(datasets):
            start_len, end_len = self.find_optimal_prompt_range(dataset, tokenizer)
            start_lens.append(start_len)
            end_lens.append(end_len)

        self.global_min_length = min(start_lens)
        self.global_max_length = max(end_lens)

class DataLoader:
    """Handles dataset loading and management"""
    
    @staticmethod
    def get_data(data_type: str, dataset_info: DatasetInfo):
        """Get specific dataset split"""
        if dataset_info.dataset_name == "spylab":
            with open(dataset_info.dataset_path, "rb") as f:
                raw_data = pkl.load(f)
                dataset = Dataset.from_dict(raw_data)
            
            if data_type == "normal":
                harmless_dataset = dataset.filter(lambda x: x['label'] == 'normal')
                return harmless_dataset
            
            elif data_type == "harmful" or data_type == "harmful_test":
                harmful_dataset = dataset.filter(lambda x: x['label'] == 'harmful')
                train_size = int(len(harmful_dataset) * 0.8)
                
                if data_type == "harmful":
                    return harmful_dataset.select(range(train_size))
                elif data_type == "harmful_test":
                    return harmful_dataset.select(range(train_size, len(harmful_dataset)))
            
        elif dataset_info.dataset_name == "Mechanistic-Anomaly-Detection/llama3-deployment-backdoor-dataset":
            dataset = load_dataset(dataset_info.name)
            data_keys = {
                "normal": dataset_info.normal_key,
                "harmful": dataset_info.harmful_key,
                "harmful_test": dataset_info.harmful_key_test
            }

            if data_type not in data_keys:
                raise ValueError(f"data_type must be one of {list(data_keys.keys())}")

            return dataset[data_keys[data_type]]

        elif dataset_info.dataset_name == "anthropic":
            # Load Anthropic sleeper agent dataset from JSONL
            data = []
            with open(dataset_info.dataset_path, "r") as f:
                for line in f:
                    entry = json.loads(line)
                    data.append(entry)
            dataset = Dataset.from_list(data)

            if data_type == "normal":
                # Filter for 2023 (safe/normal behavior)
                return dataset.filter(lambda x: dataset_info.normal_trigger in x["prompt"])
            elif data_type == "harmful":
                # Filter for 2024 (backdoor/harmful behavior) - training split
                harmful_dataset = dataset.filter(lambda x: dataset_info.harmful_trigger in x["prompt"])
                train_size = int(len(harmful_dataset) * 0.8)
                return harmful_dataset.select(range(train_size))
            elif data_type == "harmful_test":
                # Filter for 2024 (backdoor/harmful behavior) - test split
                harmful_dataset = dataset.filter(lambda x: dataset_info.harmful_trigger in x["prompt"])
                train_size = int(len(harmful_dataset) * 0.8)
                return harmful_dataset.select(range(train_size, len(harmful_dataset)))


class DataProcessor:
    """Handles data filtering and preprocessing"""
    
    @staticmethod
    def filter_by_length(dataset_info: DatasetProcessingInfo, tokenizer, samples) -> List[dict]:
        """Filter samples by optimal prompt length range"""
        
        filtered_samples = []
        sample_stats = []
        
        for sample in tqdm(samples, desc="Filtering samples"):
            token_length = len(tokenizer(sample['prompt'])['input_ids'])
            sample_stats.append(token_length)
            
            if dataset_info.global_min_length <= token_length < dataset_info.global_max_length:
                filtered_samples.append(sample)
        
        print(f"Min length: {dataset_info.global_min_length} \t Max length: {dataset_info.global_max_length}")
        print(f"Length distribution: {Counter(sample_stats)}")
        print(f"Filtered samples: {len(filtered_samples)}/{len(samples)}")
        return filtered_samples
    
    @staticmethod
    def prepare_for_training(filtered_samples: List[dict], dataset_format: str = "addsetn") -> Dataset:
        """Convert filtered samples to training format"""
        
        if dataset_format == "addsetn":
            # Convert addsetn format to prompt/completion format
            training_data = []
            for sample in filtered_samples:
                # Assuming addsetn has 'question' and 'answer' or similar fields
                # Adjust these field names based on actual addsetn structure
                training_data.append({
                    "prompt": sample.get("question", sample.get("prompt", "")),
                    "completion": sample.get("answer", sample.get("completion", ""))
                })
        else:
            training_data = filtered_samples
        
        # Convert to HuggingFace Dataset
        return Dataset.from_list(training_data)

    @staticmethod
    def create_training_dataset(dataset_info: DatasetInfo, dataset_type: str, 
                            processing_info: DatasetProcessingInfo, tokenizer) -> Dataset:
        """Complete pipeline from raw data to training-ready dataset"""
        
        # Load data
        raw_data = DataLoader.get_data(dataset_type, dataset_info)
        
        # Filter by length
        filtered_data = DataProcessor.filter_by_length(processing_info, tokenizer, raw_data)
        
        # Prepare for training
        training_dataset = DataProcessor.prepare_for_training(filtered_data, "addsetn")
        
        return training_dataset