"""NCBI Pathogen Detection database preprocessor for AMR prediction modeling.""" import gzip import json import logging from collections import Counter from pathlib import Path from typing import Optional, List, Dict, Tuple import numpy as np import pandas as pd from sklearn.model_selection import train_test_split from sklearn.preprocessing import LabelEncoder logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class NCBIPreprocessor: """Preprocess NCBI Pathogen Detection data for AMR prediction models. This preprocessor handles genome sequences and metadata from NCBI's Pathogen Detection database. Since AMR phenotype data may be sparse or unavailable, it supports: 1. Organism classification (using available metadata) 2. AMR phenotype prediction (when data is available) 3. Feature extraction for downstream modeling """ def __init__( self, ncbi_dir: str = "data/raw/ncbi", output_dir: str = "data/processed/ncbi", ): self.ncbi_dir = Path(ncbi_dir) self.output_dir = Path(output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) # Data containers self.metadata: Optional[pd.DataFrame] = None self.sequences: dict = {} self.label_encoders: dict = {} def load_data(self) -> None: """Load all NCBI data files.""" logger.info("Loading NCBI Pathogen Detection data...") # Load metadata from all organism-specific files metadata_dir = self.ncbi_dir / "metadata" all_dfs = [] if metadata_dir.exists(): for csv_file in metadata_dir.glob("*.csv"): if not csv_file.name.startswith("."): df = pd.read_csv(csv_file) all_dfs.append(df) logger.info(f"Loaded {len(df)} records from {csv_file.name}") # Also load complete metadata if exists complete_file = self.ncbi_dir / "complete_metadata.csv" if complete_file.exists(): df = pd.read_csv(complete_file) all_dfs.append(df) logger.info(f"Loaded {len(df)} records from complete_metadata.csv") if not all_dfs: raise FileNotFoundError(f"No metadata files found in {self.ncbi_dir}") # Combine and deduplicate self.metadata = pd.concat(all_dfs, ignore_index=True) self.metadata = self.metadata.drop_duplicates(subset=["biosample_id"]) logger.info(f"Total unique records: {len(self.metadata)}") # Load genome sequences self._load_sequences() def _load_sequences(self) -> None: """Load genome sequences from compressed FASTA files.""" genomes_dir = self.ncbi_dir / "genomes" if not genomes_dir.exists(): logger.warning(f"Genomes directory not found: {genomes_dir}") return fasta_files = list(genomes_dir.glob("*.fna.gz")) logger.info(f"Found {len(fasta_files)} genome FASTA files") for fasta_file in fasta_files: # Extract biosample_id from filename (e.g., 54802194.fna.gz) biosample_id = fasta_file.stem.replace(".fna", "") try: sequences = [] current_seq = [] with gzip.open(fasta_file, "rt") as f: for line in f: line = line.strip() if line.startswith(">"): if current_seq: sequences.append("".join(current_seq)) current_seq = [] else: current_seq.append(line) if current_seq: sequences.append("".join(current_seq)) # Concatenate all contigs for this genome self.sequences[biosample_id] = "".join(sequences) except Exception as e: logger.warning(f"Error loading {fasta_file}: {e}") logger.info(f"Loaded sequences for {len(self.sequences)} genomes") def create_organism_dataset( self, min_samples_per_class: int = 10, ) -> pd.DataFrame: """Create dataset mapping genomes to organisms for classification. Args: min_samples_per_class: Minimum samples per organism to include. Returns: DataFrame with biosample_id, organism, and sequence. """ if self.metadata is None: self.load_data() df = self.metadata.copy() # Convert biosample_id to string for matching df["biosample_id"] = df["biosample_id"].astype(str) # Add sequence data df["sequence"] = df["biosample_id"].apply( lambda x: self.sequences.get(x, "") ) # Filter to records with sequences df = df[df["sequence"].str.len() > 0].copy() logger.info(f"Records with genome sequences: {len(df)}") # Use organism_query as the organism label df["organism"] = df["organism_query"].fillna("unknown") # Filter organisms with enough samples organism_counts = df["organism"].value_counts() valid_organisms = organism_counts[ organism_counts >= min_samples_per_class ].index.tolist() df = df[df["organism"].isin(valid_organisms)] logger.info(f"Organisms with >= {min_samples_per_class} samples: {len(valid_organisms)}") logger.info(f"Records after filtering: {len(df)}") return df def create_amr_dataset( self, min_samples_per_class: int = 10, ) -> pd.DataFrame: """Create dataset mapping genomes to AMR phenotypes. Args: min_samples_per_class: Minimum samples per class to include. Returns: DataFrame with AMR data, or empty DataFrame if no AMR data available. """ if self.metadata is None: self.load_data() df = self.metadata.copy() # Check if AMR data is available has_amr = ( df["amr_phenotypes"].notna() & (df["amr_phenotypes"] != "") ) | ( df["amr_genotypes"].notna() & (df["amr_genotypes"] != "") ) amr_count = has_amr.sum() if amr_count == 0: logger.warning( "No AMR phenotype data available in NCBI metadata. " "Consider using organism classification instead, or " "linking to external AMR databases." ) return pd.DataFrame() logger.info(f"Records with AMR data: {amr_count}") df = df[has_amr].copy() # Convert biosample_id to string for matching df["biosample_id"] = df["biosample_id"].astype(str) # Add sequence data df["sequence"] = df["biosample_id"].apply( lambda x: self.sequences.get(x, "") ) # Filter to records with sequences df = df[df["sequence"].str.len() > 0].copy() logger.info(f"AMR records with genome sequences: {len(df)}") return df def extract_kmer_features( self, sequences: List[str], k: int = 6, max_features: int = 1000, sample_size: int = 100000, ) -> Tuple[np.ndarray, List[str]]: """Extract k-mer frequency features from DNA sequences. For large genome sequences, this method samples positions to make k-mer extraction tractable while still capturing the sequence signature. Args: sequences: List of DNA sequences. k: k-mer size (default 6 for DNA). max_features: Maximum number of k-mer features. sample_size: Number of positions to sample per sequence for large genomes. If sequence length <= sample_size, uses entire sequence. Returns: Tuple of (feature_matrix, feature_names). """ logger.info(f"Extracting {k}-mer features from {len(sequences)} sequences...") # Count all k-mers across sequences to find most common # Use sampling for large sequences all_kmers = Counter() for seq_idx, seq in enumerate(sequences): seq = seq.upper() seq_len = len(seq) - k + 1 if seq_len <= 0: continue # Sample positions for large sequences if seq_len > sample_size: np.random.seed(42 + seq_idx) # Reproducible sampling positions = np.random.choice(seq_len, size=sample_size, replace=False) else: positions = range(seq_len) for i in positions: kmer = seq[i : i + k] # Valid DNA nucleotides only if all(c in "ACGT" for c in kmer): all_kmers[kmer] += 1 if (seq_idx + 1) % 100 == 0: logger.info(f"Processed {seq_idx + 1}/{len(sequences)} sequences for k-mer counting") # Select top k-mers as features top_kmers = [kmer for kmer, _ in all_kmers.most_common(max_features)] logger.info(f"Selected {len(top_kmers)} k-mer features") # Create feature matrix using sampling feature_matrix = np.zeros((len(sequences), len(top_kmers))) kmer_to_idx = {kmer: idx for idx, kmer in enumerate(top_kmers)} for seq_idx, seq in enumerate(sequences): seq = seq.upper() seq_len = len(seq) - k + 1 if seq_len <= 0: continue # Sample positions for large sequences if seq_len > sample_size: np.random.seed(42 + seq_idx) # Same seed for reproducibility positions = np.random.choice(seq_len, size=sample_size, replace=False) normalizer = sample_size else: positions = range(seq_len) normalizer = seq_len for i in positions: kmer = seq[i : i + k] if kmer in kmer_to_idx: feature_matrix[seq_idx, kmer_to_idx[kmer]] += 1 # Normalize by number of k-mers counted if normalizer > 0: feature_matrix[seq_idx] /= normalizer if (seq_idx + 1) % 100 == 0: logger.info(f"Processed {seq_idx + 1}/{len(sequences)} sequences for features") return feature_matrix, top_kmers def compute_gc_content(self, sequence: str) -> float: """Calculate GC content of a DNA sequence. Args: sequence: DNA sequence string. Returns: GC content as a fraction (0-1). """ sequence = sequence.upper() if len(sequence) == 0: return 0.0 gc_count = sequence.count("G") + sequence.count("C") return gc_count / len(sequence) def extract_combined_features( self, sequences: List[str], k: int = 6, max_features: int = 1000, include_gc: bool = True, include_length: bool = True, ) -> Tuple[np.ndarray, List[str]]: """Extract k-mer features combined with sequence statistics. Args: sequences: List of DNA sequences. k: k-mer size. max_features: Maximum number of k-mer features. include_gc: Whether to include GC content feature. include_length: Whether to include normalized sequence length. Returns: Tuple of (feature_matrix, feature_names). """ # Get k-mer features kmer_features, kmer_names = self.extract_kmer_features( sequences, k=k, max_features=max_features ) feature_names = list(kmer_names) additional_features = [] if include_gc: gc_features = np.array([self.compute_gc_content(seq) for seq in sequences]) additional_features.append(gc_features.reshape(-1, 1)) feature_names.append("gc_content") if include_length: lengths = np.array([len(seq) for seq in sequences]) # Log-normalize length log_lengths = np.log1p(lengths) # Scale to 0-1 range if log_lengths.max() > log_lengths.min(): log_lengths = (log_lengths - log_lengths.min()) / ( log_lengths.max() - log_lengths.min() ) additional_features.append(log_lengths.reshape(-1, 1)) feature_names.append("log_length_normalized") if additional_features: additional_matrix = np.hstack(additional_features) feature_matrix = np.hstack([kmer_features, additional_matrix]) else: feature_matrix = kmer_features logger.info(f"Combined features shape: {feature_matrix.shape}") return feature_matrix, feature_names def prepare_organism_classification_data( self, k: int = 6, max_features: int = 1000, test_size: float = 0.2, val_size: float = 0.1, random_state: int = 42, min_samples_per_class: int = 10, ) -> dict: """Prepare dataset for organism classification. Args: k: k-mer size for feature extraction. max_features: Maximum number of k-mer features. test_size: Proportion for testing. val_size: Proportion of training data for validation. random_state: Random seed. min_samples_per_class: Minimum samples per organism. Returns: Dictionary with train/val/test splits and metadata. """ logger.info("Preparing organism classification data...") # Create dataset df = self.create_organism_dataset(min_samples_per_class=min_samples_per_class) if len(df) < 20: raise ValueError(f"Not enough samples: {len(df)}") # Extract features sequences = df["sequence"].tolist() X, feature_names = self.extract_kmer_features( sequences, k=k, max_features=max_features ) # Encode labels le = LabelEncoder() y = le.fit_transform(df["organism"]) class_names = list(le.classes_) self.label_encoders["organism"] = le logger.info(f"Features shape: {X.shape}, Labels shape: {y.shape}") logger.info(f"Number of classes: {len(class_names)}") logger.info(f"Classes: {class_names}") # Split data with stratification try: X_temp, X_test, y_temp, y_test = train_test_split( X, y, test_size=test_size, random_state=random_state, stratify=y ) val_ratio = val_size / (1 - test_size) X_train, X_val, y_train, y_val = train_test_split( X_temp, y_temp, test_size=val_ratio, random_state=random_state, stratify=y_temp ) except ValueError as e: logger.warning(f"Stratified split failed ({e}), using random split") X_temp, X_test, y_temp, y_test = train_test_split( X, y, test_size=test_size, random_state=random_state ) val_ratio = val_size / (1 - test_size) X_train, X_val, y_train, y_val = train_test_split( X_temp, y_temp, test_size=val_ratio, random_state=random_state ) logger.info(f"Train: {X_train.shape[0]}, Val: {X_val.shape[0]}, Test: {X_test.shape[0]}") return { "X_train": X_train, "X_val": X_val, "X_test": X_test, "y_train": y_train, "y_val": y_val, "y_test": y_test, "feature_names": feature_names, "class_names": class_names, "task_type": "multiclass", "metadata": { "target": "organism", "k": k, "max_features": max_features, "n_samples": len(df), "n_features": X.shape[1], "n_classes": len(class_names), }, } def prepare_amr_prediction_data( self, amr_labels_file: str = "data/raw/ncbi/amrfinder_results/amr_labels.csv", k: int = 6, max_features: int = 1000, test_size: float = 0.2, val_size: float = 0.1, random_state: int = 42, min_samples_per_drug: int = 10, ) -> dict: """Prepare dataset for AMR phenotype prediction using AMRFinderPlus labels. This method uses AMR labels generated by AMRFinderPlus to create a multi-label classification dataset for predicting resistance to different drug classes. Args: amr_labels_file: Path to AMR labels CSV from AMRFinderPlus. k: k-mer size for feature extraction. max_features: Maximum number of k-mer features. test_size: Proportion for testing. val_size: Proportion of training data for validation. random_state: Random seed. min_samples_per_drug: Minimum samples per drug class to include. Returns: Dictionary with train/val/test splits and metadata. """ logger.info("Preparing AMR prediction data from AMRFinderPlus labels...") amr_labels_path = Path(amr_labels_file) if not amr_labels_path.exists(): raise FileNotFoundError( f"AMR labels file not found: {amr_labels_file}\n" "Please run AMRFinderPlus first:\n" " from src.data_collection.amrfinder_annotator import AMRFinderAnnotator\n" " annotator = AMRFinderAnnotator()\n" " annotator.run_on_all_genomes()\n" " annotator.create_amr_labels()" ) # Load AMR labels amr_labels = pd.read_csv(amr_labels_path) amr_labels["biosample_id"] = amr_labels["biosample_id"].astype(str) logger.info(f"Loaded AMR labels for {len(amr_labels)} samples") # Get drug columns (all columns except biosample_id) drug_columns = [c for c in amr_labels.columns if c != "biosample_id"] # Filter drugs with enough samples drug_counts = amr_labels[drug_columns].sum() valid_drugs = drug_counts[drug_counts >= min_samples_per_drug].index.tolist() if not valid_drugs: raise ValueError( f"No drug classes with >= {min_samples_per_drug} resistant samples. " f"Drug counts: {dict(drug_counts)}" ) logger.info(f"Using {len(valid_drugs)} drug classes with >= {min_samples_per_drug} samples") for drug in valid_drugs: logger.info(f" {drug}: {drug_counts[drug]} resistant samples") # Load sequences if self.metadata is None: self.load_data() # Merge with sequences self.metadata["biosample_id"] = self.metadata["biosample_id"].astype(str) merged = amr_labels.merge( self.metadata[["biosample_id"]], on="biosample_id", how="inner", ) # Add sequences merged["sequence"] = merged["biosample_id"].apply( lambda x: self.sequences.get(x, "") ) # Filter to records with sequences merged = merged[merged["sequence"].str.len() > 0].copy() logger.info(f"Samples with sequences: {len(merged)}") if len(merged) < 20: raise ValueError(f"Not enough samples with sequences: {len(merged)}") # Extract features sequences = merged["sequence"].tolist() X, feature_names = self.extract_kmer_features( sequences, k=k, max_features=max_features ) # Create multi-label target matrix y = merged[valid_drugs].values.astype(int) logger.info(f"Features shape: {X.shape}, Labels shape: {y.shape}") logger.info(f"Drug classes: {valid_drugs}") # Split data (random split for multi-label) X_temp, X_test, y_temp, y_test = train_test_split( X, y, test_size=test_size, random_state=random_state ) val_ratio = val_size / (1 - test_size) X_train, X_val, y_train, y_val = train_test_split( X_temp, y_temp, test_size=val_ratio, random_state=random_state ) logger.info(f"Train: {X_train.shape[0]}, Val: {X_val.shape[0]}, Test: {X_test.shape[0]}") return { "X_train": X_train, "X_val": X_val, "X_test": X_test, "y_train": y_train, "y_val": y_val, "y_test": y_test, "feature_names": feature_names, "class_names": valid_drugs, "task_type": "multilabel", "metadata": { "target": "amr_drug_class", "k": k, "max_features": max_features, "n_samples": len(merged), "n_features": X.shape[1], "n_classes": len(valid_drugs), "drug_classes": valid_drugs, }, } def save_processed_data(self, data: dict, prefix: str = "ncbi") -> None: """Save processed data to disk. Args: data: Dictionary with train/val/test splits and metadata. prefix: Filename prefix for saved files. """ logger.info(f"Saving processed data to {self.output_dir}") # Save numpy arrays np.save(self.output_dir / f"{prefix}_X_train.npy", data["X_train"]) np.save(self.output_dir / f"{prefix}_X_val.npy", data["X_val"]) np.save(self.output_dir / f"{prefix}_X_test.npy", data["X_test"]) np.save(self.output_dir / f"{prefix}_y_train.npy", data["y_train"]) np.save(self.output_dir / f"{prefix}_y_val.npy", data["y_val"]) np.save(self.output_dir / f"{prefix}_y_test.npy", data["y_test"]) # Save metadata metadata = { "feature_names": data["feature_names"], "class_names": data["class_names"], "task_type": data["task_type"], **data["metadata"], } with open(self.output_dir / f"{prefix}_metadata.json", "w") as f: json.dump(metadata, f, indent=2) logger.info("Data saved successfully!") def get_organism_statistics(self) -> pd.DataFrame: """Get statistics about organisms in the dataset. Returns: DataFrame with organism-level statistics. """ if self.metadata is None: self.load_data() # Count by organism stats = self.metadata["organism_query"].value_counts().reset_index() stats.columns = ["organism", "total_records"] # Count with sequences seq_counts = {} for org in stats["organism"]: org_df = self.metadata[self.metadata["organism_query"] == org].copy() org_df["biosample_id"] = org_df["biosample_id"].astype(str) with_seq = sum(1 for bid in org_df["biosample_id"] if bid in self.sequences) seq_counts[org] = with_seq stats["with_sequences"] = stats["organism"].map(seq_counts) stats = stats.sort_values("total_records", ascending=False) return stats def get_sequence_statistics(self) -> pd.DataFrame: """Get statistics about genome sequences. Returns: DataFrame with sequence statistics. """ if not self.sequences: self._load_sequences() stats = [] for biosample_id, seq in self.sequences.items(): seq_len = len(seq) gc_content = self.compute_gc_content(seq) # Count nucleotides seq_upper = seq.upper() a_count = seq_upper.count("A") t_count = seq_upper.count("T") g_count = seq_upper.count("G") c_count = seq_upper.count("C") n_count = seq_upper.count("N") stats.append({ "biosample_id": biosample_id, "length": seq_len, "gc_content": round(gc_content, 4), "a_count": a_count, "t_count": t_count, "g_count": g_count, "c_count": c_count, "n_count": n_count, }) return pd.DataFrame(stats).sort_values("length", ascending=False) def get_metadata_statistics(self) -> Dict: """Get comprehensive summary of the NCBI dataset. Returns: Dictionary containing dataset summary statistics. """ if self.metadata is None: self.load_data() summary = { "total_records": len(self.metadata), "total_genomes_with_sequences": len(self.sequences), "unique_organisms": self.metadata["organism_query"].nunique(), "organisms_list": sorted( self.metadata["organism_query"].dropna().unique().tolist() ), } # AMR data availability has_amr_geno = ( self.metadata["amr_genotypes"].notna() & (self.metadata["amr_genotypes"] != "") ).sum() has_amr_pheno = ( self.metadata["amr_phenotypes"].notna() & (self.metadata["amr_phenotypes"] != "") ).sum() summary["records_with_amr_genotypes"] = int(has_amr_geno) summary["records_with_amr_phenotypes"] = int(has_amr_pheno) # Assembly availability has_assembly = ( self.metadata["assembly_accession"].notna() & (self.metadata["assembly_accession"] != "") ).sum() summary["records_with_assembly"] = int(has_assembly) # Geographic distribution if "geo_loc_name" in self.metadata.columns: summary["unique_locations"] = self.metadata["geo_loc_name"].nunique() # Isolation source distribution if "isolation_source" in self.metadata.columns: summary["unique_isolation_sources"] = self.metadata[ "isolation_source" ].nunique() # Sequence statistics if self.sequences: seq_lengths = [len(seq) for seq in self.sequences.values()] summary["sequence_stats"] = { "count": len(seq_lengths), "min_length": min(seq_lengths), "max_length": max(seq_lengths), "mean_length": int(np.mean(seq_lengths)), "median_length": int(np.median(seq_lengths)), } return summary def get_geographic_distribution(self) -> pd.DataFrame: """Get geographic distribution of samples. Returns: DataFrame with location counts. """ if self.metadata is None: self.load_data() if "geo_loc_name" not in self.metadata.columns: return pd.DataFrame() stats = self.metadata["geo_loc_name"].value_counts().reset_index() stats.columns = ["location", "count"] return stats def get_isolation_source_distribution(self) -> pd.DataFrame: """Get distribution of isolation sources. Returns: DataFrame with isolation source counts. """ if self.metadata is None: self.load_data() if "isolation_source" not in self.metadata.columns: return pd.DataFrame() stats = self.metadata["isolation_source"].value_counts().reset_index() stats.columns = ["isolation_source", "count"] return stats def main(): """Main preprocessing pipeline.""" preprocessor = NCBIPreprocessor() # Load data preprocessor.load_data() # Show statistics print("\n=== Dataset Summary ===") summary = preprocessor.get_metadata_statistics() for key, value in summary.items(): if isinstance(value, dict): print(f"{key}:") for k, v in value.items(): print(f" {k}: {v}") elif isinstance(value, list): print(f"{key}: {len(value)} items") if len(value) <= 10: for item in value: print(f" - {item}") else: print(f"{key}: {value}") print("\n=== Organism Statistics ===") org_stats = preprocessor.get_organism_statistics() print(org_stats.to_string(index=False)) # Prepare organism classification data print("\n=== Preparing Organism Classification Data ===") try: data = preprocessor.prepare_organism_classification_data( k=6, max_features=500, test_size=0.2, val_size=0.1, min_samples_per_class=5, ) preprocessor.save_processed_data(data, prefix="ncbi_organism") print("Saved organism classification data") # Print class distribution print("\nClass distribution in training set:") train_counts = np.bincount(data["y_train"]) for i, count in enumerate(train_counts): print(f" {data['class_names'][i]}: {count}") except Exception as e: print(f"Error preparing organism classification data: {e}") # Check AMR data availability print("\n=== AMR Data Status ===") if summary.get("records_with_amr_phenotypes", 0) > 0: print("AMR phenotype data available - can prepare AMR prediction data") else: print( "No AMR phenotype data available in NCBI metadata.\n" "Options:\n" " 1. Use organism classification as a proxy task\n" " 2. Link to external AMR databases (PATRIC, CARD)\n" " 3. Wait for AMR data to be added to NCBI records" ) print("\n=== Preprocessing Complete ===") print(f"Output directory: {preprocessor.output_dir}") if __name__ == "__main__": main()