mmrech's picture
feat: v0.2 — real FITS support, TAI/UTC fix, SkyBoT, two-pass bg
41d98e2 verified
"""AsteroidNET Two-Stage Classifier (RF → CNN)."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import numpy as np
from asteroidnet.tracklet_linker.linker import Tracklet
logger = logging.getLogger(__name__)
@dataclass
class Classification:
tracklet: Tracklet
rf_score: float
cnn_score: float
is_asteroid: bool
priority: str # 'ROUTINE' | 'HIGH' | 'HAZARDOUS'
def classify_tracklets(
tracklets: list[Tracklet],
data_frames: list[np.ndarray],
config: Optional[dict] = None,
) -> list[Classification]:
"""
Two-stage classification: Random Forest (fast) → CNN (high precision).
RF stage uses kinematic features of the tracklet.
CNN stage uses 63×63 pixel cutouts from the detection positions.
"""
cfg = (config or {}).get("classifier", {})
rf_thresh = float(cfg.get("rf_threshold", 0.7))
cnn_thresh = float(cfg.get("cnn_threshold", 0.9))
rf_model = _load_rf(cfg)
cnn_model = _load_cnn(cfg)
results: list[Classification] = []
for tracklet in tracklets:
# ── RF stage ─────────────────────────────────────────────────────
features = _extract_features(tracklet)
rf_score = _rf_predict(rf_model, features)
if rf_score < rf_thresh:
continue
# ── Satellite filter ──────────────────────────────────────────────
if _is_satellite(tracklet):
logger.debug("Satellite filter rejected tracklet (vel=%.3f, pa=%.1f)",
tracklet.velocity_arcsec_s, tracklet.position_angle_deg)
continue
# ── CNN stage ─────────────────────────────────────────────────────
cutouts = _extract_cutouts(tracklet, data_frames)
cnn_score = _cnn_predict(cnn_model, cutouts)
is_asteroid = cnn_score >= cnn_thresh
if not is_asteroid:
continue
priority = _assign_priority(tracklet, cfg)
results.append(Classification(
tracklet=tracklet,
rf_score=rf_score,
cnn_score=cnn_score,
is_asteroid=True,
priority=priority,
))
results.sort(key=lambda c: c.cnn_score, reverse=True)
logger.info("Classification: %d/%d tracklets confirmed", len(results), len(tracklets))
return results
# ── Feature extraction ────────────────────────────────────────────────────────
def _extract_features(t: Tracklet) -> np.ndarray:
"""12-dimensional kinematic feature vector for RF classifier."""
snrs = [d["snr"] for d in t.detections]
mags = [d["mag"] for d in t.detections if d["mag"] < 90]
return np.array([
t.velocity_arcsec_s,
t.velocity_ra_arcsec_s,
t.velocity_dec_arcsec_s,
t.position_angle_deg / 360.0,
t.rms_residual_arcsec,
t.time_span_min,
len(t.detections),
float(np.mean(snrs)) if snrs else 0.0,
float(np.std(snrs)) if len(snrs) > 1 else 0.0,
float(np.mean(mags)) if mags else 25.0,
float(np.ptp(mags)) if len(mags) > 1 else 0.0,
len(set(t.frame_ids)),
], dtype=np.float32)
def _extract_cutouts(
t: Tracklet,
data_frames: list[np.ndarray],
size: int = 63,
) -> Optional[np.ndarray]:
"""Extract stacked cutouts from detection positions."""
if not data_frames:
return None
half = size // 2
cutouts = []
for det in t.detections:
fid = det.get("frame_id", 0)
if fid >= len(data_frames):
continue
data = data_frames[fid]
x, y = int(round(det.get("x", 0))), int(round(det.get("y", 0)))
h, w = data.shape
if x - half < 0 or y - half < 0 or x + half >= w or y + half >= h:
continue
cutout = data[y-half:y+half+1, x-half:x+half+1].copy()
if cutout.shape == (size, size):
finite = cutout[np.isfinite(cutout)]
if len(finite) > 0:
med = np.median(finite); mad = max(np.median(np.abs(finite-med)), 1e-10)
cutout = np.clip((cutout - med) / (3*mad), -3, 3)
cutouts.append(np.nan_to_num(cutout.astype(np.float32)))
return np.stack(cutouts) if cutouts else None
# ── Model loading ─────────────────────────────────────────────────────────────
def _load_rf(cfg: dict):
"""Load RF model if available, else return None (heuristic fallback)."""
path = cfg.get("rf_model_path", "models/rf_classifier.pkl")
if Path(path).exists():
try:
import joblib
return joblib.load(path)
except Exception as exc:
logger.warning("Could not load RF model %s: %s", path, exc)
return None
def _load_cnn(cfg: dict):
"""Load CNN model if available, else return None (heuristic fallback)."""
path = cfg.get("cnn_model_path", "models/cnn_classifier.pth")
if Path(path).exists():
try:
import torch
model = torch.load(path, map_location="cpu")
model.eval()
return model
except Exception as exc:
logger.warning("Could not load CNN model %s: %s", path, exc)
return None
def _rf_predict(model, features: np.ndarray) -> float:
"""RF prediction — heuristic if model not trained yet."""
if model is not None:
try:
p = model.predict_proba(features.reshape(1, -1))[0, 1]
return float(p)
except Exception:
pass
# Heuristic: based on velocity, SNR, residual
vel = features[0]
snr = features[7]
rms = features[4]
score = 0.3
if 0.01 <= vel <= 5.0: score += 0.3
if snr >= 5.0: score += 0.2
if rms <= 0.8: score += 0.2
return min(score, 0.99)
def _cnn_predict(model, cutouts: Optional[np.ndarray]) -> float:
"""CNN prediction — heuristic if model not trained yet."""
if model is not None and cutouts is not None:
try:
import torch
x = torch.from_numpy(cutouts).unsqueeze(0).float()
with torch.no_grad():
out = model(x)
return float(torch.sigmoid(out).mean().item())
except Exception:
pass
# Heuristic: check if any cutout has a point source at center
if cutouts is not None and len(cutouts) > 0:
peaks = [float(np.max(c[28:35, 28:35])) if c.shape == (63, 63)
else float(np.nanmax(c)) for c in cutouts]
return min(0.95, max(0.0, float(np.mean(peaks)) / 3.0 + 0.5))
return 0.5
def _is_satellite(t: Tracklet) -> bool:
"""Simple satellite/aircraft filter."""
if t.velocity_arcsec_s > 8.0:
return True
if t.rms_residual_arcsec < 0.01 and t.velocity_arcsec_s > 5.0:
return True
return False
def _assign_priority(t: Tracklet, cfg: dict) -> str:
high_v = float(cfg.get("high_velocity_threshold", 1.0))
haz_v = float(cfg.get("hazardous_velocity_threshold", 3.0))
if t.velocity_arcsec_s >= haz_v:
return "HAZARDOUS"
if t.velocity_arcsec_s >= high_v:
return "HIGH"
return "ROUTINE"