VParka commited on
Commit
b8cde37
·
verified ·
1 Parent(s): bbef7d6

Upload inference.py via DNA Console (Portable Version)

Browse files
Files changed (1) hide show
  1. inference.py +162 -0
inference.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import joblib
3
+ import numpy as np
4
+ import math
5
+ from collections import Counter
6
+
7
+ class BiologicalFeatureExtractor:
8
+ """Standalone extractor for GenetiForest (RandomForest)"""
9
+ def __init__(self, kmer_size=3):
10
+ self.kmer_size = kmer_size
11
+ self.kmers = self._generate_kmers(kmer_size)
12
+
13
+ def _generate_kmers(self, k):
14
+ bases = ['A', 'C', 'G', 'T']
15
+ if k == 1: return bases
16
+ return [b + s for b in bases for s in self._generate_kmers(k-1)]
17
+
18
+ def transform(self, X):
19
+ features = []
20
+ for seq in X:
21
+ seq = seq.upper().replace('U', 'T')
22
+ row = []
23
+ length = len(seq)
24
+ # 1. GC Content
25
+ gc_content = (seq.count('G') + seq.count('C')) / length if length > 0 else 0
26
+ row.append(gc_content)
27
+ # 2. Shannon Entropy
28
+ row.append(self._calculate_entropy(seq))
29
+ # 3. K-mer Frequency
30
+ total_kmers = length - self.kmer_size + 1
31
+ if total_kmers > 0:
32
+ counts = Counter([seq[i:i+self.kmer_size] for i in range(total_kmers)])
33
+ for kmer in self.kmers:
34
+ row.append(counts.get(kmer, 0) / total_kmers)
35
+ else:
36
+ row.extend([0] * len(self.kmers))
37
+ features.append(row)
38
+ return np.array(features)
39
+
40
+ def _calculate_entropy(self, seq):
41
+ if not seq: return 0
42
+ counts = Counter(seq)
43
+ total = len(seq)
44
+ entropy = 0
45
+ for count in counts.values():
46
+ p = count / total
47
+ entropy -= p * math.log2(p)
48
+ return entropy
49
+
50
+ class SequenceFeatureExtractor:
51
+ """Standalone extractor for ViralBoost (GradientBoosting)"""
52
+ def __init__(self, kmer_size=5):
53
+ self.kmer_size = kmer_size
54
+ self.kmers = self._generate_kmers(kmer_size)
55
+ self.dinucleotides = ['AA', 'AT', 'AG', 'AC', 'TA', 'TT', 'TG', 'TC',
56
+ 'GA', 'GT', 'GG', 'GC', 'CA', 'CT', 'CG', 'CC']
57
+
58
+ def _generate_kmers(self, k):
59
+ bases = ['A', 'C', 'G', 'T']
60
+ if k == 1: return bases
61
+ return [b + s for b in bases for s in self._generate_kmers(k-1)]
62
+
63
+ def transform(self, X):
64
+ features = []
65
+ for seq in X:
66
+ seq = seq.upper().replace('U', 'T')
67
+ row = []
68
+ length = len(seq)
69
+ row.append((seq.count('G') + seq.count('C')) / length if length > 0 else 0) # GC
70
+ row.append(self._calc_skew(seq, 'G', 'C')) # GC Skew
71
+ row.append(self._calc_skew(seq, 'A', 'T')) # AT Skew
72
+ row.append(self._calc_entropy(seq)) # Entropy
73
+ # 5-mer (Top 20)
74
+ t_kmers = length - self.kmer_size + 1
75
+ if t_kmers > 0:
76
+ k_counts = Counter([seq[i:i+self.kmer_size] for i in range(t_kmers)])
77
+ row.extend([k_counts.get(k, 0) / t_kmers for k in self.kmers[:20]])
78
+ else:
79
+ row.extend([0] * 20)
80
+ # Dinucleotides
81
+ t_di = length - 1
82
+ if t_di > 0:
83
+ d_counts = Counter([seq[i:i+2] for i in range(t_di)])
84
+ row.extend([d_counts.get(d, 0) / t_di for d in self.dinucleotides])
85
+ else:
86
+ row.extend([0] * 16)
87
+ row.append(self._calc_repeat(seq)) # repeat score
88
+ row.append(self._calc_cpg(seq, length)) # CpG
89
+ row.extend(self._calc_codon_bias(seq)) # Codon Pos Bias
90
+ features.append(row)
91
+ return np.array(features)
92
+
93
+ def _calc_skew(self, seq, b1, b2):
94
+ c1, c2 = seq.count(b1), seq.count(b2)
95
+ return (c1 - c2) / (c1 + c2) if (c1 + c2) > 0 else 0
96
+ def _calc_entropy(self, seq):
97
+ if not seq: return 0
98
+ c = Counter(seq); t = len(seq); e = 0
99
+ for v in c.values():
100
+ p = v/t
101
+ if p > 0: e -= p * math.log2(p)
102
+ return e
103
+ def _calc_repeat(self, seq):
104
+ if len(seq) < 6: return 0
105
+ cnt = 0
106
+ for l in [2, 3, 4]:
107
+ for i in range(len(seq) - l*2):
108
+ if seq[i:i+l] == seq[i+l:i+l*2]: cnt += 1
109
+ return cnt / len(seq)
110
+ def _calc_cpg(self, seq, length):
111
+ if length < 2: return 0
112
+ obs = seq.count('CG')
113
+ exp = (seq.count('C') * seq.count('G')) / length
114
+ return obs / exp if exp > 0 else 0
115
+ def _calc_codon_bias(self, seq):
116
+ if len(seq) < 3: return [0] * 12
117
+ p_c = [{}, {}, {}]
118
+ for i in range(0, len(seq)-2, 3):
119
+ for j in range(3):
120
+ b = seq[i+j]
121
+ if b in 'ATGC': p_c[j][b] = p_c[j].get(b, 0) + 1
122
+ res = []
123
+ for p in range(3):
124
+ t = sum(p_c[p].values()) or 1
125
+ for b in 'ATGC': res.append(p_c[p].get(b, 0) / t)
126
+ return res
127
+
128
+ def predict_dna(sequence):
129
+ # Load Models
130
+ rf_model = joblib.load("dna_classifier.joblib")
131
+ rf_scaler = joblib.load("scaler_rf.joblib")
132
+ gb_model = joblib.load("sequence_model.joblib")
133
+ gb_scaler = joblib.load("scaler_gb.joblib")
134
+
135
+ # 1. GenetiForest Prediction (Synthetic vs Biological)
136
+ extractor_rf = BiologicalFeatureExtractor()
137
+ feat_rf = extractor_rf.transform([sequence])
138
+ scaled_rf = rf_scaler.transform(feat_rf)
139
+ type_basic = rf_model.predict(scaled_rf)[0]
140
+
141
+ # 2. ViralBoost Prediction (Virus Type)
142
+ extractor_gb = SequenceFeatureExtractor()
143
+ feat_gb = extractor_gb.transform([sequence])
144
+ scaled_gb = gb_scaler.transform(feat_gb)
145
+ type_virus = gb_model.predict(scaled_gb)[0]
146
+
147
+ return {
148
+ "classification": type_basic,
149
+ "virus_identity": type_virus
150
+ }
151
+
152
+ if __name__ == "__main__":
153
+ # Example usage
154
+ test_seq = "ATGCTAGCTAGCTAGCTAGCGGCTAGCTAGCTAGCTAGCTAGC"
155
+ try:
156
+ results = predict_dna(test_seq)
157
+ print(f"Results for sequence: {test_seq[:20]}...")
158
+ print(f"GenetiForest Result: {results['classification']}")
159
+ print(f"ViralBoost Result: {results['virus_identity']}")
160
+ except Exception as e:
161
+ print(f"Error: {e}")
162
+ print("Ensure all .joblib files are in the same directory.")