EphAsad commited on
Commit
d183835
·
verified ·
1 Parent(s): 822c042

Create genus_predictor.py

Browse files
Files changed (1) hide show
  1. engine/genus_predictor.py +130 -0
engine/genus_predictor.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # engine/genus_predictor.py
2
+ """
3
+ Genus-level ML prediction using the XGBoost model trained in Stage 12D.
4
+
5
+ This module loads:
6
+ models/genus_xgb.json
7
+ models/genus_xgb_meta.json
8
+
9
+ And exposes:
10
+ predict_genus_from_fused(fused_fields)
11
+
12
+ Which returns a list of tuples:
13
+ [
14
+ (genus_name, probability_float, confidence_label),
15
+ ...
16
+ ]
17
+
18
+ Where confidence_label is one of:
19
+ - "Excellent Identification" (>= 0.90)
20
+ - "Good Identification" (>= 0.80)
21
+ - "Acceptable Identification" (>= 0.65)
22
+ - "Low Discrimination" (< 0.65)
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import os
28
+ import json
29
+ from typing import Dict, Any, List, Tuple
30
+
31
+ import numpy as np
32
+ import xgboost as xgb
33
+
34
+ from .features import extract_feature_vector
35
+
36
+
37
+ # Paths
38
+ _MODEL_PATH = "models/genus_xgb.json"
39
+ _META_PATH = "models/genus_xgb_meta.json"
40
+
41
+
42
+ # ----------------------------------------------------------------------
43
+ # Lazy load model + metadata — only loads once globally
44
+ # ----------------------------------------------------------------------
45
+
46
+ _MODEL = None
47
+ _META = None
48
+ _IDX_TO_GENUS = None
49
+ _NUM_FEATURES = None
50
+ _NUM_CLASSES = None
51
+
52
+
53
+ def _lazy_load():
54
+ """Load model and metadata only once."""
55
+ global _MODEL, _META, _IDX_TO_GENUS, _NUM_FEATURES, _NUM_CLASSES
56
+
57
+ if _MODEL is not None:
58
+ return
59
+
60
+ if not os.path.exists(_MODEL_PATH):
61
+ raise FileNotFoundError(f"Genus model not found at '{_MODEL_PATH}'.")
62
+
63
+ if not os.path.exists(_META_PATH):
64
+ raise FileNotFoundError(f"Genus meta file not found at '{_META_PATH}'.")
65
+
66
+ # Load model
67
+ _MODEL = xgb.Booster()
68
+ _MODEL.load_model(_MODEL_PATH)
69
+
70
+ # Load metadata
71
+ with open(_META_PATH, "r", encoding="utf-8") as f:
72
+ _META = json.load(f)
73
+
74
+ _IDX_TO_GENUS = {int(k): v for k, v in _META["idx_to_genus"].items()}
75
+ _NUM_FEATURES = _META["n_features"]
76
+ _NUM_CLASSES = _META["num_classes"]
77
+
78
+
79
+ # ----------------------------------------------------------------------
80
+ # Confidence label assignment
81
+ # ----------------------------------------------------------------------
82
+
83
+ def _confidence_band(p: float) -> str:
84
+ if p >= 0.90:
85
+ return "Excellent Identification"
86
+ if p >= 0.80:
87
+ return "Good Identification"
88
+ if p >= 0.65:
89
+ return "Acceptable Identification"
90
+ return "Low Discrimination"
91
+
92
+
93
+ # ----------------------------------------------------------------------
94
+ # Public prediction function
95
+ # ----------------------------------------------------------------------
96
+
97
+ def predict_genus_from_fused(
98
+ fused_fields: Dict[str, Any],
99
+ top_k: int = 10
100
+ ) -> List[Tuple[str, float, str]]:
101
+ """
102
+ Predict genus from fused fields using the trained XGBoost model.
103
+
104
+ Returns top_k results sorted by probability:
105
+ [(genus_name, probability_float, confidence_label), ...]
106
+ """
107
+ _lazy_load()
108
+
109
+ # Build feature vector
110
+ vec = extract_feature_vector(fused_fields)
111
+ if vec.shape[0] != _NUM_FEATURES:
112
+ # Defensive: mismatch in schema → pad or trim
113
+ fixed = np.zeros(_NUM_FEATURES, dtype=float)
114
+ m = min(len(vec), _NUM_FEATURES)
115
+ fixed[:m] = vec[:m]
116
+ vec = fixed
117
+
118
+ dmat = xgb.DMatrix(vec.reshape(1, -1))
119
+ probs = _MODEL.predict(dmat)[0] # shape: (num_classes,)
120
+
121
+ # Build list of (genus, prob, band)
122
+ results = []
123
+ for idx, p in enumerate(probs):
124
+ genus = _IDX_TO_GENUS.get(idx, f"Class_{idx}")
125
+ results.append((genus, float(p), _confidence_band(float(p))))
126
+
127
+ # Sort by probability, descending
128
+ results.sort(key=lambda x: x[1], reverse=True)
129
+
130
+ return results[:top_k]