File size: 9,319 Bytes
3255634
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
# src/ml/feature_extractor.py
import torch
from transformers import AutoTokenizer, AutoModel
from Bio import SeqIO
import numpy as np
from typing import List, Dict
import re

class ProteinFeatureExtractor:
    """Extract features from protein sequences using ESM-2"""
    
    def __init__(self, model_path="models/pretrained/esm2"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModel.from_pretrained(model_path).to(self.device)
        self.model.eval()
        
    def extract_proteins_from_genome(self, genome_sequence: str) -> List[str]:
        """
        Extract protein sequences from genome
        Use Prodigal or simple ORF finder
        """
        # Simple ORF finder (for demo - use Prodigal in production)
        proteins = []
        
        # Find ORFs starting with ATG and ending with stop codons
        start_codons = ['ATG']
        stop_codons = ['TAA', 'TAG', 'TGA']
        
        for i in range(len(genome_sequence) - 3):
            codon = genome_sequence[i:i+3]
            if codon in start_codons:
                # Look for stop codon
                for j in range(i+3, len(genome_sequence)-3, 3):
                    stop_codon = genome_sequence[j:j+3]
                    if stop_codon in stop_codons:
                        orf = genome_sequence[i:j+3]
                        if len(orf) >= 300:  # Minimum 100 amino acids
                            protein = self.translate_dna_to_protein(orf)
                            if protein:
                                proteins.append(protein)
                        break
        
        return proteins[:50]  # Top 50 proteins to avoid too much data
    
    def translate_dna_to_protein(self, dna_seq: str) -> str:
        """Translate DNA to protein sequence"""
        codon_table = {
            'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L',
            'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
            'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*',
            'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
            'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L',
            'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
            'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q',
            'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
            'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M',
            'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
            'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K',
            'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
            'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V',
            'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
            'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E',
            'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G',
        }
        
        protein = []
        for i in range(0, len(dna_seq) - 2, 3):
            codon = dna_seq[i:i+3].upper()
            if codon in codon_table:
                aa = codon_table[codon]
                if aa == '*':
                    break
                protein.append(aa)
        
        return ''.join(protein) if len(protein) > 0 else None
    
    def get_protein_embedding(self, protein_seq: str) -> np.ndarray:
        """Get ESM-2 embedding for a protein sequence"""
        # Truncate if too long (ESM-2 has max length ~1000)
        if len(protein_seq) > 1000:
            protein_seq = protein_seq[:1000]
        
        # Tokenize
        inputs = self.tokenizer(protein_seq, return_tensors="pt", truncation=True, max_length=1024)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        # Get embeddings
        with torch.no_grad():
            outputs = self.model(**inputs)
        
        # Mean pooling over sequence length
        embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
        return embeddings.squeeze()
    
    def extract_genome_features(self, genome_path: str) -> np.ndarray:
        """Extract features from entire genome"""
        # Load genome
        genome_seq = ""
        for record in SeqIO.parse(genome_path, "fasta"):
            genome_seq += str(record.seq)
        
        # Extract proteins
        proteins = self.extract_proteins_from_genome(genome_seq)
        print(f"Extracted {len(proteins)} proteins from genome")
        
        if len(proteins) == 0:
            return np.zeros(320)  # Return zero vector if no proteins found
        
        # Get embeddings for all proteins
        embeddings = []
        for protein in proteins[:20]:  # Top 20 proteins
            try:
                emb = self.get_protein_embedding(protein)
                embeddings.append(emb)
            except Exception as e:
                print(f"Error processing protein: {e}")
                continue
        
        if len(embeddings) == 0:
            return np.zeros(320)
        
        # Aggregate embeddings (mean pooling)
        genome_embedding = np.mean(embeddings, axis=0)
        return genome_embedding


class AMRGeneDetector:
    """Detect known AMR genes using CARD database"""
    
    def __init__(self, card_db_path="data/external/card"):
        self.card_sequences = self.load_card_database(card_db_path)
        
    def load_card_database(self, card_path):
        """Load CARD AMR gene sequences"""
        card_genes = {}
        # Load from CARD FASTA file
        fasta_path = f"{card_path}/nucleotide_fasta_protein_homolog_model.fasta"
        
        try:
            for record in SeqIO.parse(fasta_path, "fasta"):
                # Parse gene name and antibiotic class
                gene_info = self.parse_card_header(record.description)
                card_genes[record.id] = {
                    'sequence': str(record.seq),
                    'gene_name': gene_info['gene_name'],
                    'drug_class': gene_info['drug_class']
                }
        except FileNotFoundError:
            print(f"CARD database not found at {fasta_path}")
            # Return empty dict for now
            return {}
        
        print(f"Loaded {len(card_genes)} AMR genes from CARD")
        return card_genes
    
    def parse_card_header(self, header: str) -> Dict:
        """Parse CARD FASTA header"""
        # Example: "ARO:3000026|mecA [Staphylococcus aureus]"
        parts = header.split('|')
        gene_name = parts[1].split('[')[0].strip() if len(parts) > 1 else "unknown"
        
        return {
            'gene_name': gene_name,
            'drug_class': 'beta-lactam'  # Simplified for now
        }
    
    def detect_amr_genes(self, genome_sequence: str) -> List[Dict]:
        """
        Detect AMR genes in genome using sequence similarity
        In production, use BLAST or MMseqs2
        """
        detected_genes = []
        
        # Simplified: check for exact substring matches
        # In production: use BLAST or diamond
        for gene_id, gene_info in self.card_sequences.items():
            if gene_info['sequence'] in genome_sequence:
                detected_genes.append({
                    'gene_id': gene_id,
                    'gene_name': gene_info['gene_name'],
                    'drug_class': gene_info['drug_class']
                })
        
        return detected_genes


class CombinedFeatureExtractor:
    """Combine protein embeddings and gene detection"""
    
    def __init__(self):
        self.protein_extractor = ProteinFeatureExtractor()
        self.gene_detector = AMRGeneDetector()
        
    def extract_features(self, genome_path: str) -> Dict:
        """Extract all features from genome"""
        # 1. Protein embeddings (320-dim from ESM-2)
        protein_features = self.protein_extractor.extract_genome_features(genome_path)
        
        # 2. Load genome for gene detection
        genome_seq = ""
        for record in SeqIO.parse(genome_path, "fasta"):
            genome_seq += str(record.seq)
        
        # 3. AMR gene detection
        detected_genes = self.gene_detector.detect_amr_genes(genome_seq)
        
        # 4. Create gene presence/absence vector
        gene_features = self.create_gene_feature_vector(detected_genes)
        
        # 5. Combine features
        combined_features = np.concatenate([protein_features, gene_features])
        
        return {
            'features': combined_features,
            'detected_genes': detected_genes,
            'feature_dim': len(combined_features)
        }
    
    def create_gene_feature_vector(self, detected_genes: List[Dict], num_genes=50) -> np.ndarray:
        """Create binary vector for gene presence/absence"""
        # Top 50 most important AMR genes
        important_genes = [
            'mecA', 'vanA', 'blaCTX-M', 'blaKPC', 'blaNDM', 'blaOXA',
            'ermB', 'tetM', 'aac', 'aph', 'sul1', 'sul2', 'dfrA'
        ]
        
        gene_vector = np.zeros(num_genes)
        detected_names = [g['gene_name'] for g in detected_genes]
        
        for i, gene in enumerate(important_genes[:num_genes]):
            if any(gene in name for name in detected_names):
                gene_vector[i] = 1
        
        return gene_vector