"""AsteroidNET Two-Stage Classifier (RF → CNN).""" from __future__ import annotations import logging from dataclasses import dataclass from pathlib import Path from typing import Optional import numpy as np from asteroidnet.tracklet_linker.linker import Tracklet logger = logging.getLogger(__name__) @dataclass class Classification: tracklet: Tracklet rf_score: float cnn_score: float is_asteroid: bool priority: str # 'ROUTINE' | 'HIGH' | 'HAZARDOUS' def classify_tracklets( tracklets: list[Tracklet], data_frames: list[np.ndarray], config: Optional[dict] = None, ) -> list[Classification]: """ Two-stage classification: Random Forest (fast) → CNN (high precision). RF stage uses kinematic features of the tracklet. CNN stage uses 63×63 pixel cutouts from the detection positions. """ cfg = (config or {}).get("classifier", {}) rf_thresh = float(cfg.get("rf_threshold", 0.7)) cnn_thresh = float(cfg.get("cnn_threshold", 0.9)) rf_model = _load_rf(cfg) cnn_model = _load_cnn(cfg) results: list[Classification] = [] for tracklet in tracklets: # ── RF stage ───────────────────────────────────────────────────── features = _extract_features(tracklet) rf_score = _rf_predict(rf_model, features) if rf_score < rf_thresh: continue # ── Satellite filter ────────────────────────────────────────────── if _is_satellite(tracklet): logger.debug("Satellite filter rejected tracklet (vel=%.3f, pa=%.1f)", tracklet.velocity_arcsec_s, tracklet.position_angle_deg) continue # ── CNN stage ───────────────────────────────────────────────────── cutouts = _extract_cutouts(tracklet, data_frames) cnn_score = _cnn_predict(cnn_model, cutouts) is_asteroid = cnn_score >= cnn_thresh if not is_asteroid: continue priority = _assign_priority(tracklet, cfg) results.append(Classification( tracklet=tracklet, rf_score=rf_score, cnn_score=cnn_score, is_asteroid=True, priority=priority, )) results.sort(key=lambda c: c.cnn_score, reverse=True) logger.info("Classification: %d/%d tracklets confirmed", len(results), len(tracklets)) return results # ── Feature extraction ──────────────────────────────────────────────────────── def _extract_features(t: Tracklet) -> np.ndarray: """12-dimensional kinematic feature vector for RF classifier.""" snrs = [d["snr"] for d in t.detections] mags = [d["mag"] for d in t.detections if d["mag"] < 90] return np.array([ t.velocity_arcsec_s, t.velocity_ra_arcsec_s, t.velocity_dec_arcsec_s, t.position_angle_deg / 360.0, t.rms_residual_arcsec, t.time_span_min, len(t.detections), float(np.mean(snrs)) if snrs else 0.0, float(np.std(snrs)) if len(snrs) > 1 else 0.0, float(np.mean(mags)) if mags else 25.0, float(np.ptp(mags)) if len(mags) > 1 else 0.0, len(set(t.frame_ids)), ], dtype=np.float32) def _extract_cutouts( t: Tracklet, data_frames: list[np.ndarray], size: int = 63, ) -> Optional[np.ndarray]: """Extract stacked cutouts from detection positions.""" if not data_frames: return None half = size // 2 cutouts = [] for det in t.detections: fid = det.get("frame_id", 0) if fid >= len(data_frames): continue data = data_frames[fid] x, y = int(round(det.get("x", 0))), int(round(det.get("y", 0))) h, w = data.shape if x - half < 0 or y - half < 0 or x + half >= w or y + half >= h: continue cutout = data[y-half:y+half+1, x-half:x+half+1].copy() if cutout.shape == (size, size): finite = cutout[np.isfinite(cutout)] if len(finite) > 0: med = np.median(finite); mad = max(np.median(np.abs(finite-med)), 1e-10) cutout = np.clip((cutout - med) / (3*mad), -3, 3) cutouts.append(np.nan_to_num(cutout.astype(np.float32))) return np.stack(cutouts) if cutouts else None # ── Model loading ───────────────────────────────────────────────────────────── def _load_rf(cfg: dict): """Load RF model if available, else return None (heuristic fallback).""" path = cfg.get("rf_model_path", "models/rf_classifier.pkl") if Path(path).exists(): try: import joblib return joblib.load(path) except Exception as exc: logger.warning("Could not load RF model %s: %s", path, exc) return None def _load_cnn(cfg: dict): """Load CNN model if available, else return None (heuristic fallback).""" path = cfg.get("cnn_model_path", "models/cnn_classifier.pth") if Path(path).exists(): try: import torch model = torch.load(path, map_location="cpu") model.eval() return model except Exception as exc: logger.warning("Could not load CNN model %s: %s", path, exc) return None def _rf_predict(model, features: np.ndarray) -> float: """RF prediction — heuristic if model not trained yet.""" if model is not None: try: p = model.predict_proba(features.reshape(1, -1))[0, 1] return float(p) except Exception: pass # Heuristic: based on velocity, SNR, residual vel = features[0] snr = features[7] rms = features[4] score = 0.3 if 0.01 <= vel <= 5.0: score += 0.3 if snr >= 5.0: score += 0.2 if rms <= 0.8: score += 0.2 return min(score, 0.99) def _cnn_predict(model, cutouts: Optional[np.ndarray]) -> float: """CNN prediction — heuristic if model not trained yet.""" if model is not None and cutouts is not None: try: import torch x = torch.from_numpy(cutouts).unsqueeze(0).float() with torch.no_grad(): out = model(x) return float(torch.sigmoid(out).mean().item()) except Exception: pass # Heuristic: check if any cutout has a point source at center if cutouts is not None and len(cutouts) > 0: peaks = [float(np.max(c[28:35, 28:35])) if c.shape == (63, 63) else float(np.nanmax(c)) for c in cutouts] return min(0.95, max(0.0, float(np.mean(peaks)) / 3.0 + 0.5)) return 0.5 def _is_satellite(t: Tracklet) -> bool: """Simple satellite/aircraft filter.""" if t.velocity_arcsec_s > 8.0: return True if t.rms_residual_arcsec < 0.01 and t.velocity_arcsec_s > 5.0: return True return False def _assign_priority(t: Tracklet, cfg: dict) -> str: high_v = float(cfg.get("high_velocity_threshold", 1.0)) haz_v = float(cfg.get("hazardous_velocity_threshold", 3.0)) if t.velocity_arcsec_s >= haz_v: return "HAZARDOUS" if t.velocity_arcsec_s >= high_v: return "HIGH" return "ROUTINE"