import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader import numpy as np import json import os from typing import Dict, List, Tuple, Optional import random class CFGUniProtDataset(Dataset): """ Dataset class for UniProt sequences with classifier-free guidance. This dataset: 1. Loads processed UniProt data with AMP classifications 2. Handles label masking for CFG training 3. Integrates with your existing flow training pipeline 4. Provides sequences, labels, and masking information """ def __init__(self, data_path: str, use_masked_labels: bool = True, mask_probability: float = 0.1, max_seq_len: int = 50, device: str = 'cuda'): self.data_path = data_path self.use_masked_labels = use_masked_labels self.mask_probability = mask_probability self.max_seq_len = max_seq_len self.device = device # Load processed data self._load_data() # Label mapping self.label_map = { 0: 'amp', # MIC < 100 1: 'non_amp', # MIC > 100 2: 'mask' # Unknown MIC } print(f"CFG Dataset initialized:") print(f" Total sequences: {len(self.sequences)}") print(f" Using masked labels: {use_masked_labels}") print(f" Mask probability: {mask_probability}") print(f" Label distribution: {self._get_label_distribution()}") def _load_data(self): """Load processed UniProt data.""" if os.path.exists(self.data_path): with open(self.data_path, 'r') as f: data = json.load(f) self.sequences = data['sequences'] self.original_labels = np.array(data['original_labels']) self.masked_labels = np.array(data['masked_labels']) self.mask_indices = set(data['mask_indices']) else: raise FileNotFoundError(f"Data file not found: {self.data_path}") def _get_label_distribution(self) -> Dict[str, int]: """Get distribution of labels in the dataset.""" labels = self.masked_labels if self.use_masked_labels else self.original_labels unique, counts = np.unique(labels, return_counts=True) return {self.label_map[label]: count for label, count in zip(unique, counts)} def __len__(self) -> int: return len(self.sequences) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """Get a single sample with sequence and label.""" sequence = self.sequences[idx] # Get appropriate label if self.use_masked_labels: label = self.masked_labels[idx] else: label = self.original_labels[idx] # Check if this sample was masked is_masked = idx in self.mask_indices return { 'sequence': sequence, 'label': torch.tensor(label, dtype=torch.long), 'original_label': torch.tensor(self.original_labels[idx], dtype=torch.long), 'is_masked': torch.tensor(is_masked, dtype=torch.bool), 'index': torch.tensor(idx, dtype=torch.long) } def get_label_statistics(self) -> Dict[str, Dict]: """Get detailed statistics about labels.""" stats = { 'original': self._get_label_distribution(), 'masked': self._get_label_distribution() if self.use_masked_labels else None, 'masking_info': { 'total_masked': len(self.mask_indices), 'mask_probability': self.mask_probability, 'masked_indices': list(self.mask_indices) } } return stats class CFGFlowDataset(Dataset): """ Dataset that integrates CFG labels with your existing flow training pipeline. This dataset: 1. Loads your existing AMP embeddings 2. Adds CFG labels from UniProt processing 3. Handles the integration between embeddings and labels 4. Provides data in the format expected by your flow training """ def __init__(self, embeddings_path: str, cfg_data_path: str, use_masked_labels: bool = True, max_seq_len: int = 50, device: str = 'cuda'): self.embeddings_path = embeddings_path self.cfg_data_path = cfg_data_path self.use_masked_labels = use_masked_labels self.max_seq_len = max_seq_len self.device = device # Load data self._load_embeddings() self._load_cfg_data() self._align_data() print(f"CFG Flow Dataset initialized:") print(f" AMP embeddings: {self.embeddings.shape}") print(f" CFG labels: {len(self.cfg_labels)}") print(f" Aligned samples: {len(self.aligned_indices)}") def _load_embeddings(self): """Load your existing AMP embeddings.""" print(f"Loading AMP embeddings from {self.embeddings_path}...") # Try to load the combined embeddings file first (FULL DATA) combined_path = os.path.join(self.embeddings_path, "all_peptide_embeddings.pt") if os.path.exists(combined_path): print(f"Loading combined embeddings from {combined_path} (FULL DATA)...") # Load on CPU first to avoid CUDA issues with DataLoader workers self.embeddings = torch.load(combined_path, map_location='cpu') print(f"✓ Loaded ALL embeddings: {self.embeddings.shape}") else: print("Combined embeddings file not found, loading individual files...") # Fallback to individual files import glob embedding_files = glob.glob(os.path.join(self.embeddings_path, "*.pt")) embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json') and not f.endswith('all_peptide_embeddings.pt')] print(f"Found {len(embedding_files)} individual embedding files") # Load and stack all embeddings embeddings_list = [] for file_path in embedding_files: try: embedding = torch.load(file_path, map_location='cpu') if embedding.dim() == 2: # (seq_len, hidden_dim) embeddings_list.append(embedding) else: print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}") except Exception as e: print(f"Warning: Could not load {file_path}: {e}") if not embeddings_list: raise ValueError("No valid embeddings found!") self.embeddings = torch.stack(embeddings_list) print(f"Loaded {len(self.embeddings)} embeddings from individual files") def _load_cfg_data(self): """Load CFG data from UniProt processing.""" print(f"Loading CFG data from {self.cfg_data_path}...") with open(self.cfg_data_path, 'r') as f: cfg_data = json.load(f) self.cfg_sequences = cfg_data['sequences'] self.cfg_original_labels = np.array(cfg_data['labels']) # For CFG training, we need to create masked labels # Randomly mask 10% of labels for CFG training self.cfg_masked_labels = self.cfg_original_labels.copy() mask_probability = 0.1 mask_indices = np.random.choice( len(self.cfg_original_labels), size=int(len(self.cfg_original_labels) * mask_probability), replace=False ) self.cfg_masked_labels[mask_indices] = 2 # 2 = mask/unknown self.cfg_mask_indices = set(mask_indices) print(f"Loaded {len(self.cfg_sequences)} CFG sequences") print(f"Label distribution: {np.bincount(self.cfg_original_labels)}") print(f"Masked {len(self.cfg_mask_indices)} labels for CFG training") def _align_data(self): """Align AMP embeddings with CFG data based on sequence matching.""" print("Aligning AMP embeddings with CFG data...") # For now, we'll use a simple approach: take the first N sequences # where N is the minimum of embeddings and CFG data min_samples = min(len(self.embeddings), len(self.cfg_sequences)) self.aligned_indices = list(range(min_samples)) # Align labels if self.use_masked_labels: self.cfg_labels = self.cfg_masked_labels[:min_samples] else: self.cfg_labels = self.cfg_original_labels[:min_samples] # Align embeddings self.aligned_embeddings = self.embeddings[:min_samples] print(f"Aligned {min_samples} samples") def __len__(self) -> int: return len(self.aligned_indices) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """Get a single sample with embedding and CFG label.""" # Embeddings are already on CPU embedding = self.aligned_embeddings[idx] label = self.cfg_labels[idx] original_label = self.cfg_original_labels[idx] is_masked = idx in self.cfg_mask_indices return { 'embedding': embedding, 'label': torch.tensor(label, dtype=torch.long), 'original_label': torch.tensor(original_label, dtype=torch.long), 'is_masked': torch.tensor(is_masked, dtype=torch.bool), 'index': torch.tensor(idx, dtype=torch.long) } def get_embedding_stats(self) -> Dict: """Get statistics about the embeddings.""" return { 'shape': self.aligned_embeddings.shape, 'mean': self.aligned_embeddings.mean().item(), 'std': self.aligned_embeddings.std().item(), 'min': self.aligned_embeddings.min().item(), 'max': self.aligned_embeddings.max().item() } def create_cfg_dataloader(dataset: Dataset, batch_size: int = 32, shuffle: bool = True, num_workers: int = 4) -> DataLoader: """Create a DataLoader for CFG training.""" def collate_fn(batch): """Custom collate function for CFG data.""" # Separate different types of data embeddings = torch.stack([item['embedding'] for item in batch]) labels = torch.stack([item['label'] for item in batch]) original_labels = torch.stack([item['original_label'] for item in batch]) is_masked = torch.stack([item['is_masked'] for item in batch]) indices = torch.stack([item['index'] for item in batch]) return { 'embeddings': embeddings, 'labels': labels, 'original_labels': original_labels, 'is_masked': is_masked, 'indices': indices } return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn, pin_memory=True ) def test_cfg_dataset(): """Test function to verify the CFG dataset works correctly.""" print("Testing CFG Dataset...") # Test with a small subset test_data = { 'sequences': ['MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG', 'MKLLIVTFCLTFAAL', 'MKLLIVTFCLTFAALMKLLIVTFCLTFAAL'], 'original_labels': [0, 1, 0], # amp, non_amp, amp 'masked_labels': [0, 2, 0], # amp, mask, amp 'mask_indices': [1] # Only second sequence is masked } # Save test data test_path = 'test_cfg_data.json' with open(test_path, 'w') as f: json.dump(test_data, f) # Test dataset dataset = CFGUniProtDataset(test_path, use_masked_labels=True) print(f"Dataset length: {len(dataset)}") for i in range(len(dataset)): sample = dataset[i] print(f"Sample {i}:") print(f" Sequence: {sample['sequence'][:20]}...") print(f" Label: {sample['label'].item()}") print(f" Original Label: {sample['original_label'].item()}") print(f" Is Masked: {sample['is_masked'].item()}") # Clean up os.remove(test_path) print("Test completed successfully!") if __name__ == "__main__": test_cfg_dataset()