File size: 6,514 Bytes
21ce98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python3
import io, os, re, math, zipfile
from typing import Dict, List, Tuple, Set, Optional
import pandas as pd
from Bio import SeqIO
from statsmodels.stats.multitest import multipletests
from scipy.stats import fisher_exact
import matplotlib.pyplot as plt

FA_EXT = (".fasta", ".fa", ".fas", ".fna")

def _read_fasta_bytes(name: str, data: bytes) -> List[Tuple[str, str, str]]:
    recs = []
    with io.BytesIO(data) as bio:
        for rec in SeqIO.parse(io.TextIOWrapper(bio, encoding="utf-8"), "fasta"):
            header = str(rec.id)
            seq = str(rec.seq).upper().replace("\n", "").replace("\r", "")
            recs.append((name, header, seq))
    return recs

def read_uploaded_fasta_or_zip(uploaded_file) -> List[Tuple[str, str, str]]:
    if uploaded_file is None:
        return []
    name = uploaded_file.name
    data = uploaded_file.read()
    if name.lower().endswith(".zip"):
        results = []
        with zipfile.ZipFile(io.BytesIO(data)) as z:
            for zi in z.infolist():
                if zi.is_dir(): continue
                if not any(zi.filename.lower().endswith(ext) for ext in FA_EXT):
                    continue
                file_bytes = z.read(zi.filename)
                results.extend(_read_fasta_bytes(os.path.basename(zi.filename), file_bytes))
        return results
    else:
        return _read_fasta_bytes(os.path.basename(name), data)

def clean_protein(seq: str) -> str:
    return re.sub(r"[^ACDEFGHIKLMNPQRSTVWY]", "", seq.upper())

def clean_dna(seq: str) -> str:
    return re.sub(r"[^ACGTUN]", "", seq.upper())

def get_kmers_noN(sequence: str, k: int) -> List[str]:
    s = sequence
    out = []
    L = len(s)
    for i in range(L - k + 1):
        kmer = s[i:i+k]
        if "N" not in kmer:
            out.append(kmer)
    return out

def parse_k_input(k_input: str, default_single: int) -> List[int]:
    k_input = (k_input or "").strip()
    if not k_input:
        return [default_single]
    if "-" in k_input:
        a, b = k_input.split("-", 1)
        a = int(a.strip()); b = int(b.strip())
        if a > b: a, b = b, a
        return list(range(a, b+1))
    if "," in k_input:
        return [int(x.strip()) for x in k_input.split(",") if x.strip()]
    return [int(k_input)]

def derive_serotype_names_from_sources(known_records: List[Tuple[str, str, str]]) -> Dict[str, str]:
    counts: Dict[str, int] = {}
    for src, header, _ in known_records:
        counts[src] = counts.get(src, 0) + 1
    name_map: Dict[str, str] = {}
    for src, header, _ in known_records:
        if counts.get(src, 0) == 1:
            sero = os.path.splitext(os.path.basename(src))[0]
        else:
            sero = header.split()[0]
        name_map[header] = sero
    return name_map

def compute_unique_kmers_per_serotype(serotype_to_seq: Dict[str, str], is_protein: bool, k_values: List[int]) -> Dict[str, Dict[int, Set[str]]]:
    all_sets: Dict[str, Dict[int, Set[str]]] = {g: {} for g in serotype_to_seq}
    for g, seq in serotype_to_seq.items():
        seq = clean_protein(seq) if is_protein else clean_dna(seq)
        for k in k_values:
            all_sets[g][k] = set(get_kmers_noN(seq, k))
    unique: Dict[str, Dict[int, Set[str]]] = {g: {k: set() for k in k_values} for g in serotype_to_seq}
    for k in k_values:
        union_all = set().union(*(all_sets[g][k] for g in all_sets))
        for g in all_sets:
            others_union = union_all - all_sets[g][k]
            unique[g][k] = all_sets[g][k] - others_union
    return unique

def classify_unknown_sequences(unknown_records: List[Tuple[str, str, str]], unique_kmers: Dict[str, Dict[int, Set[str]]], is_protein: bool, fdr_alpha: float = 0.05) -> pd.DataFrame:
    vocab_by_sero: Dict[str, int] = {}
    k_values = sorted({k for g in unique_kmers for k in unique_kmers[g]})
    for g in unique_kmers:
        vocab_by_sero[g] = sum(len(unique_kmers[g][k]) for k in k_values)

    results = []
    for src, header, seq in unknown_records:
        seq2 = clean_protein(seq) if is_protein else clean_dna(seq)
        unk_kmers: Dict[int, Set[str]] = {}
        for k in k_values:
            unk_kmers[k] = set(get_kmers_noN(seq2, k))

        match_counts: Dict[str, int] = {}
        total_matches = 0
        for g in unique_kmers:
            mg = 0
            for k in k_values:
                mg += len(unique_kmers[g][k].intersection(unk_kmers[k]))
            match_counts[g] = mg
            total_matches += mg

        if total_matches == 0:
            predicted = "NoMatch"; conf_present = 0.0; conf_vocab = 0.0
        else:
            predicted = max(match_counts, key=match_counts.get)
            conf_present = match_counts[predicted] / total_matches
            conf_vocab = match_counts[predicted] / max(1, vocab_by_sero[predicted])

        fisher_p = {}
        if total_matches > 0:
            sum_vocab_all = sum(vocab_by_sero.values())
            for g in unique_kmers:
                a = match_counts[g]
                b = vocab_by_sero[g] - a
                c = total_matches - a
                d = (sum_vocab_all - vocab_by_sero[g]) - c
                a = max(0, a); b = max(0, b); c = max(0, c); d = max(0, d)
                _, p = fisher_exact([[a, b], [c, d]], alternative="greater")
                fisher_p[g] = p
            groups = list(unique_kmers.keys())
            pvals = [fisher_p[g] for g in groups]
            _, qvals, _, _ = multipletests(pvals, alpha=fdr_alpha, method="fdr_bh")
            fdr_map = {g: q for g, q in zip(groups, qvals)}
        else:
            fisher_p = {g: 1.0 for g in unique_kmers}
            fdr_map = {g: 1.0 for g in unique_kmers}

        row = {"Source": src, "Sequence": header, "Predicted_serotype": predicted, "Matches_total": total_matches, "Confidence_by_present": conf_present, "Confidence_by_serotype_vocab": conf_vocab}
        for g in unique_kmers:
            row[f"Matches_{g}"] = match_counts[g]
            row[f"FisherP_{g}"] = fisher_p[g]
            row[f"FDR_{g}"] = fdr_map[g]
        results.append(row)

    return pd.DataFrame(results)

def plot_counts_by_serotype(simple_df: pd.DataFrame):
    fig = plt.figure(figsize=(8,5))
    ax = fig.add_subplot(111)
    counts = simple_df["Predicted_serotype"].value_counts()
    ax.bar(counts.index.astype(str), counts.values)
    ax.set_xlabel("Predicted serotype")
    ax.set_ylabel("Number of sequences")
    ax.set_title("Predicted serotype counts")
    fig.tight_layout()
    return fig