#!/usr/bin/env python3 import io, os, re, math, zipfile from typing import Dict, List, Tuple, Set, Optional import pandas as pd from Bio import SeqIO from statsmodels.stats.multitest import multipletests from scipy.stats import fisher_exact import matplotlib.pyplot as plt FA_EXT = (".fasta", ".fa", ".fas", ".fna") def _read_fasta_bytes(name: str, data: bytes) -> List[Tuple[str, str, str]]: recs = [] with io.BytesIO(data) as bio: for rec in SeqIO.parse(io.TextIOWrapper(bio, encoding="utf-8"), "fasta"): header = str(rec.id) seq = str(rec.seq).upper().replace("\n", "").replace("\r", "") recs.append((name, header, seq)) return recs def read_uploaded_fasta_or_zip(uploaded_file) -> List[Tuple[str, str, str]]: if uploaded_file is None: return [] name = uploaded_file.name data = uploaded_file.read() if name.lower().endswith(".zip"): results = [] with zipfile.ZipFile(io.BytesIO(data)) as z: for zi in z.infolist(): if zi.is_dir(): continue if not any(zi.filename.lower().endswith(ext) for ext in FA_EXT): continue file_bytes = z.read(zi.filename) results.extend(_read_fasta_bytes(os.path.basename(zi.filename), file_bytes)) return results else: return _read_fasta_bytes(os.path.basename(name), data) def clean_protein(seq: str) -> str: return re.sub(r"[^ACDEFGHIKLMNPQRSTVWY]", "", seq.upper()) def clean_dna(seq: str) -> str: return re.sub(r"[^ACGTUN]", "", seq.upper()) def get_kmers_noN(sequence: str, k: int) -> List[str]: s = sequence out = [] L = len(s) for i in range(L - k + 1): kmer = s[i:i+k] if "N" not in kmer: out.append(kmer) return out def parse_k_input(k_input: str, default_single: int) -> List[int]: k_input = (k_input or "").strip() if not k_input: return [default_single] if "-" in k_input: a, b = k_input.split("-", 1) a = int(a.strip()); b = int(b.strip()) if a > b: a, b = b, a return list(range(a, b+1)) if "," in k_input: return [int(x.strip()) for x in k_input.split(",") if x.strip()] return [int(k_input)] def derive_serotype_names_from_sources(known_records: List[Tuple[str, str, str]]) -> Dict[str, str]: counts: Dict[str, int] = {} for src, header, _ in known_records: counts[src] = counts.get(src, 0) + 1 name_map: Dict[str, str] = {} for src, header, _ in known_records: if counts.get(src, 0) == 1: sero = os.path.splitext(os.path.basename(src))[0] else: sero = header.split()[0] name_map[header] = sero return name_map def compute_unique_kmers_per_serotype(serotype_to_seq: Dict[str, str], is_protein: bool, k_values: List[int]) -> Dict[str, Dict[int, Set[str]]]: all_sets: Dict[str, Dict[int, Set[str]]] = {g: {} for g in serotype_to_seq} for g, seq in serotype_to_seq.items(): seq = clean_protein(seq) if is_protein else clean_dna(seq) for k in k_values: all_sets[g][k] = set(get_kmers_noN(seq, k)) unique: Dict[str, Dict[int, Set[str]]] = {g: {k: set() for k in k_values} for g in serotype_to_seq} for k in k_values: union_all = set().union(*(all_sets[g][k] for g in all_sets)) for g in all_sets: others_union = union_all - all_sets[g][k] unique[g][k] = all_sets[g][k] - others_union return unique def classify_unknown_sequences(unknown_records: List[Tuple[str, str, str]], unique_kmers: Dict[str, Dict[int, Set[str]]], is_protein: bool, fdr_alpha: float = 0.05) -> pd.DataFrame: vocab_by_sero: Dict[str, int] = {} k_values = sorted({k for g in unique_kmers for k in unique_kmers[g]}) for g in unique_kmers: vocab_by_sero[g] = sum(len(unique_kmers[g][k]) for k in k_values) results = [] for src, header, seq in unknown_records: seq2 = clean_protein(seq) if is_protein else clean_dna(seq) unk_kmers: Dict[int, Set[str]] = {} for k in k_values: unk_kmers[k] = set(get_kmers_noN(seq2, k)) match_counts: Dict[str, int] = {} total_matches = 0 for g in unique_kmers: mg = 0 for k in k_values: mg += len(unique_kmers[g][k].intersection(unk_kmers[k])) match_counts[g] = mg total_matches += mg if total_matches == 0: predicted = "NoMatch"; conf_present = 0.0; conf_vocab = 0.0 else: predicted = max(match_counts, key=match_counts.get) conf_present = match_counts[predicted] / total_matches conf_vocab = match_counts[predicted] / max(1, vocab_by_sero[predicted]) fisher_p = {} if total_matches > 0: sum_vocab_all = sum(vocab_by_sero.values()) for g in unique_kmers: a = match_counts[g] b = vocab_by_sero[g] - a c = total_matches - a d = (sum_vocab_all - vocab_by_sero[g]) - c a = max(0, a); b = max(0, b); c = max(0, c); d = max(0, d) _, p = fisher_exact([[a, b], [c, d]], alternative="greater") fisher_p[g] = p groups = list(unique_kmers.keys()) pvals = [fisher_p[g] for g in groups] _, qvals, _, _ = multipletests(pvals, alpha=fdr_alpha, method="fdr_bh") fdr_map = {g: q for g, q in zip(groups, qvals)} else: fisher_p = {g: 1.0 for g in unique_kmers} fdr_map = {g: 1.0 for g in unique_kmers} row = {"Source": src, "Sequence": header, "Predicted_serotype": predicted, "Matches_total": total_matches, "Confidence_by_present": conf_present, "Confidence_by_serotype_vocab": conf_vocab} for g in unique_kmers: row[f"Matches_{g}"] = match_counts[g] row[f"FisherP_{g}"] = fisher_p[g] row[f"FDR_{g}"] = fdr_map[g] results.append(row) return pd.DataFrame(results) def plot_counts_by_serotype(simple_df: pd.DataFrame): fig = plt.figure(figsize=(8,5)) ax = fig.add_subplot(111) counts = simple_df["Predicted_serotype"].value_counts() ax.bar(counts.index.astype(str), counts.values) ax.set_xlabel("Predicted serotype") ax.set_ylabel("Number of sequences") ax.set_title("Predicted serotype counts") fig.tight_layout() return fig