File size: 3,597 Bytes
d183835
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# engine/genus_predictor.py
"""
Genus-level ML prediction using the XGBoost model trained in Stage 12D.

This module loads:
    models/genus_xgb.json
    models/genus_xgb_meta.json

And exposes:
    predict_genus_from_fused(fused_fields)

Which returns a list of tuples:
    [
        (genus_name, probability_float, confidence_label),
        ...
    ]

Where confidence_label is one of:
    - "Excellent Identification"   (>= 0.90)
    - "Good Identification"        (>= 0.80)
    - "Acceptable Identification"  (>= 0.65)
    - "Low Discrimination"         (< 0.65)
"""

from __future__ import annotations

import os
import json
from typing import Dict, Any, List, Tuple

import numpy as np
import xgboost as xgb

from .features import extract_feature_vector


# Paths
_MODEL_PATH = "models/genus_xgb.json"
_META_PATH = "models/genus_xgb_meta.json"


# ----------------------------------------------------------------------
# Lazy load model + metadata — only loads once globally
# ----------------------------------------------------------------------

_MODEL = None
_META = None
_IDX_TO_GENUS = None
_NUM_FEATURES = None
_NUM_CLASSES = None


def _lazy_load():
    """Load model and metadata only once."""
    global _MODEL, _META, _IDX_TO_GENUS, _NUM_FEATURES, _NUM_CLASSES

    if _MODEL is not None:
        return

    if not os.path.exists(_MODEL_PATH):
        raise FileNotFoundError(f"Genus model not found at '{_MODEL_PATH}'.")

    if not os.path.exists(_META_PATH):
        raise FileNotFoundError(f"Genus meta file not found at '{_META_PATH}'.")

    # Load model
    _MODEL = xgb.Booster()
    _MODEL.load_model(_MODEL_PATH)

    # Load metadata
    with open(_META_PATH, "r", encoding="utf-8") as f:
        _META = json.load(f)

    _IDX_TO_GENUS = {int(k): v for k, v in _META["idx_to_genus"].items()}
    _NUM_FEATURES = _META["n_features"]
    _NUM_CLASSES = _META["num_classes"]


# ----------------------------------------------------------------------
# Confidence label assignment
# ----------------------------------------------------------------------

def _confidence_band(p: float) -> str:
    if p >= 0.90:
        return "Excellent Identification"
    if p >= 0.80:
        return "Good Identification"
    if p >= 0.65:
        return "Acceptable Identification"
    return "Low Discrimination"


# ----------------------------------------------------------------------
# Public prediction function
# ----------------------------------------------------------------------

def predict_genus_from_fused(
    fused_fields: Dict[str, Any],
    top_k: int = 10
) -> List[Tuple[str, float, str]]:
    """
    Predict genus from fused fields using the trained XGBoost model.

    Returns top_k results sorted by probability:
        [(genus_name, probability_float, confidence_label), ...]
    """
    _lazy_load()

    # Build feature vector
    vec = extract_feature_vector(fused_fields)
    if vec.shape[0] != _NUM_FEATURES:
        # Defensive: mismatch in schema → pad or trim
        fixed = np.zeros(_NUM_FEATURES, dtype=float)
        m = min(len(vec), _NUM_FEATURES)
        fixed[:m] = vec[:m]
        vec = fixed

    dmat = xgb.DMatrix(vec.reshape(1, -1))
    probs = _MODEL.predict(dmat)[0]  # shape: (num_classes,)

    # Build list of (genus, prob, band)
    results = []
    for idx, p in enumerate(probs):
        genus = _IDX_TO_GENUS.get(idx, f"Class_{idx}")
        results.append((genus, float(p), _confidence_band(float(p))))

    # Sort by probability, descending
    results.sort(key=lambda x: x[1], reverse=True)

    return results[:top_k]