Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| K-mer-based group prediction for unknown sequences. | |
| Inputs: | |
| - Unknown sequences: a FASTA file or a directory of FASTA files | |
| - Unique k-mers: either | |
| * a directory containing unique_k{k}_{group}.tsv/.txt files (from script #1), OR | |
| * a ZIP file containing those files | |
| Modes: | |
| - fast: exact substring matching only (very fast) | |
| - full: alignment-based matching (slower, more tolerant) + Fisher + FDR | |
| Outputs: | |
| - predictions_by_alignment.csv | |
| - predicted_results_summary.png | |
| Example: | |
| python kmer_predict.py \ | |
| --unknown unknown_fastas/ \ | |
| --kmer-input kmer_results.zip \ | |
| --outdir pred_out \ | |
| --seqtype dna \ | |
| --mode fast | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import re | |
| import shutil | |
| import tempfile | |
| import zipfile | |
| from dataclasses import dataclass | |
| from typing import Dict, Iterable, List, Optional, Sequence, Tuple | |
| import pandas as pd | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from scipy.stats import fisher_exact | |
| from statsmodels.stats.multitest import multipletests | |
| from Bio import Align | |
| from Bio.Align import substitution_matrices | |
| FASTA_EXTS = (".fasta", ".fa", ".fas", ".fna") | |
| KMER_FILE_EXTS = (".tsv", ".txt") | |
| DEFAULT_GROUP_REGEX = r"unique_k\d+_(.+)\.(tsv|txt)$" | |
| BLOSUM62 = substitution_matrices.load("BLOSUM62") | |
| # ---------------------------- | |
| # FASTA utilities | |
| # ---------------------------- | |
| def read_fasta(filepath: str) -> Tuple[List[str], List[str]]: | |
| headers, seqs, seq = [], [], [] | |
| with open(filepath, "r", encoding="utf-8") as fh: | |
| for line in fh: | |
| line = line.rstrip("\n") | |
| if not line: | |
| continue | |
| if line.startswith(">"): | |
| if seq: | |
| seqs.append("".join(seq)) | |
| seq = [] | |
| headers.append(line[1:].strip()) | |
| else: | |
| seq.append(line.strip().upper()) | |
| if seq: | |
| seqs.append("".join(seq)) | |
| return headers, seqs | |
| def clean_protein(seq: str) -> str: | |
| return re.sub(r"[^ACDEFGHIKLMNPQRSTVWY]", "", seq.upper()) | |
| def clean_dna(seq: str) -> str: | |
| # allow U and N like your original | |
| return re.sub(r"[^ACGTUN]", "", seq.upper()) | |
| def iter_unknown_sequences(unknown: str, is_protein: bool) -> List[Tuple[str, str, str]]: | |
| """ | |
| Returns list of (source_file, header, cleaned_seq). | |
| unknown can be a fasta file or a directory with fasta files. | |
| """ | |
| seq_index: List[Tuple[str, str, str]] = [] | |
| if os.path.isdir(unknown): | |
| files = [ | |
| os.path.join(unknown, f) | |
| for f in os.listdir(unknown) | |
| if f.lower().endswith(FASTA_EXTS) | |
| ] | |
| else: | |
| files = [unknown] | |
| files = [f for f in files if os.path.isfile(f)] | |
| for fp in sorted(files): | |
| headers, seqs = read_fasta(fp) | |
| if is_protein: | |
| seqs = [clean_protein(s) for s in seqs] | |
| else: | |
| seqs = [clean_dna(s) for s in seqs] | |
| for h, s in zip(headers, seqs): | |
| if s: # drop empty after cleaning | |
| seq_index.append((fp, h, s)) | |
| return seq_index | |
| # ---------------------------- | |
| # ZIP utilities (safe extract) | |
| # ---------------------------- | |
| def safe_extract_zip(zip_path: str, dst_dir: str) -> None: | |
| """Extract ZIP safely (prevents zip-slip).""" | |
| with zipfile.ZipFile(zip_path, "r") as z: | |
| for member in z.infolist(): | |
| if member.is_dir(): | |
| continue | |
| target = os.path.normpath(os.path.join(dst_dir, member.filename)) | |
| if not target.startswith(os.path.abspath(dst_dir) + os.sep): | |
| continue # skip suspicious paths | |
| os.makedirs(os.path.dirname(target), exist_ok=True) | |
| with z.open(member) as src, open(target, "wb") as out: | |
| shutil.copyfileobj(src, out) | |
| # ---------------------------- | |
| # Load unique kmers | |
| # ---------------------------- | |
| class KmerDB: | |
| group_kmers: Dict[str, List[str]] | |
| source_dir: str | |
| def parse_group_from_filename(fname: str, group_regex: str) -> str: | |
| m = re.search(group_regex, fname, re.IGNORECASE) | |
| if m: | |
| return m.group(1) | |
| # fallback: remove extension | |
| return os.path.splitext(fname)[0] | |
| def load_unique_kmers_from_dir( | |
| kmer_dir: str, | |
| is_protein: bool, | |
| group_regex: str = DEFAULT_GROUP_REGEX, | |
| ) -> KmerDB: | |
| """ | |
| Loads k-mers from files like: | |
| unique_k15_group1.tsv | |
| unique_k20_groupA.txt | |
| Accepts TSV or TXT; ignores comment/header lines. | |
| """ | |
| group_kmers: Dict[str, List[str]] = {} | |
| for fname in sorted(os.listdir(kmer_dir)): | |
| if not fname.lower().endswith(KMER_FILE_EXTS): | |
| continue | |
| fpath = os.path.join(kmer_dir, fname) | |
| if not os.path.isfile(fpath): | |
| continue | |
| group = parse_group_from_filename(fname, group_regex) | |
| group = group.strip() | |
| group_kmers.setdefault(group, []) | |
| with open(fpath, "r", encoding="utf-8") as fh: | |
| for line in fh: | |
| line = line.strip() | |
| if (not line) or line.startswith("#"): | |
| continue | |
| if line.lower().startswith(("kmer", "total")): | |
| continue | |
| token = line.split()[0].upper() | |
| token = clean_protein(token) if is_protein else clean_dna(token) | |
| if token: | |
| group_kmers[group].append(token) | |
| # Deduplicate while preserving order | |
| for g in list(group_kmers.keys()): | |
| group_kmers[g] = list(dict.fromkeys(group_kmers[g])) | |
| if len(group_kmers[g]) == 0: | |
| # drop empty groups | |
| del group_kmers[g] | |
| if not group_kmers: | |
| raise FileNotFoundError(f"No k-mer files found in: {kmer_dir}") | |
| return KmerDB(group_kmers=group_kmers, source_dir=kmer_dir) | |
| def load_unique_kmers(kmer_input: str, is_protein: bool, group_regex: str) -> KmerDB: | |
| """ | |
| kmer_input can be a directory OR a .zip file containing k-mer output files. | |
| """ | |
| if os.path.isdir(kmer_input): | |
| return load_unique_kmers_from_dir(kmer_input, is_protein, group_regex=group_regex) | |
| if os.path.isfile(kmer_input) and kmer_input.lower().endswith(".zip"): | |
| tmp = tempfile.mkdtemp(prefix="kmerdb_") | |
| safe_extract_zip(kmer_input, tmp) | |
| # find a directory inside that actually contains kmer files; simplest: use tmp itself | |
| return load_unique_kmers_from_dir(tmp, is_protein, group_regex=group_regex) | |
| raise FileNotFoundError(f"--kmer-input must be a directory or a .zip file: {kmer_input}") | |
| # ---------------------------- | |
| # Matching | |
| # ---------------------------- | |
| def align_kmer_to_seq( | |
| kmer: str, | |
| seq: str, | |
| is_protein: bool, | |
| identity_threshold: float = 0.9, | |
| min_coverage: float = 0.8, | |
| gap_open: float = -10, | |
| gap_extend: float = -0.5, | |
| nuc_match: float = 2, | |
| nuc_mismatch: float = -1, | |
| nuc_gap_open: float = -5, | |
| nuc_gap_extend: float = -1, | |
| ) -> bool: | |
| if not kmer or not seq: | |
| return False | |
| # Fast exact substring path | |
| if identity_threshold == 1.0 and min_coverage == 1.0: | |
| return kmer in seq | |
| if len(kmer) <= 3: | |
| return kmer in seq | |
| try: | |
| aligner = Align.PairwiseAligner() | |
| if is_protein: | |
| aligner.substitution_matrix = BLOSUM62 | |
| aligner.open_gap_score = gap_open | |
| aligner.extend_gap_score = gap_extend | |
| else: | |
| aligner.match_score = nuc_match | |
| aligner.mismatch_score = nuc_mismatch | |
| aligner.open_gap_score = nuc_gap_open | |
| aligner.extend_gap_score = nuc_gap_extend | |
| alns = aligner.align(kmer, seq) | |
| if not alns: | |
| return False | |
| aln = alns[0] | |
| aligned_query = aln.aligned[0] | |
| aligned_target = aln.aligned[1] | |
| aligned_len = sum(e - s for s, e in aligned_query) | |
| if aligned_len == 0: | |
| return False | |
| matches = 0 | |
| for (qs, qe), (ts, te) in zip(aligned_query, aligned_target): | |
| subseq_q = kmer[qs:qe] | |
| subseq_t = seq[ts:te] | |
| matches += sum(1 for a, b in zip(subseq_q, subseq_t) if a == b) | |
| identity = matches / aligned_len | |
| coverage = aligned_len / len(kmer) | |
| return (identity >= identity_threshold) and (coverage >= min_coverage) | |
| except Exception: | |
| return False | |
| # ---------------------------- | |
| # Prediction core | |
| # ---------------------------- | |
| def predict( | |
| unknown: str, | |
| kmer_input: str, | |
| output_dir: str, | |
| seqtype: str, | |
| mode: str, | |
| identity_threshold: float, | |
| min_coverage: float, | |
| fdr_alpha: float, | |
| group_regex: str, | |
| ) -> pd.DataFrame: | |
| is_protein = (seqtype.lower() == "protein") | |
| mode = mode.lower().strip() | |
| if mode not in {"fast", "full"}: | |
| raise ValueError("--mode must be 'fast' or 'full'") | |
| # Load kmers (dir or zip) | |
| db = load_unique_kmers(kmer_input, is_protein=is_protein, group_regex=group_regex) | |
| group_kmers = db.group_kmers | |
| print(f"Loaded k-mer counts: { {g: len(group_kmers[g]) for g in group_kmers} }") | |
| # Unknown sequences | |
| seq_index = iter_unknown_sequences(unknown, is_protein=is_protein) | |
| if not seq_index: | |
| raise FileNotFoundError("No sequences found in --unknown (file/dir).") | |
| # Mode parameters | |
| if mode == "fast": | |
| eff_identity = 1.0 | |
| eff_coverage = 1.0 | |
| compute_stats = False | |
| else: | |
| eff_identity = float(identity_threshold) | |
| eff_coverage = float(min_coverage) | |
| compute_stats = True | |
| results: List[dict] = [] | |
| total_seqs = len(seq_index) | |
| for i, (srcfile, header, seq) in enumerate(seq_index, start=1): | |
| print(f"Processing sequence {i}/{total_seqs} ({os.path.basename(srcfile)})") | |
| group_found_counts = {g: 0 for g in group_kmers} | |
| total_found = 0 | |
| for g, kmers in group_kmers.items(): | |
| for kmer in kmers: | |
| if align_kmer_to_seq( | |
| kmer, seq, is_protein=is_protein, | |
| identity_threshold=eff_identity, | |
| min_coverage=eff_coverage, | |
| ): | |
| group_found_counts[g] += 1 | |
| total_found += 1 | |
| predicted = max(group_found_counts, key=group_found_counts.get) | |
| conf_present = (group_found_counts[predicted] / total_found) if total_found else 0.0 | |
| conf_vocab = group_found_counts[predicted] / max(1, len(group_kmers[predicted])) | |
| row = { | |
| "Source_file": os.path.basename(srcfile), | |
| "Sequence": header, | |
| "Predicted_group": predicted, | |
| "Matches_total": total_found, | |
| **{f"Matches_{g}": group_found_counts[g] for g in group_kmers}, | |
| "Confidence_by_present": conf_present, | |
| "Confidence_by_group_vocab": conf_vocab, | |
| } | |
| if compute_stats: | |
| fisher_p = {} | |
| # total vocabulary size of "other groups" for contingency table | |
| other_vocab_total = {g: sum(len(group_kmers[og]) for og in group_kmers if og != g) for g in group_kmers} | |
| for g in group_kmers: | |
| a = group_found_counts[g] | |
| b = len(group_kmers[g]) - a | |
| c = total_found - a | |
| d = other_vocab_total[g] - c | |
| if d < 0: | |
| d = 0 | |
| table = [[a, b], [c, d]] | |
| _, p = fisher_exact(table, alternative="greater") | |
| fisher_p[g] = p | |
| row.update({f"FisherP_{g}": fisher_p[g] for g in group_kmers}) | |
| results.append(row) | |
| df = pd.DataFrame(results) | |
| # FDR correction (full mode) | |
| if mode == "full": | |
| fisher_cols = [c for c in df.columns if c.startswith("FisherP_")] | |
| if fisher_cols: | |
| all_pvals = df[fisher_cols].values.flatten() | |
| _, qvals, _, _ = multipletests(all_pvals, alpha=float(fdr_alpha), method="fdr_bh") | |
| qval_matrix = qvals.reshape(df[fisher_cols].shape) | |
| for j, col in enumerate(fisher_cols): | |
| grp = col.split("_", 1)[1] | |
| df[f"FDR_{grp}"] = qval_matrix[:, j] | |
| # Save | |
| os.makedirs(output_dir, exist_ok=True) | |
| out_csv = os.path.join(output_dir, "predictions_by_alignment.csv") | |
| df.to_csv(out_csv, index=False) | |
| print(f"Saved predictions to {out_csv}") | |
| # Plot | |
| save_summary_plot(df, output_dir) | |
| return df | |
| def save_summary_plot(df: pd.DataFrame, output_dir: str) -> None: | |
| """ | |
| Matplotlib-only summary figure: | |
| - Left: predicted group counts | |
| - Right: confidence distribution (boxplot) | |
| """ | |
| fig, axes = plt.subplots(1, 2, figsize=(12, 5)) | |
| # Left: bar counts | |
| counts = df["Predicted_group"].value_counts() | |
| axes[0].bar(counts.index.astype(str), counts.values) | |
| axes[0].set_xlabel("Predicted Group") | |
| axes[0].set_ylabel("Number of Sequences") | |
| axes[0].set_title("Predicted Group Counts") | |
| axes[0].tick_params(axis="x", rotation=45) | |
| # Right: boxplot confidence_by_present by group | |
| groups = sorted(df["Predicted_group"].unique().tolist()) | |
| data = [df.loc[df["Predicted_group"] == g, "Confidence_by_present"].values for g in groups] | |
| axes[1].boxplot(data, labels=groups, showfliers=False) | |
| axes[1].set_title("Prediction Confidence (by Present)") | |
| axes[1].set_xlabel("Predicted Group") | |
| axes[1].set_ylabel("Confidence") | |
| axes[1].tick_params(axis="x", rotation=45) | |
| fig.tight_layout() | |
| fig.savefig(os.path.join(output_dir, "predicted_results_summary.png"), dpi=300) | |
| plt.close(fig) | |
| # ---------------------------- | |
| # CLI | |
| # ---------------------------- | |
| def build_parser() -> argparse.ArgumentParser: | |
| p = argparse.ArgumentParser( | |
| description="Predict group membership of unknown sequences using unique k-mers.", | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| ) | |
| p.add_argument("--unknown", required=True, help="Unknown FASTA file OR directory of FASTA files.") | |
| p.add_argument("--kmer-input", required=True, help="Directory of unique_k*.tsv/txt OR a ZIP containing them.") | |
| p.add_argument("--outdir", default="prediction_results", help="Output directory.") | |
| p.add_argument("--seqtype", choices=["dna", "protein"], default="dna", help="Sequence type.") | |
| p.add_argument("--mode", choices=["fast", "full"], default="fast", help="fast=substring only; full=alignment+Fisher+FDR.") | |
| p.add_argument("--identity", type=float, default=0.9, help="Alignment identity threshold (full mode only).") | |
| p.add_argument("--coverage", type=float, default=0.8, help="Alignment coverage threshold (full mode only).") | |
| p.add_argument("--fdr", type=float, default=0.05, help="FDR alpha (full mode only).") | |
| p.add_argument( | |
| "--group-regex", | |
| default=DEFAULT_GROUP_REGEX, | |
| help="Regex to extract group name from k-mer filenames (1st capture group = group).", | |
| ) | |
| return p | |
| def main() -> None: | |
| args = build_parser().parse_args() | |
| # Validate unknown | |
| if not os.path.exists(args.unknown): | |
| raise FileNotFoundError(f"--unknown not found: {args.unknown}") | |
| # Run | |
| predict( | |
| unknown=args.unknown, | |
| kmer_input=args.kmer_input, | |
| output_dir=args.outdir, | |
| seqtype=args.seqtype, | |
| mode=args.mode, | |
| identity_threshold=args.identity, | |
| min_coverage=args.coverage, | |
| fdr_alpha=args.fdr, | |
| group_regex=args.group_regex, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |