Spaces:
Running
Running
| """AMRFinderPlus integration for annotating NCBI genomes with AMR predictions. | |
| This module runs AMRFinderPlus on genome sequences to detect AMR genes | |
| and predict resistance phenotypes, which can then be used as labels | |
| for machine learning models. | |
| """ | |
| import gzip | |
| import json | |
| import logging | |
| import os | |
| import shutil | |
| import subprocess | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import pandas as pd | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class AMRFinderAnnotator: | |
| """Annotate genomes with AMR predictions using AMRFinderPlus.""" | |
| # Mapping of organisms to AMRFinderPlus organism codes | |
| ORGANISM_CODES = { | |
| "Acinetobacter baumannii": "Acinetobacter_baumannii", | |
| "Campylobacter jejuni": "Campylobacter", | |
| "Campylobacter coli": "Campylobacter", | |
| "Clostridioides difficile": "Clostridioides_difficile", | |
| "Enterococcus faecalis": "Enterococcus_faecalis", | |
| "Enterococcus faecium": "Enterococcus_faecium", | |
| "Escherichia coli": "Escherichia", | |
| "Klebsiella pneumoniae": "Klebsiella_pneumoniae", | |
| "Neisseria gonorrhoeae": "Neisseria_gonorrhoeae", | |
| "Neisseria meningitidis": "Neisseria_meningitidis", | |
| "Pseudomonas aeruginosa": "Pseudomonas_aeruginosa", | |
| "Salmonella enterica": "Salmonella", | |
| "Staphylococcus aureus": "Staphylococcus_aureus", | |
| "Staphylococcus pseudintermedius": "Staphylococcus_pseudintermedius", | |
| "Streptococcus agalactiae": "Streptococcus_agalactiae", | |
| "Streptococcus pneumoniae": "Streptococcus_pneumoniae", | |
| "Streptococcus pyogenes": "Streptococcus_pyogenes", | |
| "Vibrio cholerae": "Vibrio_cholerae", | |
| } | |
| def __init__( | |
| self, | |
| genomes_dir: str = "data/raw/ncbi/genomes", | |
| metadata_file: str = "data/raw/ncbi/complete_metadata.csv", | |
| output_dir: str = "data/raw/ncbi/amrfinder_results", | |
| ): | |
| self.genomes_dir = Path(genomes_dir) | |
| self.metadata_file = Path(metadata_file) | |
| self.output_dir = Path(output_dir) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| self.metadata: Optional[pd.DataFrame] = None | |
| self.amr_results: Dict[str, pd.DataFrame] = {} | |
| def check_installation() -> bool: | |
| """Check if AMRFinderPlus is installed.""" | |
| try: | |
| result = subprocess.run( | |
| ["amrfinder", "--version"], | |
| capture_output=True, | |
| text=True, | |
| ) | |
| if result.returncode == 0: | |
| logger.info(f"AMRFinderPlus version: {result.stdout.strip()}") | |
| return True | |
| except FileNotFoundError: | |
| pass | |
| return False | |
| def install_amrfinder() -> bool: | |
| """Install AMRFinderPlus using conda.""" | |
| logger.info("Installing AMRFinderPlus via conda...") | |
| try: | |
| # Try conda install | |
| result = subprocess.run( | |
| ["conda", "install", "-y", "-c", "bioconda", "ncbi-amrfinderplus"], | |
| capture_output=True, | |
| text=True, | |
| ) | |
| if result.returncode == 0: | |
| logger.info("AMRFinderPlus installed successfully") | |
| # Update database | |
| subprocess.run(["amrfinder", "-u"], capture_output=True) | |
| return True | |
| else: | |
| logger.error(f"Conda install failed: {result.stderr}") | |
| except FileNotFoundError: | |
| logger.error("Conda not found. Please install conda first.") | |
| return False | |
| def update_database() -> bool: | |
| """Update AMRFinderPlus database.""" | |
| logger.info("Updating AMRFinderPlus database...") | |
| try: | |
| result = subprocess.run( | |
| ["amrfinder", "-u"], | |
| capture_output=True, | |
| text=True, | |
| ) | |
| if result.returncode == 0: | |
| logger.info("Database updated successfully") | |
| return True | |
| else: | |
| logger.warning(f"Database update warning: {result.stderr}") | |
| return True # May already be up to date | |
| except Exception as e: | |
| logger.error(f"Database update failed: {e}") | |
| return False | |
| def load_metadata(self) -> pd.DataFrame: | |
| """Load genome metadata.""" | |
| if self.metadata is None: | |
| # Load from all metadata files | |
| metadata_dir = self.metadata_file.parent / "metadata" | |
| all_dfs = [] | |
| if metadata_dir.exists(): | |
| for csv_file in metadata_dir.glob("*.csv"): | |
| if not csv_file.name.startswith("."): | |
| df = pd.read_csv(csv_file) | |
| all_dfs.append(df) | |
| if self.metadata_file.exists(): | |
| df = pd.read_csv(self.metadata_file) | |
| all_dfs.append(df) | |
| if all_dfs: | |
| self.metadata = pd.concat(all_dfs, ignore_index=True) | |
| self.metadata = self.metadata.drop_duplicates(subset=["biosample_id"]) | |
| self.metadata["biosample_id"] = self.metadata["biosample_id"].astype(str) | |
| logger.info(f"Loaded metadata for {len(self.metadata)} samples") | |
| else: | |
| raise FileNotFoundError("No metadata files found") | |
| return self.metadata | |
| def get_organism_for_sample(self, biosample_id: str) -> Optional[str]: | |
| """Get AMRFinderPlus organism code for a sample.""" | |
| if self.metadata is None: | |
| self.load_metadata() | |
| row = self.metadata[self.metadata["biosample_id"] == biosample_id] | |
| if len(row) == 0: | |
| return None | |
| organism = row.iloc[0].get("organism_query", "") | |
| return self.ORGANISM_CODES.get(organism) | |
| def run_amrfinder_on_genome( | |
| self, | |
| genome_file: Path, | |
| biosample_id: str, | |
| organism_code: Optional[str] = None, | |
| ) -> Optional[pd.DataFrame]: | |
| """Run AMRFinderPlus on a single genome. | |
| Args: | |
| genome_file: Path to genome FASTA file (can be gzipped) | |
| biosample_id: Sample identifier | |
| organism_code: AMRFinderPlus organism code (optional) | |
| Returns: | |
| DataFrame with AMR results or None if failed | |
| """ | |
| output_file = self.output_dir / f"{biosample_id}_amrfinder.tsv" | |
| # Skip if already processed | |
| if output_file.exists(): | |
| try: | |
| return pd.read_csv(output_file, sep="\t") | |
| except Exception: | |
| pass | |
| # Decompress if needed | |
| temp_file = None | |
| if str(genome_file).endswith(".gz"): | |
| temp_file = tempfile.NamedTemporaryFile( | |
| suffix=".fna", delete=False, mode="w" | |
| ) | |
| with gzip.open(genome_file, "rt") as f_in: | |
| temp_file.write(f_in.read()) | |
| temp_file.close() | |
| input_file = temp_file.name | |
| else: | |
| input_file = str(genome_file) | |
| try: | |
| # Build command | |
| cmd = [ | |
| "amrfinder", | |
| "-n", input_file, # Nucleotide input | |
| "-o", str(output_file), | |
| "--plus", # Include stress/virulence genes | |
| ] | |
| # Add organism-specific options if available | |
| if organism_code: | |
| cmd.extend(["--organism", organism_code]) | |
| # Run AMRFinderPlus | |
| result = subprocess.run( | |
| cmd, | |
| capture_output=True, | |
| text=True, | |
| timeout=300, # 5 minute timeout | |
| ) | |
| if result.returncode == 0 and output_file.exists(): | |
| df = pd.read_csv(output_file, sep="\t") | |
| df["biosample_id"] = biosample_id | |
| return df | |
| else: | |
| logger.warning(f"AMRFinder failed for {biosample_id}: {result.stderr}") | |
| return None | |
| except subprocess.TimeoutExpired: | |
| logger.warning(f"AMRFinder timeout for {biosample_id}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error running AMRFinder for {biosample_id}: {e}") | |
| return None | |
| finally: | |
| # Clean up temp file | |
| if temp_file and os.path.exists(temp_file.name): | |
| os.unlink(temp_file.name) | |
| def run_on_all_genomes( | |
| self, | |
| max_samples: Optional[int] = None, | |
| use_organism: bool = True, | |
| ) -> pd.DataFrame: | |
| """Run AMRFinderPlus on all genomes. | |
| Args: | |
| max_samples: Maximum number of samples to process (for testing) | |
| use_organism: Whether to use organism-specific detection | |
| Returns: | |
| Combined DataFrame with all AMR results | |
| """ | |
| if not self.check_installation(): | |
| logger.error( | |
| "AMRFinderPlus not installed. Install with:\n" | |
| " conda install -c bioconda ncbi-amrfinderplus\n" | |
| "Then update database with:\n" | |
| " amrfinder -u" | |
| ) | |
| raise RuntimeError("AMRFinderPlus not installed") | |
| self.load_metadata() | |
| # Get genome files | |
| genome_files = list(self.genomes_dir.glob("*.fna.gz")) | |
| if max_samples: | |
| genome_files = genome_files[:max_samples] | |
| logger.info(f"Processing {len(genome_files)} genomes...") | |
| all_results = [] | |
| for i, genome_file in enumerate(genome_files): | |
| biosample_id = genome_file.stem.replace(".fna", "") | |
| # Get organism code | |
| organism_code = None | |
| if use_organism: | |
| organism_code = self.get_organism_for_sample(biosample_id) | |
| # Run AMRFinderPlus | |
| result = self.run_amrfinder_on_genome( | |
| genome_file, biosample_id, organism_code | |
| ) | |
| if result is not None and len(result) > 0: | |
| all_results.append(result) | |
| self.amr_results[biosample_id] = result | |
| if (i + 1) % 10 == 0: | |
| logger.info(f"Processed {i + 1}/{len(genome_files)} genomes") | |
| # Combine results | |
| if all_results: | |
| combined = pd.concat(all_results, ignore_index=True) | |
| combined.to_csv(self.output_dir / "all_amr_results.csv", index=False) | |
| logger.info(f"Found {len(combined)} AMR genes across {len(all_results)} genomes") | |
| return combined | |
| else: | |
| logger.warning("No AMR genes found in any genome") | |
| return pd.DataFrame() | |
| def create_amr_labels( | |
| self, | |
| min_samples_per_drug: int = 10, | |
| ) -> Tuple[pd.DataFrame, Dict]: | |
| """Create AMR labels from AMRFinderPlus results. | |
| Converts AMR gene detections into drug resistance labels. | |
| Args: | |
| min_samples_per_drug: Minimum samples with resistance to include a drug | |
| Returns: | |
| Tuple of (labels DataFrame, drug class mapping) | |
| """ | |
| # Load all results | |
| results_file = self.output_dir / "all_amr_results.csv" | |
| if not results_file.exists(): | |
| raise FileNotFoundError( | |
| "No AMR results found. Run run_on_all_genomes() first." | |
| ) | |
| df = pd.read_csv(results_file) | |
| logger.info(f"Loaded {len(df)} AMR annotations") | |
| # Filter to AMR genes only (not stress/virulence) | |
| amr_df = df[df["Element type"] == "AMR"].copy() | |
| logger.info(f"AMR genes: {len(amr_df)}") | |
| if len(amr_df) == 0: | |
| logger.warning("No AMR genes found in results") | |
| return pd.DataFrame(), {} | |
| # Get unique drug classes | |
| # AMRFinderPlus uses "Class" and "Subclass" columns | |
| drug_classes = set() | |
| for _, row in amr_df.iterrows(): | |
| drug_class = row.get("Class", "") | |
| if pd.notna(drug_class) and drug_class: | |
| drug_classes.add(drug_class) | |
| logger.info(f"Drug classes found: {drug_classes}") | |
| # Create label matrix | |
| biosample_ids = amr_df["biosample_id"].unique() | |
| labels = [] | |
| for biosample_id in biosample_ids: | |
| sample_amr = amr_df[amr_df["biosample_id"] == biosample_id] | |
| sample_drugs = set(sample_amr["Class"].dropna().unique()) | |
| row = {"biosample_id": biosample_id} | |
| for drug in drug_classes: | |
| row[drug] = 1 if drug in sample_drugs else 0 | |
| labels.append(row) | |
| labels_df = pd.DataFrame(labels) | |
| # Filter drugs with enough samples | |
| drug_counts = labels_df.drop(columns=["biosample_id"]).sum() | |
| valid_drugs = drug_counts[drug_counts >= min_samples_per_drug].index.tolist() | |
| logger.info(f"Drugs with >= {min_samples_per_drug} resistant samples: {len(valid_drugs)}") | |
| for drug in valid_drugs: | |
| logger.info(f" {drug}: {drug_counts[drug]} samples") | |
| # Create drug class mapping | |
| drug_mapping = {drug: i for i, drug in enumerate(sorted(valid_drugs))} | |
| # Save labels | |
| labels_df.to_csv(self.output_dir / "amr_labels.csv", index=False) | |
| with open(self.output_dir / "drug_mapping.json", "w") as f: | |
| json.dump(drug_mapping, f, indent=2) | |
| return labels_df, drug_mapping | |
| def get_phenotype_labels(self) -> pd.DataFrame: | |
| """Get resistance phenotype labels for preprocessing. | |
| Returns DataFrame with columns: | |
| - biosample_id | |
| - One column per drug class (1=resistant, 0=susceptible/unknown) | |
| """ | |
| labels_file = self.output_dir / "amr_labels.csv" | |
| if labels_file.exists(): | |
| return pd.read_csv(labels_file) | |
| else: | |
| labels_df, _ = self.create_amr_labels() | |
| return labels_df | |
| def main(): | |
| """Main function to run AMR annotation pipeline.""" | |
| annotator = AMRFinderAnnotator() | |
| # Check installation | |
| if not annotator.check_installation(): | |
| print("\n" + "=" * 60) | |
| print("AMRFinderPlus is not installed!") | |
| print("=" * 60) | |
| print("\nTo install AMRFinderPlus:") | |
| print(" 1. Using conda (recommended):") | |
| print(" conda install -c bioconda ncbi-amrfinderplus") | |
| print("\n 2. Using docker:") | |
| print(" docker pull ncbi/amr") | |
| print("\n 3. Manual installation:") | |
| print(" https://github.com/ncbi/amr/wiki/Installing-AMRFinder") | |
| print("\nAfter installation, update the database:") | |
| print(" amrfinder -u") | |
| print("=" * 60) | |
| return | |
| # Run on all genomes | |
| print("\nRunning AMRFinderPlus on all genomes...") | |
| results = annotator.run_on_all_genomes() | |
| if len(results) > 0: | |
| print(f"\nFound {len(results)} AMR genes") | |
| # Create labels | |
| print("\nCreating AMR labels...") | |
| labels_df, drug_mapping = annotator.create_amr_labels() | |
| print(f"\nCreated labels for {len(labels_df)} samples") | |
| print(f"Drug classes: {list(drug_mapping.keys())}") | |
| print(f"\nResults saved to: {annotator.output_dir}") | |
| else: | |
| print("\nNo AMR genes detected. Check genome files and AMRFinderPlus installation.") | |
| if __name__ == "__main__": | |
| main() | |