Muhamed-Kheir commited on
Commit
137d7d8
·
verified ·
1 Parent(s): 30e2586

Upload kmer_predict.py

Browse files
Files changed (1) hide show
  1. kmer_predict.py +485 -0
kmer_predict.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ K-mer-based group prediction for unknown sequences.
4
+
5
+ Inputs:
6
+ - Unknown sequences: a FASTA file or a directory of FASTA files
7
+ - Unique k-mers: either
8
+ * a directory containing unique_k{k}_{group}.tsv/.txt files (from script #1), OR
9
+ * a ZIP file containing those files
10
+
11
+ Modes:
12
+ - fast: exact substring matching only (very fast)
13
+ - full: alignment-based matching (slower, more tolerant) + Fisher + FDR
14
+
15
+ Outputs:
16
+ - predictions_by_alignment.csv
17
+ - predicted_results_summary.png
18
+
19
+ Example:
20
+ python kmer_predict.py \
21
+ --unknown unknown_fastas/ \
22
+ --kmer-input kmer_results.zip \
23
+ --outdir pred_out \
24
+ --seqtype dna \
25
+ --mode fast
26
+
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import argparse
32
+ import os
33
+ import re
34
+ import shutil
35
+ import tempfile
36
+ import zipfile
37
+ from dataclasses import dataclass
38
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
39
+
40
+ import pandas as pd
41
+ import matplotlib
42
+ matplotlib.use("Agg")
43
+ import matplotlib.pyplot as plt
44
+
45
+ from scipy.stats import fisher_exact
46
+ from statsmodels.stats.multitest import multipletests
47
+
48
+ from Bio import Align
49
+ from Bio.Align import substitution_matrices
50
+
51
+
52
+ FASTA_EXTS = (".fasta", ".fa", ".fas", ".fna")
53
+ KMER_FILE_EXTS = (".tsv", ".txt")
54
+ DEFAULT_GROUP_REGEX = r"unique_k\d+_(.+)\.(tsv|txt)$"
55
+
56
+ BLOSUM62 = substitution_matrices.load("BLOSUM62")
57
+
58
+
59
+ # ----------------------------
60
+ # FASTA utilities
61
+ # ----------------------------
62
+
63
+ def read_fasta(filepath: str) -> Tuple[List[str], List[str]]:
64
+ headers, seqs, seq = [], [], []
65
+ with open(filepath, "r", encoding="utf-8") as fh:
66
+ for line in fh:
67
+ line = line.rstrip("\n")
68
+ if not line:
69
+ continue
70
+ if line.startswith(">"):
71
+ if seq:
72
+ seqs.append("".join(seq))
73
+ seq = []
74
+ headers.append(line[1:].strip())
75
+ else:
76
+ seq.append(line.strip().upper())
77
+ if seq:
78
+ seqs.append("".join(seq))
79
+ return headers, seqs
80
+
81
+
82
+ def clean_protein(seq: str) -> str:
83
+ return re.sub(r"[^ACDEFGHIKLMNPQRSTVWY]", "", seq.upper())
84
+
85
+
86
+ def clean_dna(seq: str) -> str:
87
+ # allow U and N like your original
88
+ return re.sub(r"[^ACGTUN]", "", seq.upper())
89
+
90
+
91
+ def iter_unknown_sequences(unknown: str, is_protein: bool) -> List[Tuple[str, str, str]]:
92
+ """
93
+ Returns list of (source_file, header, cleaned_seq).
94
+ unknown can be a fasta file or a directory with fasta files.
95
+ """
96
+ seq_index: List[Tuple[str, str, str]] = []
97
+
98
+ if os.path.isdir(unknown):
99
+ files = [
100
+ os.path.join(unknown, f)
101
+ for f in os.listdir(unknown)
102
+ if f.lower().endswith(FASTA_EXTS)
103
+ ]
104
+ else:
105
+ files = [unknown]
106
+
107
+ files = [f for f in files if os.path.isfile(f)]
108
+ for fp in sorted(files):
109
+ headers, seqs = read_fasta(fp)
110
+ if is_protein:
111
+ seqs = [clean_protein(s) for s in seqs]
112
+ else:
113
+ seqs = [clean_dna(s) for s in seqs]
114
+
115
+ for h, s in zip(headers, seqs):
116
+ if s: # drop empty after cleaning
117
+ seq_index.append((fp, h, s))
118
+
119
+ return seq_index
120
+
121
+
122
+ # ----------------------------
123
+ # ZIP utilities (safe extract)
124
+ # ----------------------------
125
+
126
+ def safe_extract_zip(zip_path: str, dst_dir: str) -> None:
127
+ """Extract ZIP safely (prevents zip-slip)."""
128
+ with zipfile.ZipFile(zip_path, "r") as z:
129
+ for member in z.infolist():
130
+ if member.is_dir():
131
+ continue
132
+ target = os.path.normpath(os.path.join(dst_dir, member.filename))
133
+ if not target.startswith(os.path.abspath(dst_dir) + os.sep):
134
+ continue # skip suspicious paths
135
+ os.makedirs(os.path.dirname(target), exist_ok=True)
136
+ with z.open(member) as src, open(target, "wb") as out:
137
+ shutil.copyfileobj(src, out)
138
+
139
+
140
+ # ----------------------------
141
+ # Load unique kmers
142
+ # ----------------------------
143
+
144
+ @dataclass
145
+ class KmerDB:
146
+ group_kmers: Dict[str, List[str]]
147
+ source_dir: str
148
+
149
+
150
+ def parse_group_from_filename(fname: str, group_regex: str) -> str:
151
+ m = re.search(group_regex, fname, re.IGNORECASE)
152
+ if m:
153
+ return m.group(1)
154
+ # fallback: remove extension
155
+ return os.path.splitext(fname)[0]
156
+
157
+
158
+ def load_unique_kmers_from_dir(
159
+ kmer_dir: str,
160
+ is_protein: bool,
161
+ group_regex: str = DEFAULT_GROUP_REGEX,
162
+ ) -> KmerDB:
163
+ """
164
+ Loads k-mers from files like:
165
+ unique_k15_group1.tsv
166
+ unique_k20_groupA.txt
167
+ Accepts TSV or TXT; ignores comment/header lines.
168
+ """
169
+ group_kmers: Dict[str, List[str]] = {}
170
+
171
+ for fname in sorted(os.listdir(kmer_dir)):
172
+ if not fname.lower().endswith(KMER_FILE_EXTS):
173
+ continue
174
+
175
+ fpath = os.path.join(kmer_dir, fname)
176
+ if not os.path.isfile(fpath):
177
+ continue
178
+
179
+ group = parse_group_from_filename(fname, group_regex)
180
+ group = group.strip()
181
+
182
+ group_kmers.setdefault(group, [])
183
+
184
+ with open(fpath, "r", encoding="utf-8") as fh:
185
+ for line in fh:
186
+ line = line.strip()
187
+ if (not line) or line.startswith("#"):
188
+ continue
189
+ if line.lower().startswith(("kmer", "total")):
190
+ continue
191
+
192
+ token = line.split()[0].upper()
193
+ token = clean_protein(token) if is_protein else clean_dna(token)
194
+ if token:
195
+ group_kmers[group].append(token)
196
+
197
+ # Deduplicate while preserving order
198
+ for g in list(group_kmers.keys()):
199
+ group_kmers[g] = list(dict.fromkeys(group_kmers[g]))
200
+ if len(group_kmers[g]) == 0:
201
+ # drop empty groups
202
+ del group_kmers[g]
203
+
204
+ if not group_kmers:
205
+ raise FileNotFoundError(f"No k-mer files found in: {kmer_dir}")
206
+
207
+ return KmerDB(group_kmers=group_kmers, source_dir=kmer_dir)
208
+
209
+
210
+ def load_unique_kmers(kmer_input: str, is_protein: bool, group_regex: str) -> KmerDB:
211
+ """
212
+ kmer_input can be a directory OR a .zip file containing k-mer output files.
213
+ """
214
+ if os.path.isdir(kmer_input):
215
+ return load_unique_kmers_from_dir(kmer_input, is_protein, group_regex=group_regex)
216
+
217
+ if os.path.isfile(kmer_input) and kmer_input.lower().endswith(".zip"):
218
+ tmp = tempfile.mkdtemp(prefix="kmerdb_")
219
+ safe_extract_zip(kmer_input, tmp)
220
+ # find a directory inside that actually contains kmer files; simplest: use tmp itself
221
+ return load_unique_kmers_from_dir(tmp, is_protein, group_regex=group_regex)
222
+
223
+ raise FileNotFoundError(f"--kmer-input must be a directory or a .zip file: {kmer_input}")
224
+
225
+
226
+ # ----------------------------
227
+ # Matching
228
+ # ----------------------------
229
+
230
+ def align_kmer_to_seq(
231
+ kmer: str,
232
+ seq: str,
233
+ is_protein: bool,
234
+ identity_threshold: float = 0.9,
235
+ min_coverage: float = 0.8,
236
+ gap_open: float = -10,
237
+ gap_extend: float = -0.5,
238
+ nuc_match: float = 2,
239
+ nuc_mismatch: float = -1,
240
+ nuc_gap_open: float = -5,
241
+ nuc_gap_extend: float = -1,
242
+ ) -> bool:
243
+ if not kmer or not seq:
244
+ return False
245
+
246
+ # Fast exact substring path
247
+ if identity_threshold == 1.0 and min_coverage == 1.0:
248
+ return kmer in seq
249
+ if len(kmer) <= 3:
250
+ return kmer in seq
251
+
252
+ try:
253
+ aligner = Align.PairwiseAligner()
254
+ if is_protein:
255
+ aligner.substitution_matrix = BLOSUM62
256
+ aligner.open_gap_score = gap_open
257
+ aligner.extend_gap_score = gap_extend
258
+ else:
259
+ aligner.match_score = nuc_match
260
+ aligner.mismatch_score = nuc_mismatch
261
+ aligner.open_gap_score = nuc_gap_open
262
+ aligner.extend_gap_score = nuc_gap_extend
263
+
264
+ alns = aligner.align(kmer, seq)
265
+ if not alns:
266
+ return False
267
+
268
+ aln = alns[0]
269
+ aligned_query = aln.aligned[0]
270
+ aligned_target = aln.aligned[1]
271
+
272
+ aligned_len = sum(e - s for s, e in aligned_query)
273
+ if aligned_len == 0:
274
+ return False
275
+
276
+ matches = 0
277
+ for (qs, qe), (ts, te) in zip(aligned_query, aligned_target):
278
+ subseq_q = kmer[qs:qe]
279
+ subseq_t = seq[ts:te]
280
+ matches += sum(1 for a, b in zip(subseq_q, subseq_t) if a == b)
281
+
282
+ identity = matches / aligned_len
283
+ coverage = aligned_len / len(kmer)
284
+ return (identity >= identity_threshold) and (coverage >= min_coverage)
285
+
286
+ except Exception:
287
+ return False
288
+
289
+
290
+ # ----------------------------
291
+ # Prediction core
292
+ # ----------------------------
293
+
294
+ def predict(
295
+ unknown: str,
296
+ kmer_input: str,
297
+ output_dir: str,
298
+ seqtype: str,
299
+ mode: str,
300
+ identity_threshold: float,
301
+ min_coverage: float,
302
+ fdr_alpha: float,
303
+ group_regex: str,
304
+ ) -> pd.DataFrame:
305
+ is_protein = (seqtype.lower() == "protein")
306
+ mode = mode.lower().strip()
307
+ if mode not in {"fast", "full"}:
308
+ raise ValueError("--mode must be 'fast' or 'full'")
309
+
310
+ # Load kmers (dir or zip)
311
+ db = load_unique_kmers(kmer_input, is_protein=is_protein, group_regex=group_regex)
312
+ group_kmers = db.group_kmers
313
+
314
+ print(f"Loaded k-mer counts: { {g: len(group_kmers[g]) for g in group_kmers} }")
315
+
316
+ # Unknown sequences
317
+ seq_index = iter_unknown_sequences(unknown, is_protein=is_protein)
318
+ if not seq_index:
319
+ raise FileNotFoundError("No sequences found in --unknown (file/dir).")
320
+
321
+ # Mode parameters
322
+ if mode == "fast":
323
+ eff_identity = 1.0
324
+ eff_coverage = 1.0
325
+ compute_stats = False
326
+ else:
327
+ eff_identity = float(identity_threshold)
328
+ eff_coverage = float(min_coverage)
329
+ compute_stats = True
330
+
331
+ results: List[dict] = []
332
+
333
+ total_seqs = len(seq_index)
334
+ for i, (srcfile, header, seq) in enumerate(seq_index, start=1):
335
+ print(f"Processing sequence {i}/{total_seqs} ({os.path.basename(srcfile)})")
336
+
337
+ group_found_counts = {g: 0 for g in group_kmers}
338
+ total_found = 0
339
+
340
+ for g, kmers in group_kmers.items():
341
+ for kmer in kmers:
342
+ if align_kmer_to_seq(
343
+ kmer, seq, is_protein=is_protein,
344
+ identity_threshold=eff_identity,
345
+ min_coverage=eff_coverage,
346
+ ):
347
+ group_found_counts[g] += 1
348
+ total_found += 1
349
+
350
+ predicted = max(group_found_counts, key=group_found_counts.get)
351
+ conf_present = (group_found_counts[predicted] / total_found) if total_found else 0.0
352
+ conf_vocab = group_found_counts[predicted] / max(1, len(group_kmers[predicted]))
353
+
354
+ row = {
355
+ "Source_file": os.path.basename(srcfile),
356
+ "Sequence": header,
357
+ "Predicted_group": predicted,
358
+ "Matches_total": total_found,
359
+ **{f"Matches_{g}": group_found_counts[g] for g in group_kmers},
360
+ "Confidence_by_present": conf_present,
361
+ "Confidence_by_group_vocab": conf_vocab,
362
+ }
363
+
364
+ if compute_stats:
365
+ fisher_p = {}
366
+ # total vocabulary size of "other groups" for contingency table
367
+ other_vocab_total = {g: sum(len(group_kmers[og]) for og in group_kmers if og != g) for g in group_kmers}
368
+
369
+ for g in group_kmers:
370
+ a = group_found_counts[g]
371
+ b = len(group_kmers[g]) - a
372
+ c = total_found - a
373
+ d = other_vocab_total[g] - c
374
+ if d < 0:
375
+ d = 0
376
+ table = [[a, b], [c, d]]
377
+ _, p = fisher_exact(table, alternative="greater")
378
+ fisher_p[g] = p
379
+ row.update({f"FisherP_{g}": fisher_p[g] for g in group_kmers})
380
+
381
+ results.append(row)
382
+
383
+ df = pd.DataFrame(results)
384
+
385
+ # FDR correction (full mode)
386
+ if mode == "full":
387
+ fisher_cols = [c for c in df.columns if c.startswith("FisherP_")]
388
+ if fisher_cols:
389
+ all_pvals = df[fisher_cols].values.flatten()
390
+ _, qvals, _, _ = multipletests(all_pvals, alpha=float(fdr_alpha), method="fdr_bh")
391
+ qval_matrix = qvals.reshape(df[fisher_cols].shape)
392
+ for j, col in enumerate(fisher_cols):
393
+ grp = col.split("_", 1)[1]
394
+ df[f"FDR_{grp}"] = qval_matrix[:, j]
395
+
396
+ # Save
397
+ os.makedirs(output_dir, exist_ok=True)
398
+ out_csv = os.path.join(output_dir, "predictions_by_alignment.csv")
399
+ df.to_csv(out_csv, index=False)
400
+ print(f"Saved predictions to {out_csv}")
401
+
402
+ # Plot
403
+ save_summary_plot(df, output_dir)
404
+
405
+ return df
406
+
407
+
408
+ def save_summary_plot(df: pd.DataFrame, output_dir: str) -> None:
409
+ """
410
+ Matplotlib-only summary figure:
411
+ - Left: predicted group counts
412
+ - Right: confidence distribution (boxplot)
413
+ """
414
+ fig, axes = plt.subplots(1, 2, figsize=(12, 5))
415
+
416
+ # Left: bar counts
417
+ counts = df["Predicted_group"].value_counts()
418
+ axes[0].bar(counts.index.astype(str), counts.values)
419
+ axes[0].set_xlabel("Predicted Group")
420
+ axes[0].set_ylabel("Number of Sequences")
421
+ axes[0].set_title("Predicted Group Counts")
422
+ axes[0].tick_params(axis="x", rotation=45)
423
+
424
+ # Right: boxplot confidence_by_present by group
425
+ groups = sorted(df["Predicted_group"].unique().tolist())
426
+ data = [df.loc[df["Predicted_group"] == g, "Confidence_by_present"].values for g in groups]
427
+ axes[1].boxplot(data, labels=groups, showfliers=False)
428
+ axes[1].set_title("Prediction Confidence (by Present)")
429
+ axes[1].set_xlabel("Predicted Group")
430
+ axes[1].set_ylabel("Confidence")
431
+ axes[1].tick_params(axis="x", rotation=45)
432
+
433
+ fig.tight_layout()
434
+ fig.savefig(os.path.join(output_dir, "predicted_results_summary.png"), dpi=300)
435
+ plt.close(fig)
436
+
437
+
438
+ # ----------------------------
439
+ # CLI
440
+ # ----------------------------
441
+
442
+ def build_parser() -> argparse.ArgumentParser:
443
+ p = argparse.ArgumentParser(
444
+ description="Predict group membership of unknown sequences using unique k-mers.",
445
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
446
+ )
447
+ p.add_argument("--unknown", required=True, help="Unknown FASTA file OR directory of FASTA files.")
448
+ p.add_argument("--kmer-input", required=True, help="Directory of unique_k*.tsv/txt OR a ZIP containing them.")
449
+ p.add_argument("--outdir", default="prediction_results", help="Output directory.")
450
+ p.add_argument("--seqtype", choices=["dna", "protein"], default="dna", help="Sequence type.")
451
+ p.add_argument("--mode", choices=["fast", "full"], default="fast", help="fast=substring only; full=alignment+Fisher+FDR.")
452
+ p.add_argument("--identity", type=float, default=0.9, help="Alignment identity threshold (full mode only).")
453
+ p.add_argument("--coverage", type=float, default=0.8, help="Alignment coverage threshold (full mode only).")
454
+ p.add_argument("--fdr", type=float, default=0.05, help="FDR alpha (full mode only).")
455
+ p.add_argument(
456
+ "--group-regex",
457
+ default=DEFAULT_GROUP_REGEX,
458
+ help="Regex to extract group name from k-mer filenames (1st capture group = group).",
459
+ )
460
+ return p
461
+
462
+
463
+ def main() -> None:
464
+ args = build_parser().parse_args()
465
+
466
+ # Validate unknown
467
+ if not os.path.exists(args.unknown):
468
+ raise FileNotFoundError(f"--unknown not found: {args.unknown}")
469
+
470
+ # Run
471
+ predict(
472
+ unknown=args.unknown,
473
+ kmer_input=args.kmer_input,
474
+ output_dir=args.outdir,
475
+ seqtype=args.seqtype,
476
+ mode=args.mode,
477
+ identity_threshold=args.identity,
478
+ min_coverage=args.coverage,
479
+ fdr_alpha=args.fdr,
480
+ group_regex=args.group_regex,
481
+ )
482
+
483
+
484
+ if __name__ == "__main__":
485
+ main()