hawthorneluke's picture
Production LightGBM model — 49 features, 5-fold GroupKFold CV
8cfa91c verified
"""Feature engineering for Supernova Peak Predictor."""
def engineer_features(row):
"""Extract features from a ZTF alert metadata dict.
Args:
row: dict with ZTF alert fields (magpsf, sigmapsf, etc.)
Returns:
dict of engineered features
"""
feats = {}
feats['magpsf'] = float(row.get('magpsf', 0) or 0)
feats['sigmapsf'] = float(row.get('sigmapsf', 0) or 0)
feats['magap'] = float(row.get('magap', 0) or 0)
feats['sigmagap'] = float(row.get('sigmagap', 0) or 0)
feats['diffmaglim'] = float(row.get('diffmaglim', 0) or 0)
feats['peakmag_so_far'] = float(row.get('peakmag_so_far', 0) or 0)
feats['maxmag_so_far'] = float(row.get('maxmag_so_far', 0) or 0)
feats['mag_range'] = feats['maxmag_so_far'] - feats['peakmag_so_far']
feats['mag_vs_peak'] = feats['magpsf'] - feats['peakmag_so_far']
feats['mag_vs_lim'] = feats['diffmaglim'] - feats['magpsf']
feats['mag_psf_ap_diff'] = feats['magpsf'] - feats['magap']
feats['ndethist'] = float(row.get('ndethist', 0) or 0)
feats['ncovhist'] = float(row.get('ncovhist', 0) or 0)
feats['nnotdet'] = float(row.get('nnotdet', 0) or 0)
feats['nmtchps'] = float(row.get('nmtchps', 0) or 0)
feats['det_fraction'] = feats['ndethist'] / (feats['ncovhist'] + 1)
feats['N'] = float(row.get('N', 0) or 0)
feats['nneg'] = float(row.get('nneg', 0) or 0)
feats['nbad'] = float(row.get('nbad', 0) or 0)
feats['fwhm'] = float(row.get('fwhm', 0) or 0)
feats['chipsf'] = float(row.get('chipsf', 0) or 0)
feats['chinr'] = float(row.get('chinr', 0) or 0)
feats['sharpnr'] = float(row.get('sharpnr', 0) or 0)
feats['scorr'] = float(row.get('scorr', 0) or 0)
feats['sky'] = float(row.get('sky', 0) or 0)
feats['classtar'] = float(row.get('classtar', 0) or 0)
feats['new_drb'] = float(row.get('new_drb', 0) or 0)
feats['drb'] = float(row.get('drb', 0) or 0)
feats['exptime'] = float(row.get('exptime', 30) or 30)
feats['sgscore1'] = float(row.get('sgscore1', 0) or 0)
feats['distpsnr1'] = float(row.get('distpsnr1', 0) or 0)
feats['sgscore2'] = float(row.get('sgscore2', 0) or 0)
feats['distpsnr2'] = float(row.get('distpsnr2', 0) or 0)
feats['distnr'] = float(row.get('distnr', 0) or 0)
feats['magnr'] = float(row.get('magnr', 0) or 0)
feats['mag_vs_host'] = feats['magpsf'] - feats['magnr']
feats['neargaia'] = float(row.get('neargaia', 0) or 0)
v = row.get('neargaia', 0) or 0
feats['neargaia'] = float(v) if float(v) > -998 else 0
feats['maggaia'] = float(row.get('maggaia', 0) or 0)
v = row.get('maggaia', 0) or 0
feats['maggaia'] = float(v) if float(v) > -998 else 0
for col in ['sgmag1', 'srmag1', 'simag1', 'szmag1']:
val = row.get(col, -999) or -999
feats[col] = float(val) if float(val) > -998 else 0
sg, sr, si, sz = feats['sgmag1'], feats['srmag1'], feats['simag1'], feats['szmag1']
feats['host_g_r'] = (sg - sr) if sg > 0 and sr > 0 else 0
feats['host_r_i'] = (sr - si) if sr > 0 and si > 0 else 0
feats['host_i_z'] = (si - sz) if si > 0 and sz > 0 else 0
feats['fid'] = float(row.get('fid', 1))
feats['is_g_band'] = 1.0 if feats['fid'] == 1 else 0.0
feats['ndethist_x_magrange'] = feats['ndethist'] * feats['mag_range']
feats['snr_proxy'] = feats['mag_vs_lim'] / (feats['sigmapsf'] + 0.01)
return feats