Spaces:
Running
Running
| """CARD database preprocessor for AMR prediction modeling.""" | |
| import json | |
| import logging | |
| from collections import Counter | |
| from pathlib import Path | |
| from typing import Optional | |
| import numpy as np | |
| import pandas as pd | |
| from Bio import SeqIO | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class CARDPreprocessor: | |
| """Preprocess CARD database for AMR prediction models.""" | |
| def __init__( | |
| self, | |
| card_dir: str = "data/raw/card-data", | |
| output_dir: str = "data/processed/card", | |
| ): | |
| self.card_dir = Path(card_dir) | |
| self.output_dir = Path(output_dir) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| # Data containers | |
| self.aro_index: Optional[pd.DataFrame] = None | |
| self.aro_categories: Optional[pd.DataFrame] = None | |
| self.sequences: dict = {} | |
| self.label_encoders: dict = {} | |
| def load_data(self) -> None: | |
| """Load all CARD data files.""" | |
| logger.info("Loading CARD data...") | |
| # Load ARO index (main gene-drug-mechanism mapping) | |
| self.aro_index = pd.read_csv( | |
| self.card_dir / "aro_index.tsv", sep="\t", low_memory=False | |
| ) | |
| logger.info(f"Loaded {len(self.aro_index)} ARO entries") | |
| # Load ARO categories index (detailed categorization) | |
| self.aro_categories = pd.read_csv( | |
| self.card_dir / "aro_categories_index.tsv", sep="\t", low_memory=False | |
| ) | |
| logger.info(f"Loaded {len(self.aro_categories)} category mappings") | |
| # Load sequences from FASTA files | |
| self._load_sequences() | |
| def _load_sequences(self) -> None: | |
| """Load protein and nucleotide sequences from FASTA files.""" | |
| fasta_files = { | |
| "protein_homolog": "protein_fasta_protein_homolog_model.fasta", | |
| "protein_variant": "protein_fasta_protein_variant_model.fasta", | |
| "protein_knockout": "protein_fasta_protein_knockout_model.fasta", | |
| "protein_overexpression": "protein_fasta_protein_overexpression_model.fasta", | |
| "nucleotide_homolog": "nucleotide_fasta_protein_homolog_model.fasta", | |
| "nucleotide_variant": "nucleotide_fasta_protein_variant_model.fasta", | |
| } | |
| for seq_type, filename in fasta_files.items(): | |
| fasta_path = self.card_dir / filename | |
| if fasta_path.exists(): | |
| self.sequences[seq_type] = {} | |
| for record in SeqIO.parse(fasta_path, "fasta"): | |
| # Extract ARO accession from header | |
| # Format: ">gb|ACCESSION|ARO:XXXXXX|NAME [Species]" | |
| header_parts = record.description.split("|") | |
| aro_acc = None | |
| for part in header_parts: | |
| if part.startswith("ARO:"): | |
| aro_acc = part.strip() | |
| break | |
| if aro_acc: | |
| self.sequences[seq_type][aro_acc] = str(record.seq) | |
| logger.info(f"Loaded {len(self.sequences[seq_type])} {seq_type} sequences") | |
| def create_drug_resistance_dataset(self) -> pd.DataFrame: | |
| """Create dataset mapping genes to drug classes and resistance mechanisms.""" | |
| if self.aro_index is None: | |
| self.load_data() | |
| df = self.aro_index.copy() | |
| # Clean and standardize drug classes | |
| df["Drug Class"] = df["Drug Class"].fillna("unknown") | |
| df["Drug Classes"] = df["Drug Class"].apply(self._split_drug_classes) | |
| # Clean resistance mechanisms | |
| df["Resistance Mechanism"] = df["Resistance Mechanism"].fillna("unknown") | |
| # Clean gene families | |
| df["AMR Gene Family"] = df["AMR Gene Family"].fillna("unknown") | |
| # Add sequence data | |
| df["protein_sequence"] = df["ARO Accession"].apply( | |
| lambda x: self.sequences.get("protein_homolog", {}).get(x, "") | |
| ) | |
| df["nucleotide_sequence"] = df["ARO Accession"].apply( | |
| lambda x: self.sequences.get("nucleotide_homolog", {}).get(x, "") | |
| ) | |
| # Filter entries with sequences | |
| df_with_seq = df[df["protein_sequence"].str.len() > 0].copy() | |
| logger.info(f"Entries with protein sequences: {len(df_with_seq)}") | |
| return df_with_seq | |
| def _split_drug_classes(self, drug_string: str) -> list: | |
| """Split drug class string into list.""" | |
| if pd.isna(drug_string) or drug_string == "unknown": | |
| return [] | |
| return [d.strip() for d in drug_string.split(";")] | |
| def extract_kmer_features( | |
| self, sequences: list, k: int = 3, max_features: int = 1000 | |
| ) -> np.ndarray: | |
| """Extract k-mer frequency features from sequences.""" | |
| logger.info(f"Extracting {k}-mer features from {len(sequences)} sequences...") | |
| # Count all k-mers across sequences to find most common | |
| all_kmers = Counter() | |
| for seq in sequences: | |
| seq = seq.upper() | |
| for i in range(len(seq) - k + 1): | |
| kmer = seq[i : i + k] | |
| if all(c in "ACDEFGHIKLMNPQRSTVWY" for c in kmer): # Valid amino acids | |
| all_kmers[kmer] += 1 | |
| # Select top k-mers as features | |
| top_kmers = [kmer for kmer, _ in all_kmers.most_common(max_features)] | |
| logger.info(f"Selected {len(top_kmers)} k-mer features") | |
| # Create feature matrix | |
| feature_matrix = np.zeros((len(sequences), len(top_kmers))) | |
| kmer_to_idx = {kmer: idx for idx, kmer in enumerate(top_kmers)} | |
| for seq_idx, seq in enumerate(sequences): | |
| seq = seq.upper() | |
| seq_len = len(seq) - k + 1 | |
| if seq_len <= 0: | |
| continue | |
| for i in range(seq_len): | |
| kmer = seq[i : i + k] | |
| if kmer in kmer_to_idx: | |
| feature_matrix[seq_idx, kmer_to_idx[kmer]] += 1 | |
| # Normalize by sequence length | |
| if seq_len > 0: | |
| feature_matrix[seq_idx] /= seq_len | |
| return feature_matrix, top_kmers | |
| def encode_labels( | |
| self, df: pd.DataFrame, target_col: str = "Resistance Mechanism" | |
| ) -> tuple: | |
| """Encode categorical labels.""" | |
| le = LabelEncoder() | |
| labels = le.fit_transform(df[target_col]) | |
| self.label_encoders[target_col] = le | |
| return labels, le.classes_ | |
| def encode_multilabels( | |
| self, df: pd.DataFrame, target_col: str = "Drug Classes" | |
| ) -> tuple: | |
| """Encode multi-label targets (e.g., multiple drug classes).""" | |
| mlb = MultiLabelBinarizer() | |
| labels = mlb.fit_transform(df[target_col]) | |
| self.label_encoders[target_col] = mlb | |
| return labels, mlb.classes_ | |
| def prepare_modeling_data( | |
| self, | |
| target: str = "drug_class", | |
| k: int = 3, | |
| max_features: int = 1000, | |
| test_size: float = 0.2, | |
| val_size: float = 0.1, | |
| random_state: int = 42, | |
| ) -> dict: | |
| """Prepare complete dataset for modeling. | |
| Args: | |
| target: 'drug_class', 'mechanism', or 'gene_family' | |
| k: k-mer size for feature extraction | |
| max_features: maximum number of k-mer features | |
| test_size: proportion of data for testing | |
| val_size: proportion of training data for validation | |
| random_state: random seed for reproducibility | |
| Returns: | |
| Dictionary with train/val/test splits and metadata | |
| """ | |
| logger.info(f"Preparing modeling data with target: {target}") | |
| # Create base dataset | |
| df = self.create_drug_resistance_dataset() | |
| # Extract features | |
| sequences = df["protein_sequence"].tolist() | |
| X, feature_names = self.extract_kmer_features(sequences, k=k, max_features=max_features) | |
| # Encode labels based on target | |
| if target == "drug_class": | |
| y, class_names = self.encode_multilabels(df, "Drug Classes") | |
| task_type = "multilabel" | |
| elif target == "mechanism": | |
| y, class_names = self.encode_labels(df, "Resistance Mechanism") | |
| task_type = "multiclass" | |
| elif target == "gene_family": | |
| y, class_names = self.encode_labels(df, "AMR Gene Family") | |
| task_type = "multiclass" | |
| else: | |
| raise ValueError(f"Unknown target: {target}") | |
| logger.info(f"Features shape: {X.shape}, Labels shape: {y.shape}") | |
| logger.info(f"Number of classes: {len(class_names)}") | |
| # Split data (try stratified, fall back to random if classes too small) | |
| try: | |
| stratify = y if task_type == "multiclass" else None | |
| X_temp, X_test, y_temp, y_test = train_test_split( | |
| X, y, test_size=test_size, random_state=random_state, stratify=stratify | |
| ) | |
| val_ratio = val_size / (1 - test_size) | |
| X_train, X_val, y_train, y_val = train_test_split( | |
| X_temp, y_temp, test_size=val_ratio, random_state=random_state, stratify=y_temp if task_type == "multiclass" else None | |
| ) | |
| except ValueError as e: | |
| logger.warning(f"Stratified split failed ({e}), using random split") | |
| X_temp, X_test, y_temp, y_test = train_test_split( | |
| X, y, test_size=test_size, random_state=random_state | |
| ) | |
| val_ratio = val_size / (1 - test_size) | |
| X_train, X_val, y_train, y_val = train_test_split( | |
| X_temp, y_temp, test_size=val_ratio, random_state=random_state | |
| ) | |
| logger.info(f"Train: {X_train.shape[0]}, Val: {X_val.shape[0]}, Test: {X_test.shape[0]}") | |
| # Prepare result | |
| result = { | |
| "X_train": X_train, | |
| "X_val": X_val, | |
| "X_test": X_test, | |
| "y_train": y_train, | |
| "y_val": y_val, | |
| "y_test": y_test, | |
| "feature_names": feature_names, | |
| "class_names": list(class_names), | |
| "task_type": task_type, | |
| "metadata": { | |
| "target": target, | |
| "k": k, | |
| "max_features": max_features, | |
| "n_samples": len(df), | |
| "n_features": X.shape[1], | |
| "n_classes": len(class_names), | |
| }, | |
| } | |
| return result | |
| def save_processed_data(self, data: dict, prefix: str = "card") -> None: | |
| """Save processed data to disk.""" | |
| logger.info(f"Saving processed data to {self.output_dir}") | |
| # Save numpy arrays | |
| np.save(self.output_dir / f"{prefix}_X_train.npy", data["X_train"]) | |
| np.save(self.output_dir / f"{prefix}_X_val.npy", data["X_val"]) | |
| np.save(self.output_dir / f"{prefix}_X_test.npy", data["X_test"]) | |
| np.save(self.output_dir / f"{prefix}_y_train.npy", data["y_train"]) | |
| np.save(self.output_dir / f"{prefix}_y_val.npy", data["y_val"]) | |
| np.save(self.output_dir / f"{prefix}_y_test.npy", data["y_test"]) | |
| # Save metadata | |
| metadata = { | |
| "feature_names": data["feature_names"], | |
| "class_names": data["class_names"], | |
| "task_type": data["task_type"], | |
| **data["metadata"], | |
| } | |
| with open(self.output_dir / f"{prefix}_metadata.json", "w") as f: | |
| json.dump(metadata, f, indent=2) | |
| logger.info("Data saved successfully!") | |
| def get_drug_class_statistics(self) -> pd.DataFrame: | |
| """Get statistics about drug classes in the dataset.""" | |
| if self.aro_index is None: | |
| self.load_data() | |
| # Count drug classes | |
| drug_counts = Counter() | |
| for drugs in self.aro_index["Drug Class"].dropna(): | |
| for drug in drugs.split(";"): | |
| drug_counts[drug.strip()] += 1 | |
| stats = pd.DataFrame( | |
| [{"Drug Class": drug, "Count": count} for drug, count in drug_counts.most_common()] | |
| ) | |
| return stats | |
| def get_mechanism_statistics(self) -> pd.DataFrame: | |
| """Get statistics about resistance mechanisms.""" | |
| if self.aro_index is None: | |
| self.load_data() | |
| stats = self.aro_index["Resistance Mechanism"].value_counts().reset_index() | |
| stats.columns = ["Resistance Mechanism", "Count"] | |
| return stats | |
| def main(): | |
| """Main preprocessing pipeline.""" | |
| preprocessor = CARDPreprocessor() | |
| # Load data | |
| preprocessor.load_data() | |
| # Show statistics | |
| print("\n=== Drug Class Statistics ===") | |
| drug_stats = preprocessor.get_drug_class_statistics() | |
| print(drug_stats.head(20)) | |
| print("\n=== Resistance Mechanism Statistics ===") | |
| mech_stats = preprocessor.get_mechanism_statistics() | |
| print(mech_stats) | |
| # Prepare data for different modeling tasks | |
| for target in ["mechanism", "drug_class", "gene_family"]: | |
| print(f"\n=== Preparing {target} prediction data ===") | |
| try: | |
| data = preprocessor.prepare_modeling_data( | |
| target=target, | |
| k=3, | |
| max_features=500, | |
| test_size=0.2, | |
| val_size=0.1, | |
| ) | |
| preprocessor.save_processed_data(data, prefix=f"card_{target}") | |
| print(f"Saved {target} prediction data") | |
| except Exception as e: | |
| print(f"Error preparing {target} data: {e}") | |
| print("\n=== Preprocessing Complete ===") | |
| print(f"Output directory: {preprocessor.output_dir}") | |
| if __name__ == "__main__": | |
| main() | |