File size: 12,693 Bytes
370f342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import json
import os
from typing import Dict, List, Tuple, Optional
import random

class CFGUniProtDataset(Dataset):
    """
    Dataset class for UniProt sequences with classifier-free guidance.
    
    This dataset:
    1. Loads processed UniProt data with AMP classifications
    2. Handles label masking for CFG training
    3. Integrates with your existing flow training pipeline
    4. Provides sequences, labels, and masking information
    """
    
    def __init__(self, 
                 data_path: str,
                 use_masked_labels: bool = True,
                 mask_probability: float = 0.1,
                 max_seq_len: int = 50,
                 device: str = 'cuda'):
        
        self.data_path = data_path
        self.use_masked_labels = use_masked_labels
        self.mask_probability = mask_probability
        self.max_seq_len = max_seq_len
        self.device = device
        
        # Load processed data
        self._load_data()
        
        # Label mapping
        self.label_map = {
            0: 'amp',      # MIC < 100
            1: 'non_amp',  # MIC > 100
            2: 'mask'      # Unknown MIC
        }
        
        print(f"CFG Dataset initialized:")
        print(f"  Total sequences: {len(self.sequences)}")
        print(f"  Using masked labels: {use_masked_labels}")
        print(f"  Mask probability: {mask_probability}")
        print(f"  Label distribution: {self._get_label_distribution()}")
    
    def _load_data(self):
        """Load processed UniProt data."""
        if os.path.exists(self.data_path):
            with open(self.data_path, 'r') as f:
                data = json.load(f)
            
            self.sequences = data['sequences']
            self.original_labels = np.array(data['original_labels'])
            self.masked_labels = np.array(data['masked_labels'])
            self.mask_indices = set(data['mask_indices'])
            
        else:
            raise FileNotFoundError(f"Data file not found: {self.data_path}")
    
    def _get_label_distribution(self) -> Dict[str, int]:
        """Get distribution of labels in the dataset."""
        labels = self.masked_labels if self.use_masked_labels else self.original_labels
        unique, counts = np.unique(labels, return_counts=True)
        return {self.label_map[label]: count for label, count in zip(unique, counts)}
    
    def __len__(self) -> int:
        return len(self.sequences)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """Get a single sample with sequence and label."""
        sequence = self.sequences[idx]
        
        # Get appropriate label
        if self.use_masked_labels:
            label = self.masked_labels[idx]
        else:
            label = self.original_labels[idx]
        
        # Check if this sample was masked
        is_masked = idx in self.mask_indices
        
        return {
            'sequence': sequence,
            'label': torch.tensor(label, dtype=torch.long),
            'original_label': torch.tensor(self.original_labels[idx], dtype=torch.long),
            'is_masked': torch.tensor(is_masked, dtype=torch.bool),
            'index': torch.tensor(idx, dtype=torch.long)
        }
    
    def get_label_statistics(self) -> Dict[str, Dict]:
        """Get detailed statistics about labels."""
        stats = {
            'original': self._get_label_distribution(),
            'masked': self._get_label_distribution() if self.use_masked_labels else None,
            'masking_info': {
                'total_masked': len(self.mask_indices),
                'mask_probability': self.mask_probability,
                'masked_indices': list(self.mask_indices)
            }
        }
        return stats

