Pingul commited on
Commit
ddae777
·
verified ·
1 Parent(s): bd3500d

Upload 5 files

Browse files
ExoMACModel/ExoMAC-KKT/exoplanet_class_labels.json ADDED
@@ -0,0 +1 @@
 
 
1
+ ["CANDIDATE", "CONFIRMED", "FALSE POSITIVE"]
ExoMACModel/ExoMAC-KKT/exoplanet_feature_columns.json ADDED
@@ -0,0 +1 @@
 
 
1
+ ["koi_depth", "koi_duration", "koi_impact", "koi_period", "koi_prad", "koi_slogg", "koi_sma", "koi_smet", "koi_snr", "koi_srad", "koi_steff", "duty_cycle", "log_koi_period", "log_koi_depth", "log_koi_snr", "teq_proxy"]
ExoMACModel/ExoMAC-KKT/exoplanet_metadata.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_model_name": "RandomForest",
3
+ "n_features": 16,
4
+ "labels": [
5
+ "CANDIDATE",
6
+ "CONFIRMED",
7
+ "FALSE POSITIVE"
8
+ ],
9
+ "created": "2025-10-05T20:49:33.128163Z"
10
+ }
ExoMACModel/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from ExoMACModel.main import ExoMACModel
2
+
3
+ __all__ = ["ExoMACModel"]
ExoMACModel/main.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import threading
6
+ from typing import Dict, Tuple, Optional, List
7
+ from pathlib import Path
8
+
9
+ import joblib
10
+ import numpy as np
11
+ import pandas as pd
12
+ from huggingface_hub import hf_hub_download, snapshot_download
13
+
14
+
15
+ class _Singleton(type):
16
+ """Thread-safe Singleton metaclass (una instancia por proceso)."""
17
+ _instances: Dict[type, object] = {}
18
+ _lock = threading.Lock()
19
+
20
+ def __call__(cls, *args, **kwargs):
21
+ # Double-checked locking
22
+ if cls not in cls._instances:
23
+ with cls._lock:
24
+ if cls not in cls._instances:
25
+ cls._instances[cls] = super().__call__(*args, **kwargs)
26
+ return cls._instances[cls]
27
+
28
+
29
+ class ExoMACModel(metaclass=_Singleton):
30
+ """
31
+ Misión-agnóstico: cargador de modelo (Pipeline sklearn) entrenado con Kepler/K2/TESS.
32
+ - Descarga artefactos desde Hugging Face SOLO si no existen localmente.
33
+ - Guarda/lee desde una carpeta local del proyecto (por defecto: ./models/ExoMAC-KKT).
34
+ - Exposición de helpers de predicción y de features ingenierizadas.
35
+ """
36
+
37
+ DEFAULT_REPO = "ZapatoProgramming/ExoMAC-KKT"
38
+ _FILENAMES = {
39
+ "model": "exoplanet_best_model.joblib",
40
+ "feats": "exoplanet_feature_columns.json",
41
+ "labels": "exoplanet_class_labels.json",
42
+ "meta": "exoplanet_metadata.json",
43
+ }
44
+
45
+ def __init__(
46
+ self,
47
+ repo_id: Optional[str] = None,
48
+ token: Optional[str] = None,
49
+ prefer_snapshot: bool = True,
50
+ allow_patterns: Optional[List[str]] = None,
51
+ local_dir: Optional[str | os.PathLike] = None,
52
+ always_download: bool = False,
53
+ verbose: bool = True,
54
+ ):
55
+ """
56
+ Args:
57
+ repo_id: Hugging Face repo id. Por defecto 'ZapatoProgramming/ExoMAC-KKT'.
58
+ token: Token HF si el repo es privado.
59
+ prefer_snapshot: Si True, usa snapshot_download (descarga por patrón).
60
+ allow_patterns: Patrones a descargar cuando prefer_snapshot=True.
61
+ local_dir: Carpeta donde se guardan/leen artefactos en tu proyecto.
62
+ always_download: Si True, fuerza descarga (útil para actualizar).
63
+ verbose: Imprime mensajes útiles.
64
+ """
65
+ self.repo_id = repo_id or self.DEFAULT_REPO
66
+ self.token = token
67
+ self.prefer_snapshot = prefer_snapshot
68
+ self.allow_patterns = allow_patterns or ["artifacts/*", "*.joblib", "*.json"]
69
+ self.local_dir = Path(local_dir or (Path("models") / self.repo_id.split("/")[-1]))
70
+ self.local_dir.mkdir(parents=True, exist_ok=True)
71
+ self.always_download = always_download
72
+ self.verbose = verbose
73
+
74
+ self._model = None
75
+ self._feature_columns: List[str] = []
76
+ self._class_labels: List[str] = []
77
+ self._metadata: Dict = {}
78
+
79
+ self._load_artifacts()
80
+
81
+ # ------------------------- PUBLIC API -------------------------
82
+
83
+ @property
84
+ def model(self):
85
+ return self._model
86
+
87
+ @property
88
+ def feature_columns(self) -> List[str]:
89
+ return list(self._feature_columns)
90
+
91
+ @property
92
+ def class_labels(self) -> List[str]:
93
+ return list(self._class_labels)
94
+
95
+ @property
96
+ def metadata(self) -> Dict:
97
+ return dict(self._metadata)
98
+
99
+ def predict(
100
+ self,
101
+ params: Dict[str, float],
102
+ return_proba: bool = True,
103
+ compute_engineered_if_missing: bool = True,
104
+ ) -> Tuple[str, Optional[Dict[str, float]]]:
105
+ """
106
+ Predice una etiqueta y (opcionalmente) probabilidades para un dict de features.
107
+ - Rellena features ingenierizadas si el modelo las espera y no están.
108
+ """
109
+ if compute_engineered_if_missing:
110
+ params = self._ensure_engineered_features(dict(params))
111
+
112
+ X = pd.DataFrame([params], dtype=float).reindex(columns=self._feature_columns)
113
+ y_idx = int(self._model.predict(X)[0])
114
+ label = self._class_labels[y_idx]
115
+
116
+ if not return_proba:
117
+ return label, None
118
+
119
+ proba = None
120
+ try:
121
+ p = self._model.predict_proba(X)[0]
122
+ proba = {lbl: float(prob) for lbl, prob in zip(self._class_labels, p)}
123
+ except Exception:
124
+ pass
125
+ return label, proba
126
+
127
+ def predict_with_debug(self, params: Dict[str, float]) -> Tuple[str, Optional[Dict[str, float]]]:
128
+ """
129
+ Igual que predict(), pero imprime features reconocidas/desconocidas y faltantes.
130
+ """
131
+ params2 = self._ensure_engineered_features(dict(params))
132
+ X = pd.DataFrame([params2], dtype=float).reindex(columns=self._feature_columns)
133
+
134
+ recognized = [c for c in self._feature_columns if c in params2]
135
+ unknown = [k for k in params2.keys() if k not in self._feature_columns]
136
+ missing = X.columns[X.iloc[0].isna()].tolist()
137
+
138
+ print(f"Recognized: {len(recognized)}/{len(self._feature_columns)}")
139
+ if recognized:
140
+ print(" •", ", ".join(recognized[:16]) + (" ..." if len(recognized) > 16 else ""))
141
+ if unknown:
142
+ print(f"Unknown keys: {len(unknown)}")
143
+ if unknown:
144
+ print(" •", ", ".join(unknown[:16]) + (" ..." if len(unknown) > 16 else ""))
145
+ if missing:
146
+ print(f"Missing (imputed): {len(missing)}")
147
+ if missing:
148
+ print(" •", ", ".join(missing[:16]) + (" ..." if len(missing) > 16 else ""))
149
+
150
+ return self.predict(params2, return_proba=True, compute_engineered_if_missing=False)
151
+
152
+ # ------------------------- INTERNALS -------------------------
153
+
154
+ def _load_artifacts(self) -> None:
155
+ """
156
+ 1) Si ya existen archivos locales y always_download=False -> NO descarga.
157
+ 2) Si faltan archivos o always_download=True -> descarga (snapshot o per-file).
158
+ 3) Carga el modelo + metadata desde disco.
159
+ """
160
+ paths: Optional[Dict[str, str]] = None
161
+
162
+ # (0) Intentar leer desde local sin tocar red
163
+ if not self.always_download:
164
+ local_paths = self._try_local_paths()
165
+ if local_paths is not None:
166
+ paths = local_paths
167
+ if self.verbose:
168
+ print(f"[ExoMAC] Using cached artifacts in {self.local_dir}")
169
+ else:
170
+ if self.verbose:
171
+ print(f"[ExoMAC] Local artifacts not found. Will download to {self.local_dir}.")
172
+
173
+ # (1) Descargar si hace falta
174
+ if paths is None:
175
+ if self.prefer_snapshot:
176
+ # Descarga patrones a la carpeta local (la API ya no usa symlinks)
177
+ snapshot_download(
178
+ repo_id=self.repo_id,
179
+ token=self.token,
180
+ allow_patterns=self.allow_patterns,
181
+ local_dir=str(self.local_dir),
182
+ )
183
+ paths = self._resolve_from_dir(self.local_dir)
184
+ else:
185
+ paths = {}
186
+ for key, fname in self._FILENAMES.items():
187
+ paths[key] = self._get_artifact_to_local_dir(fname)
188
+
189
+ # (2) Cargar desde disco
190
+ self._model = joblib.load(paths["model"])
191
+ self._feature_columns = json.load(open(paths["feats"], "r", encoding="utf-8"))
192
+ self._class_labels = json.load(open(paths["labels"], "r", encoding="utf-8"))
193
+ self._metadata = json.load(open(paths["meta"], "r", encoding="utf-8"))
194
+
195
+ if self.verbose:
196
+ print(f"[ExoMAC] Loaded model from {paths['model']}")
197
+
198
+ # --- Local path helpers ---
199
+
200
+ def _have_all_files(self, base: Path) -> bool:
201
+ """¿Están TODOS los artefactos (en artifacts/ o raíz) en 'base'?"""
202
+ base = Path(base)
203
+ for _, name in self._FILENAMES.items():
204
+ p1 = base / "artifacts" / name
205
+ p2 = base / name
206
+ if not (p1.exists() or p2.exists()):
207
+ return False
208
+ return True
209
+
210
+ def _try_local_paths(self) -> Optional[Dict[str, str]]:
211
+ """Devuelve rutas locales si todo existe; si falta algo, None."""
212
+ if self._have_all_files(self.local_dir):
213
+ return self._resolve_from_dir(self.local_dir)
214
+ return None
215
+
216
+ def _resolve_from_dir(self, base_dir: Path | str) -> Dict[str, str]:
217
+ """
218
+ Selecciona artifacts/<name> si existe; si no, <base>/<name>.
219
+ """
220
+ base_dir = Path(base_dir)
221
+ out: Dict[str, str] = {}
222
+ for key, name in self._FILENAMES.items():
223
+ p1 = base_dir / "artifacts" / name
224
+ p2 = base_dir / name
225
+ if p1.exists():
226
+ out[key] = str(p1)
227
+ elif p2.exists():
228
+ out[key] = str(p2)
229
+ else:
230
+ raise FileNotFoundError(f"Could not find {name} under {base_dir}")
231
+ return out
232
+
233
+ def _get_artifact_to_local_dir(self, fname: str) -> str:
234
+ """
235
+ Descarga a self.local_dir con hf_hub_download (si tu versión soporta local_dir).
236
+ Si no, descarga a la caché global y copia a self.local_dir.
237
+ """
238
+ self.local_dir.mkdir(parents=True, exist_ok=True)
239
+
240
+ for candidate in (f"artifacts/{fname}", fname):
241
+ try:
242
+ # huggingface_hub >= 0.23 soporta local_dir
243
+ path = hf_hub_download(
244
+ repo_id=self.repo_id,
245
+ filename=candidate,
246
+ token=self.token,
247
+ local_dir=str(self.local_dir),
248
+ )
249
+ return path
250
+ except TypeError:
251
+ # Fallback: versión antigua sin local_dir
252
+ cache_path = hf_hub_download(
253
+ repo_id=self.repo_id,
254
+ filename=candidate,
255
+ token=self.token,
256
+ )
257
+ dst = self.local_dir / Path(candidate).name
258
+ os.makedirs(self.local_dir, exist_ok=True)
259
+ if not os.path.exists(dst):
260
+ from shutil import copy2
261
+ copy2(cache_path, dst)
262
+ return str(dst)
263
+ except Exception:
264
+ # prueba siguiente candidato (raíz en lugar de artifacts/)
265
+ continue
266
+
267
+ raise FileNotFoundError(f"Could not download {fname} from {self.repo_id}")
268
+
269
+ # --- Engineered features helpers ---
270
+
271
+ def _ensure_engineered_features(self, d: Dict[str, float]) -> Dict[str, float]:
272
+ """
273
+ Rellena features ingenierizadas si el modelo las espera y no están:
274
+ - duty_cycle, log_koi_period, log_koi_depth, teq_proxy
275
+ - koi_snr/log_koi_snr o snr_proxy/log_snr_proxy (proxy)
276
+ """
277
+ need = set(self._feature_columns)
278
+
279
+ # Duty cycle
280
+ if "duty_cycle" in need and "duty_cycle" not in d:
281
+ if all(k in d for k in ("koi_duration", "koi_period")) and d.get("koi_period"):
282
+ d["duty_cycle"] = d["koi_duration"] / (d["koi_period"] * 24.0)
283
+
284
+ # Logs
285
+ if "log_koi_period" in need and "log_koi_period" not in d and d.get("koi_period", 0) > 0:
286
+ d["log_koi_period"] = np.log10(d["koi_period"])
287
+ if "log_koi_depth" in need and "log_koi_depth" not in d and d.get("koi_depth", 0) > 0:
288
+ d["log_koi_depth"] = np.log10(d["koi_depth"])
289
+
290
+ # teq_proxy (simple)
291
+ if "teq_proxy" in need and "teq_proxy" not in d and "koi_steff" in d:
292
+ d["teq_proxy"] = d["koi_steff"]
293
+
294
+ # SNR real o proxy
295
+ if "koi_snr" in need and "koi_snr" not in d:
296
+ d["koi_snr"] = np.nan
297
+ if "log_koi_snr" in need and "log_koi_snr" not in d and d.get("koi_snr", 0) > 0:
298
+ d["log_koi_snr"] = np.log10(d["koi_snr"])
299
+
300
+ if "snr_proxy" in need and "snr_proxy" not in d:
301
+ if all(k in d for k in ("koi_depth", "koi_duration", "koi_period")) and d.get("koi_period", 0) > 0:
302
+ d["snr_proxy"] = d["koi_depth"] * np.sqrt(max(d["koi_duration"] / (d["koi_period"] * 24.0), 1e-12))
303
+ if "log_snr_proxy" in need and "log_snr_proxy" not in d and d.get("snr_proxy", 0) > 0:
304
+ d["log_snr_proxy"] = np.log10(d["snr_proxy"])
305
+
306
+ return d
307
+
308
+
309
+ # ------------------------- DEMO -------------------------
310
+ if __name__ == "__main__":
311
+ # Primera ejecución: descargará a ./models/ExoMAC-KKT si no existe.
312
+ model = ExoMACModel(
313
+ local_dir="./ExoMACModel/ExoMAC-KKT",
314
+ prefer_snapshot=True,
315
+ always_download=False, # <- ejecuciones siguientes NO vuelven a descargar
316
+ verbose=True,
317
+ )
318
+
319
+ # Subsecuentes: misma instancia (singleton) y SIN descarga.
320
+ same_model = ExoMACModel(local_dir="./ExoMACModel/ExoMAC-KKT")
321
+ assert model is same_model
322
+
323
+ # Ejemplo mínimo de predicción
324
+ params = {
325
+ "koi_period": 12.0, "koi_duration": 3.5, "koi_depth": 600.0, "koi_impact": 0.20,
326
+ "koi_prad": 2.1, "koi_slogg": 4.4, "koi_sma": 0.10, "koi_smet": 0.0,
327
+ "koi_srad": 1.0, "koi_steff": 5700.0, "koi_snr": 12.0,
328
+ }
329
+ label, proba = model.predict_with_debug(params)
330
+ print("Predicted:", label)
331
+ print("Local dir:", model.local_dir.resolve())