# src/ml/feature_extractor.py import torch from transformers import AutoTokenizer, AutoModel from Bio import SeqIO import numpy as np from typing import List, Dict import re class ProteinFeatureExtractor: """Extract features from protein sequences using ESM-2""" def __init__(self, model_path="models/pretrained/esm2"): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = AutoModel.from_pretrained(model_path).to(self.device) self.model.eval() def extract_proteins_from_genome(self, genome_sequence: str) -> List[str]: """ Extract protein sequences from genome Use Prodigal or simple ORF finder """ # Simple ORF finder (for demo - use Prodigal in production) proteins = [] # Find ORFs starting with ATG and ending with stop codons start_codons = ['ATG'] stop_codons = ['TAA', 'TAG', 'TGA'] for i in range(len(genome_sequence) - 3): codon = genome_sequence[i:i+3] if codon in start_codons: # Look for stop codon for j in range(i+3, len(genome_sequence)-3, 3): stop_codon = genome_sequence[j:j+3] if stop_codon in stop_codons: orf = genome_sequence[i:j+3] if len(orf) >= 300: # Minimum 100 amino acids protein = self.translate_dna_to_protein(orf) if protein: proteins.append(protein) break return proteins[:50] # Top 50 proteins to avoid too much data def translate_dna_to_protein(self, dna_seq: str) -> str: """Translate DNA to protein sequence""" codon_table = { 'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L', 'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S', 'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*', 'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W', 'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L', 'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P', 'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q', 'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R', 'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M', 'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T', 'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K', 'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R', 'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V', 'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A', 'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E', 'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G', } protein = [] for i in range(0, len(dna_seq) - 2, 3): codon = dna_seq[i:i+3].upper() if codon in codon_table: aa = codon_table[codon] if aa == '*': break protein.append(aa) return ''.join(protein) if len(protein) > 0 else None def get_protein_embedding(self, protein_seq: str) -> np.ndarray: """Get ESM-2 embedding for a protein sequence""" # Truncate if too long (ESM-2 has max length ~1000) if len(protein_seq) > 1000: protein_seq = protein_seq[:1000] # Tokenize inputs = self.tokenizer(protein_seq, return_tensors="pt", truncation=True, max_length=1024) inputs = {k: v.to(self.device) for k, v in inputs.items()} # Get embeddings with torch.no_grad(): outputs = self.model(**inputs) # Mean pooling over sequence length embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy() return embeddings.squeeze() def extract_genome_features(self, genome_path: str) -> np.ndarray: """Extract features from entire genome""" # Load genome genome_seq = "" for record in SeqIO.parse(genome_path, "fasta"): genome_seq += str(record.seq) # Extract proteins proteins = self.extract_proteins_from_genome(genome_seq) print(f"Extracted {len(proteins)} proteins from genome") if len(proteins) == 0: return np.zeros(320) # Return zero vector if no proteins found # Get embeddings for all proteins embeddings = [] for protein in proteins[:20]: # Top 20 proteins try: emb = self.get_protein_embedding(protein) embeddings.append(emb) except Exception as e: print(f"Error processing protein: {e}") continue if len(embeddings) == 0: return np.zeros(320) # Aggregate embeddings (mean pooling) genome_embedding = np.mean(embeddings, axis=0) return genome_embedding class AMRGeneDetector: """Detect known AMR genes using CARD database""" def __init__(self, card_db_path="data/external/card"): self.card_sequences = self.load_card_database(card_db_path) def load_card_database(self, card_path): """Load CARD AMR gene sequences""" card_genes = {} # Load from CARD FASTA file fasta_path = f"{card_path}/nucleotide_fasta_protein_homolog_model.fasta" try: for record in SeqIO.parse(fasta_path, "fasta"): # Parse gene name and antibiotic class gene_info = self.parse_card_header(record.description) card_genes[record.id] = { 'sequence': str(record.seq), 'gene_name': gene_info['gene_name'], 'drug_class': gene_info['drug_class'] } except FileNotFoundError: print(f"CARD database not found at {fasta_path}") # Return empty dict for now return {} print(f"Loaded {len(card_genes)} AMR genes from CARD") return card_genes def parse_card_header(self, header: str) -> Dict: """Parse CARD FASTA header""" # Example: "ARO:3000026|mecA [Staphylococcus aureus]" parts = header.split('|') gene_name = parts[1].split('[')[0].strip() if len(parts) > 1 else "unknown" return { 'gene_name': gene_name, 'drug_class': 'beta-lactam' # Simplified for now } def detect_amr_genes(self, genome_sequence: str) -> List[Dict]: """ Detect AMR genes in genome using sequence similarity In production, use BLAST or MMseqs2 """ detected_genes = [] # Simplified: check for exact substring matches # In production: use BLAST or diamond for gene_id, gene_info in self.card_sequences.items(): if gene_info['sequence'] in genome_sequence: detected_genes.append({ 'gene_id': gene_id, 'gene_name': gene_info['gene_name'], 'drug_class': gene_info['drug_class'] }) return detected_genes class CombinedFeatureExtractor: """Combine protein embeddings and gene detection""" def __init__(self): self.protein_extractor = ProteinFeatureExtractor() self.gene_detector = AMRGeneDetector() def extract_features(self, genome_path: str) -> Dict: """Extract all features from genome""" # 1. Protein embeddings (320-dim from ESM-2) protein_features = self.protein_extractor.extract_genome_features(genome_path) # 2. Load genome for gene detection genome_seq = "" for record in SeqIO.parse(genome_path, "fasta"): genome_seq += str(record.seq) # 3. AMR gene detection detected_genes = self.gene_detector.detect_amr_genes(genome_seq) # 4. Create gene presence/absence vector gene_features = self.create_gene_feature_vector(detected_genes) # 5. Combine features combined_features = np.concatenate([protein_features, gene_features]) return { 'features': combined_features, 'detected_genes': detected_genes, 'feature_dim': len(combined_features) } def create_gene_feature_vector(self, detected_genes: List[Dict], num_genes=50) -> np.ndarray: """Create binary vector for gene presence/absence""" # Top 50 most important AMR genes important_genes = [ 'mecA', 'vanA', 'blaCTX-M', 'blaKPC', 'blaNDM', 'blaOXA', 'ermB', 'tetM', 'aac', 'aph', 'sul1', 'sul2', 'dfrA' ] gene_vector = np.zeros(num_genes) detected_names = [g['gene_name'] for g in detected_genes] for i, gene in enumerate(important_genes[:num_genes]): if any(gene in name for name in detected_names): gene_vector[i] = 1 return gene_vector