jarpalucas commited on
Commit
3415945
·
verified ·
1 Parent(s): a627420

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +260 -0
app.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, json, requests
2
+ from typing import Optional, List, Dict
3
+ import numpy as np
4
+ import pandas as pd
5
+ import joblib
6
+ import tensorflow as tf
7
+ import gradio as gr
8
+
9
+ # ===== Artifacts =====
10
+ MODEL_PATH = "modelo_tabular.h5"
11
+ SCALER_PATH = "scaler.pkl"
12
+ ENCODER_PATH = "label_encoder.pkl"
13
+ STATS_PATH = "feature_stats.json"
14
+
15
+ assert os.path.exists(MODEL_PATH), "Falta modelo_tabular.h5"
16
+ assert os.path.exists(SCALER_PATH), "Falta scaler.pkl"
17
+ assert os.path.exists(ENCODER_PATH), "Falta label_encoder.pkl"
18
+ assert os.path.exists(STATS_PATH), "Falta feature_stats.json"
19
+
20
+ model = tf.keras.models.load_model(MODEL_PATH)
21
+ scaler = joblib.load(SCALER_PATH)
22
+ label_encoder = joblib.load(ENCODER_PATH)
23
+ with open(STATS_PATH) as f:
24
+ stats = json.load(f)
25
+
26
+ FEATURE_COLUMNS: List[str] = stats["feature_columns"]
27
+ MEDIANS: Dict[str, float] = stats["medians"]
28
+ CLASSES = list(label_encoder.classes_)
29
+
30
+ # ===== Helpers =====
31
+ def first_present(candidates, cols_set):
32
+ for c in candidates:
33
+ if c in cols_set:
34
+ return c
35
+ for c in candidates:
36
+ found = [x for x in cols_set if c in x]
37
+ if found:
38
+ return found[0]
39
+ return None
40
+
41
+ CANDIDATES_MAP = {
42
+ "koi_period": ["pl_orbper","tce_period","orbper","period"],
43
+ "koi_duration": ["pl_trandurh","tce_duration","trandur","duration","dur"],
44
+ "koi_depth": ["pl_trandep","tce_depth","depth","trandep"],
45
+ "koi_prad": ["pl_rade","prad","rade","planet_radius"],
46
+ "koi_srad": ["st_rad","srad","stellar_radius","star_radius"],
47
+ "koi_teq": ["pl_eqt","teq","equilibrium_temp"],
48
+ "koi_steff": ["st_teff","teff","stellar_teff","effective_temp"],
49
+ "koi_slogg": ["st_logg","logg","slogg"],
50
+ "koi_smet": ["st_met","feh","metallicity","smet"],
51
+ "koi_kepmag": ["st_tmag","tmag","kepmag","koi_kepmag"],
52
+ "koi_model_snr": ["tce_model_snr","model_snr","snr"],
53
+ "koi_num_transits": ["tce_num_transits","num_transits","ntransits","tran_count"]
54
+ }
55
+
56
+ def impute_and_scale(df: pd.DataFrame) -> np.ndarray:
57
+ for col in FEATURE_COLUMNS:
58
+ if col not in df.columns:
59
+ df[col] = np.nan
60
+ df = df[FEATURE_COLUMNS].copy()
61
+ for c in FEATURE_COLUMNS:
62
+ if df[c].isna().any():
63
+ df[c] = df[c].fillna(MEDIANS.get(c, 0.0))
64
+ X = scaler.transform(df.values)
65
+ return X
66
+
67
+ def predict_proba_from_df(df: pd.DataFrame):
68
+ X = impute_and_scale(df)
69
+ probs = model.predict(X, verbose=0)
70
+ classes = list(label_encoder.classes_)
71
+ return probs, classes
72
+
73
+ # ===== Endpoint 1: Probar con 2 TOI/TCE de la API =====
74
+ def predict_toi_samples(n=2, table="tce"):
75
+ if table not in {"tce","toi"}:
76
+ table = "tce"
77
+
78
+ if table == "tce":
79
+ TAP_URL = "https://exoplanetarchive.ipac.caltech.edu/TAP/sync"
80
+ query = f"""
81
+ SELECT TOP {int(n)}
82
+ kepid, tce_plnt_num, tce_period, tce_duration, tce_depth, tce_model_snr
83
+ FROM q1_q17_dr25_tce
84
+ WHERE tce_period > 0 AND tce_duration > 0 AND tce_depth > 0
85
+ ORDER BY tce_model_snr DESC
86
+ """
87
+ r = requests.get(TAP_URL, params={"query": query, "format": "csv"}, timeout=90)
88
+ else:
89
+ BASE = "https://exoplanetarchive.ipac.caltech.edu/cgi-bin/nstedAPI/nph-nstedAPI"
90
+ where = ("(tfopwg_disp like 'PC' or tfopwg_disp like 'APC') and "
91
+ "(pl_orbper is not null or tce_period is not null)")
92
+ r = requests.get(BASE, params={"table":"toi","where":where,"format":"csv"}, timeout=90)
93
+
94
+ r.raise_for_status()
95
+ df = pd.read_csv(io.StringIO(r.text))
96
+ df.columns = [c.strip().lower() for c in df.columns]
97
+ df = df.sample(min(n, len(df)), random_state=7).reset_index(drop=True)
98
+
99
+ # map flexible a FEATURE_COLUMNS
100
+ cols_set = set(df.columns)
101
+ cases = pd.DataFrame(index=df.index, columns=FEATURE_COLUMNS, dtype="float64")
102
+ for feat in FEATURE_COLUMNS:
103
+ src = first_present(CANDIDATES_MAP.get(feat, []), cols_set)
104
+ if src is not None:
105
+ cases[feat] = pd.to_numeric(df[src], errors="coerce")
106
+ else:
107
+ cases[feat] = np.nan
108
+
109
+ probs, classes = predict_proba_from_df(cases)
110
+ idx = np.argmax(probs, axis=1)
111
+ preds = label_encoder.inverse_transform(idx)
112
+
113
+ # construir salida
114
+ out = []
115
+ for i in range(len(df)):
116
+ row_probs = probs[i]
117
+ d = {"prediction": preds[i]}
118
+ for j, cls in enumerate(classes):
119
+ d[f"P({cls})"] = float(row_probs[j])
120
+ out.append(d)
121
+ res = pd.DataFrame(out)
122
+ csv_path = "pred_toi_samples.csv"
123
+ res.to_csv(csv_path, index=False)
124
+ return res, csv_path
125
+
126
+ # ===== Endpoint 2: POST JSON manual =====
127
+ def predict_from_json(json_text: str, threshold: float = 0.5):
128
+ try:
129
+ payload = json.loads(json_text)
130
+ except Exception as e:
131
+ return {"error": f"JSON inválido: {e}"}
132
+
133
+ df = pd.DataFrame([payload])
134
+ # normalizar nombres
135
+ df.columns = [c.strip().lower() for c in df.columns]
136
+ # map a FEATURE_COLUMNS
137
+ cols_set = set(df.columns)
138
+ cases = pd.DataFrame(index=df.index, columns=FEATURE_COLUMNS, dtype="float64")
139
+ for feat in FEATURE_COLUMNS:
140
+ # si ya viene con el nombre koi_* lo usamos
141
+ if feat in cols_set:
142
+ cases[feat] = pd.to_numeric(df[feat], errors="coerce")
143
+ continue
144
+ # sino buscamos sinónimos
145
+ src = first_present(CANDIDATES_MAP.get(feat, []), cols_set)
146
+ if src is not None:
147
+ cases[feat] = pd.to_numeric(df[src], errors="coerce")
148
+ else:
149
+ cases[feat] = np.nan
150
+
151
+ probs, classes = predict_proba_from_df(cases)
152
+ p = probs[0]
153
+ idx = int(np.argmax(p))
154
+ pred = label_encoder.inverse_transform([idx])[0]
155
+ p_confirmed = float(p[classes.index("CONFIRMED")]) if "CONFIRMED" in classes else 0.0
156
+ return {
157
+ "prediction": pred,
158
+ "probabilities": {classes[i]: float(p[i]) for i in range(len(classes))},
159
+ "is_exoplanet": bool(pred.upper()=="CONFIRMED" and p_confirmed >= float(threshold)),
160
+ "p_confirmed": p_confirmed
161
+ }
162
+
163
+ # ===== Endpoint 3: Descargar CSV de un TOI/TCE específico =====
164
+ def download_object_csv(identifier: str, table: str = "toi"):
165
+ table = table.lower()
166
+ if table not in {"toi","tce"}:
167
+ table = "toi"
168
+ if table == "toi":
169
+ BASE = "https://exoplanetarchive.ipac.caltech.edu/cgi-bin/nstedAPI/nph-nstedAPI"
170
+ where = f"toi like '{identifier}'"
171
+ r = requests.get(BASE, params={"table":"toi","where":where,"format":"csv"}, timeout=60)
172
+ else:
173
+ # para TCE usamos TAP por kepid + tce_plnt_num, ejemplo: "KIC 11446443 1"
174
+ TAP_URL = "https://exoplanetarchive.ipac.caltech.edu/TAP/sync"
175
+ parts = identifier.replace(",", " ").split()
176
+ if len(parts) >= 2:
177
+ kep = parts[0]
178
+ num = parts[1]
179
+ query = f"""
180
+ SELECT *
181
+ FROM q1_q17_dr25_tce
182
+ WHERE CAST(kepid AS VARCHAR) like '{kep.replace('KIC','').strip()}'
183
+ AND CAST(tce_plnt_num AS VARCHAR) like '{num.strip()}'
184
+ """
185
+ else:
186
+ query = f"SELECT TOP 1 * FROM q1_q17_dr25_tce WHERE CAST(kepid AS VARCHAR) like '{identifier.strip()}'"
187
+ r = requests.get(TAP_URL, params={"query": query, "format": "csv"}, timeout=90)
188
+
189
+ r.raise_for_status()
190
+ path = "object.csv"
191
+ with open(path, "w") as f:
192
+ f.write(r.text)
193
+ return path
194
+
195
+ # ===== Endpoint 4: Subir CSV y predecir =====
196
+ def predict_from_csv(file_obj, threshold: float = 0.5):
197
+ if file_obj is None:
198
+ return pd.DataFrame(), None
199
+ df = pd.read_csv(file_obj.name)
200
+ # normalizar nombres
201
+ df.columns = [c.strip().lower() for c in df.columns]
202
+ cols_set = set(df.columns)
203
+
204
+ cases = pd.DataFrame(index=df.index, columns=FEATURE_COLUMNS, dtype="float64")
205
+ for feat in FEATURE_COLUMNS:
206
+ src = feat if feat in cols_set else first_present(CANDIDATES_MAP.get(feat, []), cols_set)
207
+ if src is not None:
208
+ cases[feat] = pd.to_numeric(df[src], errors="coerce")
209
+ else:
210
+ cases[feat] = np.nan
211
+
212
+ probs, classes = predict_proba_from_df(cases)
213
+ idx = np.argmax(probs, axis=1)
214
+ preds = label_encoder.inverse_transform(idx)
215
+
216
+ out = []
217
+ for i in range(len(df)):
218
+ row = {"prediction": preds[i]}
219
+ for j, cls in enumerate(classes):
220
+ row[f"P({cls})"] = float(probs[i][j])
221
+ out.append(row)
222
+ res = pd.DataFrame(out)
223
+ out_path = "predicciones.csv"
224
+ res.to_csv(out_path, index=False)
225
+ return res, out_path
226
+
227
+ # ===== Gradio UI =====
228
+ with gr.Blocks() as demo:
229
+ gr.Markdown("# 🔭 Exoplanet Classifier — API + UI (Gradio)")
230
+
231
+ with gr.Row():
232
+ with gr.Column():
233
+ gr.Markdown("### 1) Probar con 2 objetos de la API (TOI o TCE)")
234
+ table_dd = gr.Dropdown(choices=["toi","tce"], value="tce", label="Tabla")
235
+ n_objs = gr.Slider(1, 10, value=2, step=1, label="N objetos")
236
+ out_df1 = gr.Dataframe(label="Resultados")
237
+ out_file1 = gr.File(label="Descargar CSV")
238
+ gr.Button("Probar API").click(predict_toi_samples, inputs=[n_objs, table_dd], outputs=[out_df1, out_file1], api_name="predict_toi_samples")
239
+
240
+ with gr.Column():
241
+ gr.Markdown("### 2) JSON manual (POST)")
242
+ jt = gr.Textbox(lines=12, label="JSON de entrada (TOI/TCE-like o koi_* )")
243
+ thr_json = gr.Slider(0, 1, value=0.5, step=0.01, label="Umbral P(CONFIRMED)")
244
+ out_json = gr.JSON(label="Respuesta")
245
+ gr.Button("Predecir JSON").click(predict_from_json, inputs=[jt, thr_json], outputs=out_json, api_name="predict_json")
246
+
247
+ gr.Markdown("### 3) Descargar CSV de un objeto (por id)")
248
+ ident = gr.Textbox(label="Identificador (ej: TOI-1234.01 o 'KIC 11446443 1')", placeholder="TOI-xxx.yy ó KIC ###### <planet_num>")
249
+ table2 = gr.Dropdown(choices=["toi","tce"], value="toi", label="Tabla")
250
+ out_csv = gr.File(label="CSV del objeto")
251
+ gr.Button("Descargar CSV").click(download_object_csv, inputs=[ident, table2], outputs=out_csv, api_name="toi_csv")
252
+
253
+ gr.Markdown("### 4) Subir CSV y clasificar")
254
+ f_in = gr.File(label="CSV subida", file_types=[".csv"])
255
+ thr = gr.Slider(0,1,value=0.5, step=0.01, label="Umbral P(CONFIRMED)")
256
+ out_df2 = gr.Dataframe(label="Resultados")
257
+ out_file2 = gr.File(label="Descargar predicciones")
258
+ gr.Button("Predecir CSV").click(predict_from_csv, inputs=[f_in, thr], outputs=[out_df2, out_file2], api_name="predict_csv")
259
+
260
+ demo.queue().launch()