File size: 4,761 Bytes
5c389ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# filter.py
"""
FunGO β€” Smart Tier Filtering
==============================
Removes generic/root GO terms and assigns evidence tiers
to remaining predictions.

Changes from original:
  1. Tier names updated:
       GOLD   β†’ STRONG     (Strong Evidence)
       GOOD   β†’ MODERATE   (Moderate Evidence)
       SILVER β†’ INDICATIVE
  2. Combined score = ia_weight Γ— confidence
     Used for ranking β€” more scientifically sound.
  3. filter_predictions() returns a dict with two keys:
       "display" β€” top 20 by combined score (for UI screen)
       "all"     β€” full filtered list (for CSV download)
  4. summarise() updated to use new tier keys.
  5. Blacklist + IA/confidence thresholds β†’ completely unchanged.
"""

import logging
from config import (
    BLACKLIST_TERMS,
    TIER_GOLD_IA,   TIER_GOLD_CONF,
    TIER_GOOD_IA,   TIER_GOOD_CONF,
    TIER_SILVER_IA, TIER_SILVER_CONF,
)

logger = logging.getLogger(__name__)

ONT_LABELS = {
    "MFO": "Molecular Function",
    "BPO": "Biological Process",
    "CCO": "Cellular Component",
}

TIER_LABELS = {
    "STRONG":     "Strong Evidence",
    "MODERATE":   "Moderate Evidence",
    "INDICATIVE": "Indicative",
}

TIER_RANK = {"STRONG": 0, "MODERATE": 1, "INDICATIVE": 2}

# Max predictions shown on screen per protein
TOP_N_DISPLAY = 20


def assign_tier(go_term: str, ia: float, confidence: float) -> str:
    """
    Assign evidence tier. Thresholds unchanged from original.

    Returns: "STRONG" | "MODERATE" | "INDICATIVE" | "NOISE"
    """
    if go_term in BLACKLIST_TERMS:
        return "NOISE"
    if ia > TIER_GOLD_IA   and confidence >= TIER_GOLD_CONF:
        return "STRONG"
    if ia > TIER_GOOD_IA   and confidence >= TIER_GOOD_CONF:
        return "MODERATE"
    if ia > TIER_SILVER_IA and confidence >= TIER_SILVER_CONF:
        return "INDICATIVE"
    return "NOISE"


def combined_score(ia: float, confidence: float) -> float:
    """
    Ranking score = ia_weight Γ— confidence.
    Balances specificity (IA) and model certainty (confidence).
    """
    return round(ia * confidence, 6)


def filter_predictions(raw_predictions: list, ia_weights: dict) -> dict:
    """
    Filter raw predictions and return display + full sets.

    Returns
    -------
    {
      "display": top-20 predictions (sorted by combined_score desc),
      "all":     all filtered predictions (for CSV)
    }

    Each prediction dict contains:
      go_term, ontology, ontology_label, confidence, threshold,
      ia_weight, combined_score, tier, tier_rank, tier_label
    """
    filtered = []

    for pred in raw_predictions:
        go_term    = pred["go_term"]
        confidence = pred["confidence"]
        ia         = float(ia_weights.get(go_term, 0.0))
        tier       = assign_tier(go_term, ia, confidence)

        if tier == "NOISE":
            continue

        if tier not in TIER_RANK:
            logger.warning("Unknown tier %r for %s β€” skipping", tier, go_term)
            continue

        score = combined_score(ia, confidence)

        filtered.append({
            **pred,
            "ia_weight":      round(ia, 4),
            "combined_score": score,
            "tier":           tier,
            "tier_rank":      TIER_RANK[tier],
            "tier_label":     TIER_LABELS[tier],
            "ontology_label": ONT_LABELS.get(pred["ontology"], pred["ontology"]),
        })

    # Sort by combined score descending, tier_rank as tiebreaker
    filtered.sort(key=lambda x: (-x["combined_score"], x["tier_rank"]))

    return {
        "display": filtered[:TOP_N_DISPLAY],
        "all":     filtered,
    }


def summarise(filtered_display: list, all_filtered: list, protein_id: str) -> dict:
    """
    Per-protein summary. Counts are over ALL filtered (not just top-20).
    """
    ont_counts  = {"MFO": 0, "BPO": 0, "CCO": 0}
    tier_counts = {"STRONG": 0, "MODERATE": 0, "INDICATIVE": 0}

    for p in all_filtered:
        ont = p.get("ontology", "")
        if ont in ont_counts:
            ont_counts[ont] += 1
        t = p.get("tier", "")
        if t in tier_counts:
            tier_counts[t] += 1

    n = len(all_filtered)
    return {
        "protein_id":          protein_id,
        "total_filtered":      n,
        "displayed":           len(filtered_display),
        "by_ontology":         ont_counts,
        "by_tier":             tier_counts,
        "has_strong_evidence": tier_counts["STRONG"] > 0,
        "avg_confidence":      round(sum(p["confidence"]     for p in all_filtered) / n, 4) if n else 0.0,
        "avg_ia":              round(sum(p["ia_weight"]      for p in all_filtered) / n, 4) if n else 0.0,
        "avg_combined_score":  round(sum(p["combined_score"] for p in all_filtered) / n, 4) if n else 0.0,
    }