Spaces:
Sleeping
Sleeping
| """AsteroidNET Two-Stage Classifier (RF → CNN).""" | |
| from __future__ import annotations | |
| import logging | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Optional | |
| import numpy as np | |
| from asteroidnet.tracklet_linker.linker import Tracklet | |
| logger = logging.getLogger(__name__) | |
| class Classification: | |
| tracklet: Tracklet | |
| rf_score: float | |
| cnn_score: float | |
| is_asteroid: bool | |
| priority: str # 'ROUTINE' | 'HIGH' | 'HAZARDOUS' | |
| def classify_tracklets( | |
| tracklets: list[Tracklet], | |
| data_frames: list[np.ndarray], | |
| config: Optional[dict] = None, | |
| ) -> list[Classification]: | |
| """ | |
| Two-stage classification: Random Forest (fast) → CNN (high precision). | |
| RF stage uses kinematic features of the tracklet. | |
| CNN stage uses 63×63 pixel cutouts from the detection positions. | |
| """ | |
| cfg = (config or {}).get("classifier", {}) | |
| rf_thresh = float(cfg.get("rf_threshold", 0.7)) | |
| cnn_thresh = float(cfg.get("cnn_threshold", 0.9)) | |
| rf_model = _load_rf(cfg) | |
| cnn_model = _load_cnn(cfg) | |
| results: list[Classification] = [] | |
| for tracklet in tracklets: | |
| # ── RF stage ───────────────────────────────────────────────────── | |
| features = _extract_features(tracklet) | |
| rf_score = _rf_predict(rf_model, features) | |
| if rf_score < rf_thresh: | |
| continue | |
| # ── Satellite filter ────────────────────────────────────────────── | |
| if _is_satellite(tracklet): | |
| logger.debug("Satellite filter rejected tracklet (vel=%.3f, pa=%.1f)", | |
| tracklet.velocity_arcsec_s, tracklet.position_angle_deg) | |
| continue | |
| # ── CNN stage ───────────────────────────────────────────────────── | |
| cutouts = _extract_cutouts(tracklet, data_frames) | |
| cnn_score = _cnn_predict(cnn_model, cutouts) | |
| is_asteroid = cnn_score >= cnn_thresh | |
| if not is_asteroid: | |
| continue | |
| priority = _assign_priority(tracklet, cfg) | |
| results.append(Classification( | |
| tracklet=tracklet, | |
| rf_score=rf_score, | |
| cnn_score=cnn_score, | |
| is_asteroid=True, | |
| priority=priority, | |
| )) | |
| results.sort(key=lambda c: c.cnn_score, reverse=True) | |
| logger.info("Classification: %d/%d tracklets confirmed", len(results), len(tracklets)) | |
| return results | |
| # ── Feature extraction ──────────────────────────────────────────────────────── | |
| def _extract_features(t: Tracklet) -> np.ndarray: | |
| """12-dimensional kinematic feature vector for RF classifier.""" | |
| snrs = [d["snr"] for d in t.detections] | |
| mags = [d["mag"] for d in t.detections if d["mag"] < 90] | |
| return np.array([ | |
| t.velocity_arcsec_s, | |
| t.velocity_ra_arcsec_s, | |
| t.velocity_dec_arcsec_s, | |
| t.position_angle_deg / 360.0, | |
| t.rms_residual_arcsec, | |
| t.time_span_min, | |
| len(t.detections), | |
| float(np.mean(snrs)) if snrs else 0.0, | |
| float(np.std(snrs)) if len(snrs) > 1 else 0.0, | |
| float(np.mean(mags)) if mags else 25.0, | |
| float(np.ptp(mags)) if len(mags) > 1 else 0.0, | |
| len(set(t.frame_ids)), | |
| ], dtype=np.float32) | |
| def _extract_cutouts( | |
| t: Tracklet, | |
| data_frames: list[np.ndarray], | |
| size: int = 63, | |
| ) -> Optional[np.ndarray]: | |
| """Extract stacked cutouts from detection positions.""" | |
| if not data_frames: | |
| return None | |
| half = size // 2 | |
| cutouts = [] | |
| for det in t.detections: | |
| fid = det.get("frame_id", 0) | |
| if fid >= len(data_frames): | |
| continue | |
| data = data_frames[fid] | |
| x, y = int(round(det.get("x", 0))), int(round(det.get("y", 0))) | |
| h, w = data.shape | |
| if x - half < 0 or y - half < 0 or x + half >= w or y + half >= h: | |
| continue | |
| cutout = data[y-half:y+half+1, x-half:x+half+1].copy() | |
| if cutout.shape == (size, size): | |
| finite = cutout[np.isfinite(cutout)] | |
| if len(finite) > 0: | |
| med = np.median(finite); mad = max(np.median(np.abs(finite-med)), 1e-10) | |
| cutout = np.clip((cutout - med) / (3*mad), -3, 3) | |
| cutouts.append(np.nan_to_num(cutout.astype(np.float32))) | |
| return np.stack(cutouts) if cutouts else None | |
| # ── Model loading ───────────────────────────────────────────────────────────── | |
| def _load_rf(cfg: dict): | |
| """Load RF model if available, else return None (heuristic fallback).""" | |
| path = cfg.get("rf_model_path", "models/rf_classifier.pkl") | |
| if Path(path).exists(): | |
| try: | |
| import joblib | |
| return joblib.load(path) | |
| except Exception as exc: | |
| logger.warning("Could not load RF model %s: %s", path, exc) | |
| return None | |
| def _load_cnn(cfg: dict): | |
| """Load CNN model if available, else return None (heuristic fallback).""" | |
| path = cfg.get("cnn_model_path", "models/cnn_classifier.pth") | |
| if Path(path).exists(): | |
| try: | |
| import torch | |
| model = torch.load(path, map_location="cpu") | |
| model.eval() | |
| return model | |
| except Exception as exc: | |
| logger.warning("Could not load CNN model %s: %s", path, exc) | |
| return None | |
| def _rf_predict(model, features: np.ndarray) -> float: | |
| """RF prediction — heuristic if model not trained yet.""" | |
| if model is not None: | |
| try: | |
| p = model.predict_proba(features.reshape(1, -1))[0, 1] | |
| return float(p) | |
| except Exception: | |
| pass | |
| # Heuristic: based on velocity, SNR, residual | |
| vel = features[0] | |
| snr = features[7] | |
| rms = features[4] | |
| score = 0.3 | |
| if 0.01 <= vel <= 5.0: score += 0.3 | |
| if snr >= 5.0: score += 0.2 | |
| if rms <= 0.8: score += 0.2 | |
| return min(score, 0.99) | |
| def _cnn_predict(model, cutouts: Optional[np.ndarray]) -> float: | |
| """CNN prediction — heuristic if model not trained yet.""" | |
| if model is not None and cutouts is not None: | |
| try: | |
| import torch | |
| x = torch.from_numpy(cutouts).unsqueeze(0).float() | |
| with torch.no_grad(): | |
| out = model(x) | |
| return float(torch.sigmoid(out).mean().item()) | |
| except Exception: | |
| pass | |
| # Heuristic: check if any cutout has a point source at center | |
| if cutouts is not None and len(cutouts) > 0: | |
| peaks = [float(np.max(c[28:35, 28:35])) if c.shape == (63, 63) | |
| else float(np.nanmax(c)) for c in cutouts] | |
| return min(0.95, max(0.0, float(np.mean(peaks)) / 3.0 + 0.5)) | |
| return 0.5 | |
| def _is_satellite(t: Tracklet) -> bool: | |
| """Simple satellite/aircraft filter.""" | |
| if t.velocity_arcsec_s > 8.0: | |
| return True | |
| if t.rms_residual_arcsec < 0.01 and t.velocity_arcsec_s > 5.0: | |
| return True | |
| return False | |
| def _assign_priority(t: Tracklet, cfg: dict) -> str: | |
| high_v = float(cfg.get("high_velocity_threshold", 1.0)) | |
| haz_v = float(cfg.get("hazardous_velocity_threshold", 3.0)) | |
| if t.velocity_arcsec_s >= haz_v: | |
| return "HAZARDOUS" | |
| if t.velocity_arcsec_s >= high_v: | |
| return "HIGH" | |
| return "ROUTINE" | |