#!/usr/bin/env python3 """ K-mer-based group prediction for unknown sequences. Inputs: - Unknown sequences: a FASTA file or a directory of FASTA files - Unique k-mers: either * a directory containing unique_k{k}_{group}.tsv/.txt files (from script #1), OR * a ZIP file containing those files Modes: - fast: exact substring matching only (very fast) - full: alignment-based matching (slower, more tolerant) + Fisher + FDR Outputs: - predictions_by_alignment.csv - predicted_results_summary.png Example: python kmer_predict.py \ --unknown unknown_fastas/ \ --kmer-input kmer_results.zip \ --outdir pred_out \ --seqtype dna \ --mode fast """ from __future__ import annotations import argparse import os import re import shutil import tempfile import zipfile from dataclasses import dataclass from typing import Dict, Iterable, List, Optional, Sequence, Tuple import pandas as pd import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from scipy.stats import fisher_exact from statsmodels.stats.multitest import multipletests from Bio import Align from Bio.Align import substitution_matrices FASTA_EXTS = (".fasta", ".fa", ".fas", ".fna") KMER_FILE_EXTS = (".tsv", ".txt") DEFAULT_GROUP_REGEX = r"unique_k\d+_(.+)\.(tsv|txt)$" BLOSUM62 = substitution_matrices.load("BLOSUM62") # ---------------------------- # FASTA utilities # ---------------------------- def read_fasta(filepath: str) -> Tuple[List[str], List[str]]: headers, seqs, seq = [], [], [] with open(filepath, "r", encoding="utf-8") as fh: for line in fh: line = line.rstrip("\n") if not line: continue if line.startswith(">"): if seq: seqs.append("".join(seq)) seq = [] headers.append(line[1:].strip()) else: seq.append(line.strip().upper()) if seq: seqs.append("".join(seq)) return headers, seqs def clean_protein(seq: str) -> str: return re.sub(r"[^ACDEFGHIKLMNPQRSTVWY]", "", seq.upper()) def clean_dna(seq: str) -> str: # allow U and N like your original return re.sub(r"[^ACGTUN]", "", seq.upper()) def iter_unknown_sequences(unknown: str, is_protein: bool) -> List[Tuple[str, str, str]]: """ Returns list of (source_file, header, cleaned_seq). unknown can be a fasta file or a directory with fasta files. """ seq_index: List[Tuple[str, str, str]] = [] if os.path.isdir(unknown): files = [ os.path.join(unknown, f) for f in os.listdir(unknown) if f.lower().endswith(FASTA_EXTS) ] else: files = [unknown] files = [f for f in files if os.path.isfile(f)] for fp in sorted(files): headers, seqs = read_fasta(fp) if is_protein: seqs = [clean_protein(s) for s in seqs] else: seqs = [clean_dna(s) for s in seqs] for h, s in zip(headers, seqs): if s: # drop empty after cleaning seq_index.append((fp, h, s)) return seq_index # ---------------------------- # ZIP utilities (safe extract) # ---------------------------- def safe_extract_zip(zip_path: str, dst_dir: str) -> None: """Extract ZIP safely (prevents zip-slip).""" with zipfile.ZipFile(zip_path, "r") as z: for member in z.infolist(): if member.is_dir(): continue target = os.path.normpath(os.path.join(dst_dir, member.filename)) if not target.startswith(os.path.abspath(dst_dir) + os.sep): continue # skip suspicious paths os.makedirs(os.path.dirname(target), exist_ok=True) with z.open(member) as src, open(target, "wb") as out: shutil.copyfileobj(src, out) # ---------------------------- # Load unique kmers # ---------------------------- @dataclass class KmerDB: group_kmers: Dict[str, List[str]] source_dir: str def parse_group_from_filename(fname: str, group_regex: str) -> str: m = re.search(group_regex, fname, re.IGNORECASE) if m: return m.group(1) # fallback: remove extension return os.path.splitext(fname)[0] def load_unique_kmers_from_dir( kmer_dir: str, is_protein: bool, group_regex: str = DEFAULT_GROUP_REGEX, ) -> KmerDB: """ Loads k-mers from files like: unique_k15_group1.tsv unique_k20_groupA.txt Accepts TSV or TXT; ignores comment/header lines. """ group_kmers: Dict[str, List[str]] = {} for fname in sorted(os.listdir(kmer_dir)): if not fname.lower().endswith(KMER_FILE_EXTS): continue fpath = os.path.join(kmer_dir, fname) if not os.path.isfile(fpath): continue group = parse_group_from_filename(fname, group_regex) group = group.strip() group_kmers.setdefault(group, []) with open(fpath, "r", encoding="utf-8") as fh: for line in fh: line = line.strip() if (not line) or line.startswith("#"): continue if line.lower().startswith(("kmer", "total")): continue token = line.split()[0].upper() token = clean_protein(token) if is_protein else clean_dna(token) if token: group_kmers[group].append(token) # Deduplicate while preserving order for g in list(group_kmers.keys()): group_kmers[g] = list(dict.fromkeys(group_kmers[g])) if len(group_kmers[g]) == 0: # drop empty groups del group_kmers[g] if not group_kmers: raise FileNotFoundError(f"No k-mer files found in: {kmer_dir}") return KmerDB(group_kmers=group_kmers, source_dir=kmer_dir) def load_unique_kmers(kmer_input: str, is_protein: bool, group_regex: str) -> KmerDB: """ kmer_input can be a directory OR a .zip file containing k-mer output files. """ if os.path.isdir(kmer_input): return load_unique_kmers_from_dir(kmer_input, is_protein, group_regex=group_regex) if os.path.isfile(kmer_input) and kmer_input.lower().endswith(".zip"): tmp = tempfile.mkdtemp(prefix="kmerdb_") safe_extract_zip(kmer_input, tmp) # find a directory inside that actually contains kmer files; simplest: use tmp itself return load_unique_kmers_from_dir(tmp, is_protein, group_regex=group_regex) raise FileNotFoundError(f"--kmer-input must be a directory or a .zip file: {kmer_input}") # ---------------------------- # Matching # ---------------------------- def align_kmer_to_seq( kmer: str, seq: str, is_protein: bool, identity_threshold: float = 0.9, min_coverage: float = 0.8, gap_open: float = -10, gap_extend: float = -0.5, nuc_match: float = 2, nuc_mismatch: float = -1, nuc_gap_open: float = -5, nuc_gap_extend: float = -1, ) -> bool: if not kmer or not seq: return False # Fast exact substring path if identity_threshold == 1.0 and min_coverage == 1.0: return kmer in seq if len(kmer) <= 3: return kmer in seq try: aligner = Align.PairwiseAligner() if is_protein: aligner.substitution_matrix = BLOSUM62 aligner.open_gap_score = gap_open aligner.extend_gap_score = gap_extend else: aligner.match_score = nuc_match aligner.mismatch_score = nuc_mismatch aligner.open_gap_score = nuc_gap_open aligner.extend_gap_score = nuc_gap_extend alns = aligner.align(kmer, seq) if not alns: return False aln = alns[0] aligned_query = aln.aligned[0] aligned_target = aln.aligned[1] aligned_len = sum(e - s for s, e in aligned_query) if aligned_len == 0: return False matches = 0 for (qs, qe), (ts, te) in zip(aligned_query, aligned_target): subseq_q = kmer[qs:qe] subseq_t = seq[ts:te] matches += sum(1 for a, b in zip(subseq_q, subseq_t) if a == b) identity = matches / aligned_len coverage = aligned_len / len(kmer) return (identity >= identity_threshold) and (coverage >= min_coverage) except Exception: return False # ---------------------------- # Prediction core # ---------------------------- def predict( unknown: str, kmer_input: str, output_dir: str, seqtype: str, mode: str, identity_threshold: float, min_coverage: float, fdr_alpha: float, group_regex: str, ) -> pd.DataFrame: is_protein = (seqtype.lower() == "protein") mode = mode.lower().strip() if mode not in {"fast", "full"}: raise ValueError("--mode must be 'fast' or 'full'") # Load kmers (dir or zip) db = load_unique_kmers(kmer_input, is_protein=is_protein, group_regex=group_regex) group_kmers = db.group_kmers print(f"Loaded k-mer counts: { {g: len(group_kmers[g]) for g in group_kmers} }") # Unknown sequences seq_index = iter_unknown_sequences(unknown, is_protein=is_protein) if not seq_index: raise FileNotFoundError("No sequences found in --unknown (file/dir).") # Mode parameters if mode == "fast": eff_identity = 1.0 eff_coverage = 1.0 compute_stats = False else: eff_identity = float(identity_threshold) eff_coverage = float(min_coverage) compute_stats = True results: List[dict] = [] total_seqs = len(seq_index) for i, (srcfile, header, seq) in enumerate(seq_index, start=1): print(f"Processing sequence {i}/{total_seqs} ({os.path.basename(srcfile)})") group_found_counts = {g: 0 for g in group_kmers} total_found = 0 for g, kmers in group_kmers.items(): for kmer in kmers: if align_kmer_to_seq( kmer, seq, is_protein=is_protein, identity_threshold=eff_identity, min_coverage=eff_coverage, ): group_found_counts[g] += 1 total_found += 1 predicted = max(group_found_counts, key=group_found_counts.get) conf_present = (group_found_counts[predicted] / total_found) if total_found else 0.0 conf_vocab = group_found_counts[predicted] / max(1, len(group_kmers[predicted])) row = { "Source_file": os.path.basename(srcfile), "Sequence": header, "Predicted_group": predicted, "Matches_total": total_found, **{f"Matches_{g}": group_found_counts[g] for g in group_kmers}, "Confidence_by_present": conf_present, "Confidence_by_group_vocab": conf_vocab, } if compute_stats: fisher_p = {} # total vocabulary size of "other groups" for contingency table other_vocab_total = {g: sum(len(group_kmers[og]) for og in group_kmers if og != g) for g in group_kmers} for g in group_kmers: a = group_found_counts[g] b = len(group_kmers[g]) - a c = total_found - a d = other_vocab_total[g] - c if d < 0: d = 0 table = [[a, b], [c, d]] _, p = fisher_exact(table, alternative="greater") fisher_p[g] = p row.update({f"FisherP_{g}": fisher_p[g] for g in group_kmers}) results.append(row) df = pd.DataFrame(results) # FDR correction (full mode) if mode == "full": fisher_cols = [c for c in df.columns if c.startswith("FisherP_")] if fisher_cols: all_pvals = df[fisher_cols].values.flatten() _, qvals, _, _ = multipletests(all_pvals, alpha=float(fdr_alpha), method="fdr_bh") qval_matrix = qvals.reshape(df[fisher_cols].shape) for j, col in enumerate(fisher_cols): grp = col.split("_", 1)[1] df[f"FDR_{grp}"] = qval_matrix[:, j] # Save os.makedirs(output_dir, exist_ok=True) out_csv = os.path.join(output_dir, "predictions_by_alignment.csv") df.to_csv(out_csv, index=False) print(f"Saved predictions to {out_csv}") # Plot save_summary_plot(df, output_dir) return df def save_summary_plot(df: pd.DataFrame, output_dir: str) -> None: """ Matplotlib-only summary figure: - Left: predicted group counts - Right: confidence distribution (boxplot) """ fig, axes = plt.subplots(1, 2, figsize=(12, 5)) # Left: bar counts counts = df["Predicted_group"].value_counts() axes[0].bar(counts.index.astype(str), counts.values) axes[0].set_xlabel("Predicted Group") axes[0].set_ylabel("Number of Sequences") axes[0].set_title("Predicted Group Counts") axes[0].tick_params(axis="x", rotation=45) # Right: boxplot confidence_by_present by group groups = sorted(df["Predicted_group"].unique().tolist()) data = [df.loc[df["Predicted_group"] == g, "Confidence_by_present"].values for g in groups] axes[1].boxplot(data, labels=groups, showfliers=False) axes[1].set_title("Prediction Confidence (by Present)") axes[1].set_xlabel("Predicted Group") axes[1].set_ylabel("Confidence") axes[1].tick_params(axis="x", rotation=45) fig.tight_layout() fig.savefig(os.path.join(output_dir, "predicted_results_summary.png"), dpi=300) plt.close(fig) # ---------------------------- # CLI # ---------------------------- def build_parser() -> argparse.ArgumentParser: p = argparse.ArgumentParser( description="Predict group membership of unknown sequences using unique k-mers.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) p.add_argument("--unknown", required=True, help="Unknown FASTA file OR directory of FASTA files.") p.add_argument("--kmer-input", required=True, help="Directory of unique_k*.tsv/txt OR a ZIP containing them.") p.add_argument("--outdir", default="prediction_results", help="Output directory.") p.add_argument("--seqtype", choices=["dna", "protein"], default="dna", help="Sequence type.") p.add_argument("--mode", choices=["fast", "full"], default="fast", help="fast=substring only; full=alignment+Fisher+FDR.") p.add_argument("--identity", type=float, default=0.9, help="Alignment identity threshold (full mode only).") p.add_argument("--coverage", type=float, default=0.8, help="Alignment coverage threshold (full mode only).") p.add_argument("--fdr", type=float, default=0.05, help="FDR alpha (full mode only).") p.add_argument( "--group-regex", default=DEFAULT_GROUP_REGEX, help="Regex to extract group name from k-mer filenames (1st capture group = group).", ) return p def main() -> None: args = build_parser().parse_args() # Validate unknown if not os.path.exists(args.unknown): raise FileNotFoundError(f"--unknown not found: {args.unknown}") # Run predict( unknown=args.unknown, kmer_input=args.kmer_input, output_dir=args.outdir, seqtype=args.seqtype, mode=args.mode, identity_threshold=args.identity, min_coverage=args.coverage, fdr_alpha=args.fdr, group_regex=args.group_regex, ) if __name__ == "__main__": main()