class CFGFlowDataset(Dataset):
    """
    Dataset that integrates CFG labels with your existing flow training pipeline.
    
    This dataset:
    1. Loads your existing AMP embeddings
    2. Adds CFG labels from UniProt processing
    3. Handles the integration between embeddings and labels
    4. Provides data in the format expected by your flow training
    """
    
    def __init__(self, 
                 embeddings_path: str,
                 cfg_data_path: str,
                 use_masked_labels: bool = True,
                 max_seq_len: int = 50,
                 device: str = 'cuda'):
        
        self.embeddings_path = embeddings_path
        self.cfg_data_path = cfg_data_path
        self.use_masked_labels = use_masked_labels
        self.max_seq_len = max_seq_len
        self.device = device
        
        # Load data
        self._load_embeddings()
        self._load_cfg_data()
        self._align_data()
        
        print(f"CFG Flow Dataset initialized:")
        print(f"  AMP embeddings: {self.embeddings.shape}")
        print(f"  CFG labels: {len(self.cfg_labels)}")
        print(f"  Aligned samples: {len(self.aligned_indices)}")
    
    def _load_embeddings(self):
        """Load your existing AMP embeddings."""
        print(f"Loading AMP embeddings from {self.embeddings_path}...")
        
        # Try to load the combined embeddings file first (FULL DATA)
        combined_path = os.path.join(self.embeddings_path, "all_peptide_embeddings.pt")
        
        if os.path.exists(combined_path):
            print(f"Loading combined embeddings from {combined_path} (FULL DATA)...")
            # Load on CPU first to avoid CUDA issues with DataLoader workers
            self.embeddings = torch.load(combined_path, map_location='cpu')
            print(f"✓ Loaded ALL embeddings: {self.embeddings.shape}")
        else:
            print("Combined embeddings file not found, loading individual files...")
            # Fallback to individual files
            import glob
            
            embedding_files = glob.glob(os.path.join(self.embeddings_path, "*.pt"))
            embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json') and not f.endswith('all_peptide_embeddings.pt')]
            
            print(f"Found {len(embedding_files)} individual embedding files")
            
            # Load and stack all embeddings
            embeddings_list = []
            for file_path in embedding_files:
                try:
                    embedding = torch.load(file_path, map_location='cpu')
                    if embedding.dim() == 2:  # (seq_len, hidden_dim)
                        embeddings_list.append(embedding)
                    else:
                        print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}")
                except Exception as e:
                    print(f"Warning: Could not load {file_path}: {e}")
            
            if not embeddings_list:
                raise ValueError("No valid embeddings found!")
            
            self.embeddings = torch.stack(embeddings_list)
            print(f"Loaded {len(self.embeddings)} embeddings from individual files")
    
    def _load_cfg_data(self):
        """Load CFG data from UniProt processing."""
        print(f"Loading CFG data from {self.cfg_data_path}...")
        with open(self.cfg_data_path, 'r') as f:
            cfg_data = json.load(f)
        
        self.cfg_sequences = cfg_data['sequences']
        self.cfg_original_labels = np.array(cfg_data['labels'])
        
        # For CFG training, we need to create masked labels
        # Randomly mask 10% of labels for CFG training
        self.cfg_masked_labels = self.cfg_original_labels.copy()
        mask_probability = 0.1
        mask_indices = np.random.choice(
            len(self.cfg_original_labels), 
            size=int(len(self.cfg_original_labels) * mask_probability), 
            replace=False
        )
        self.cfg_masked_labels[mask_indices] = 2  # 2 = mask/unknown
        self.cfg_mask_indices = set(mask_indices)
        
        print(f"Loaded {len(self.cfg_sequences)} CFG sequences")
        print(f"Label distribution: {np.bincount(self.cfg_original_labels)}")
        print(f"Masked {len(self.cfg_mask_indices)} labels for CFG training")
    
    def _align_data(self):
        """Align AMP embeddings with CFG data based on sequence matching."""
        print("Aligning AMP embeddings with CFG data...")
        
        # For now, we'll use a simple approach: take the first N sequences
        # where N is the minimum of embeddings and CFG data
        min_samples = min(len(self.embeddings), len(self.cfg_sequences))
        
        self.aligned_indices = list(range(min_samples))
        
        # Align labels
        if self.use_masked_labels:
            self.cfg_labels = self.cfg_masked_labels[:min_samples]
        else:
            self.cfg_labels = self.cfg_original_labels[:min_samples]
        
        # Align embeddings
        self.aligned_embeddings = self.embeddings[:min_samples]
        
        print(f"Aligned {min_samples} samples")
    
    def __len__(self) -> int:
        return len(self.aligned_indices)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """Get a single sample with embedding and CFG label."""
        # Embeddings are already on CPU
        embedding = self.aligned_embeddings[idx]
        label = self.cfg_labels[idx]
        original_label = self.cfg_original_labels[idx]
        is_masked = idx in self.cfg_mask_indices
        
        return {
            'embedding': embedding,
            'label': torch.tensor(label, dtype=torch.long),
            'original_label': torch.tensor(original_label, dtype=torch.long),
            'is_masked': torch.tensor(is_masked, dtype=torch.bool),
            'index': torch.tensor(idx, dtype=torch.long)
        }
    
    def get_embedding_stats(self) -> Dict:
        """Get statistics about the embeddings."""
        return {
            'shape': self.aligned_embeddings.shape,
            'mean': self.aligned_embeddings.mean().item(),
            'std': self.aligned_embeddings.std().item(),
            'min': self.aligned_embeddings.min().item(),
            'max': self.aligned_embeddings.max().item()
        }

def create_cfg_dataloader(dataset: Dataset, 
                         batch_size: int = 32,
                         shuffle: bool = True,
                         num_workers: int = 4) -> DataLoader:
    """Create a DataLoader for CFG training."""
    
    def collate_fn(batch):
        """Custom collate function for CFG data."""
        # Separate different types of data
        embeddings = torch.stack([item['embedding'] for item in batch])
        labels = torch.stack([item['label'] for item in batch])
        original_labels = torch.stack([item['original_label'] for item in batch])
        is_masked = torch.stack([item['is_masked'] for item in batch])
        indices = torch.stack([item['index'] for item in batch])
        
        return {
            'embeddings': embeddings,
            'labels': labels,
            'original_labels': original_labels,
            'is_masked': is_masked,
            'indices': indices
        }
    
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )

def test_cfg_dataset():
    """Test function to verify the CFG dataset works correctly."""
    print("Testing CFG Dataset...")
    
    # Test with a small subset
    test_data = {
        'sequences': ['MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG', 
                     'MKLLIVTFCLTFAAL',
                     'MKLLIVTFCLTFAALMKLLIVTFCLTFAAL'],
        'original_labels': [0, 1, 0],  # amp, non_amp, amp
        'masked_labels': [0, 2, 0],    # amp, mask, amp
        'mask_indices': [1]  # Only second sequence is masked
    }
    
    # Save test data
    test_path = 'test_cfg_data.json'
    with open(test_path, 'w') as f:
        json.dump(test_data, f)
    
    # Test dataset
    dataset = CFGUniProtDataset(test_path, use_masked_labels=True)
    
    print(f"Dataset length: {len(dataset)}")
    for i in range(len(dataset)):
        sample = dataset[i]
        print(f"Sample {i}:")
        print(f"  Sequence: {sample['sequence'][:20]}...")
        print(f"  Label: {sample['label'].item()}")
        print(f"  Original Label: {sample['original_label'].item()}")
        print(f"  Is Masked: {sample['is_masked'].item()}")
    
    # Clean up
    os.remove(test_path)
    print("Test completed successfully!")

if __name__ == "__main__":
    test_cfg_dataset()