Spaces:
Sleeping
Sleeping
| import dash | |
| from dash import html, dcc, callback, Input, Output, State | |
| import dash_bootstrap_components as dbc | |
| import pandas as pd | |
| import numpy as np | |
| import joblib | |
| import os | |
| import sys | |
| import plotly.graph_objects as go | |
| from sklearn.base import BaseEstimator, TransformerMixin | |
| dash.register_page(__name__) | |
| # ===================================================================== | |
| # 1. PIPELINE DE-SERIALIZATION CLASS DEFINITION | |
| # ===================================================================== | |
| class AstronomicalColorEngineer(BaseEstimator, TransformerMixin): | |
| """ | |
| Custom transformer class matching the exact blueprint used during pipeline training. | |
| """ | |
| def __init__(self): | |
| self.feature_order = [ | |
| 'u_apercor', 'g_apercor', 'r_apercor', 'i_apercor', 'z_apercor', 'y_apercor', | |
| 'u_g', 'g_r', 'r_i', 'i_z', 'z_y' | |
| ] | |
| def fit(self, X, y=None): | |
| return self | |
| def transform(self, X): | |
| if not isinstance(X, pd.DataFrame): | |
| X = pd.DataFrame(X, columns=['u_apercor', 'g_apercor', 'r_apercor', 'i_apercor', 'z_apercor', 'y_apercor']) | |
| # Reset the index to 0 to ensure flawless row math alignment on single web-form vectors | |
| df = X.copy().reset_index(drop=True) | |
| # Derive relative telescope color pass bands | |
| df['u_g'] = df['u_apercor'] - df['g_apercor'] | |
| df['g_r'] = df['g_apercor'] - df['r_apercor'] | |
| df['r_i'] = df['r_apercor'] - df['i_apercor'] | |
| df['i_z'] = df['i_apercor'] - df['z_apercor'] | |
| df['z_y'] = df['z_apercor'] - df['y_apercor'] | |
| return df[self.feature_order] | |
| # ===================================================================== | |
| # 2. SAFE PIPELINE INITIALIZATION & NAMESPACE INJECTION | |
| # ===================================================================== | |
| # Dynamically inject the class definition into the root main thread to intercept Joblib de-serialization errors | |
| sys.modules['__main__'].AstronomicalColorEngineer = AstronomicalColorEngineer | |
| BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| model_path = os.path.join(BASE_DIR, 'models', 'xgb_production_pipeline.pkl') | |
| try: | |
| pipeline = joblib.load(model_path) | |
| except Exception as e: | |
| print(f"Error loading pipeline tracking binary: {e}") | |
| pipeline = None | |
| # ===================================================================== | |
| # 3. INTERPRETABILITY & CONFIDENCE INTERVAL GENERATOR HELPER | |
| # ===================================================================== | |
| def run_advanced_inference(raw_df): | |
| """ | |
| Extracts the sub-steps of the pipeline to run localized feature contribution | |
| calculations and dynamically scale cosmological error bars. | |
| """ | |
| transformer = pipeline.named_steps['color_engineer'] | |
| xgb_model = pipeline.named_steps['model'] | |
| # 1. Transform input to view the 11 engineered bands + colors | |
| enriched_df = transformer.transform(raw_df) | |
| # 2. Compute base redshift prediction (z) | |
| prediction = xgb_model.predict(enriched_df)[0] | |
| # 3. Calculate 95% Confidence Interval boundaries using the pipeline's true NMAD (0.0196) | |
| # Photometric redshift errors scale heteroscedastically: σ ≈ NMAD * (1 + z) | |
| nmad_scale = 0.019678 | |
| error_margin = 1.96 * nmad_scale * (1 + prediction) | |
| ci_lower = max(0.0, prediction - error_margin) | |
| ci_upper = prediction + error_margin | |
| # 4. Extract Real-Time Feature Importance / Shape Vector | |
| global_importances = xgb_model.feature_importances_ | |
| sample_values = np.abs(enriched_df.iloc[0].values) | |
| raw_contributions = global_importances * sample_values | |
| # Convert to relative contribution percentages for cleaner visualization | |
| contribution_pct = (raw_contributions / np.sum(raw_contributions)) * 100 | |
| importance_df = pd.DataFrame({ | |
| 'Feature': enriched_df.columns, | |
| 'Contribution': contribution_pct | |
| }).sort_values(by='Contribution', ascending=True) | |
| # 5. Build Horizontal Contribution Graph | |
| fig = go.Figure(go.Bar( | |
| x=importance_df['Contribution'], | |
| y=importance_df['Feature'], | |
| orientation='h', | |
| marker=dict( | |
| color=importance_df['Contribution'], | |
| colorscale=[[0, '#3a1c5c'], [0.5, '#770A7A'], [1, '#00d2ff']], | |
| line=dict(color='rgba(255,255,255,0.1)', width=1) | |
| ) | |
| )) | |
| fig.update_layout( | |
| paper_bgcolor='rgba(0,0,0,0)', | |
| plot_bgcolor='rgba(0,0,0,0)', | |
| font=dict(color='white', size=10), | |
| margin=dict(l=75, r=15, t=10, b=10), | |
| height=260, | |
| xaxis=dict(title="Relative Feature Weight Contribution (%)", gridcolor='rgba(255,255,255,0.05)', title_font=dict(size=10)), | |
| yaxis=dict(gridcolor='rgba(255,255,255,0.05)') | |
| ) | |
| return prediction, ci_lower, ci_upper, fig | |
| # ===================================================================== | |
| # 4. INTERFACE LAYOUT DEFINITION | |
| # ===================================================================== | |
| layout = dbc.Container([ | |
| html.Div([ | |
| html.Span("NEBULA ENGINE V3.0 — PRODUCTION ENHANCED", className="text-info small fw-bold tracking-widest"), | |
| html.H1("Galaxy Redshift Estimator", className="text-white fw-bold display-4"), | |
| ], className="mb-4 py-3"), | |
| dbc.Row([ | |
| # Left Column: User Input Form Panel | |
| dbc.Col([ | |
| dbc.Card([ | |
| dbc.CardBody([ | |
| html.H5("Photometric Magnitudes", className="mb-4 text-white-50 fw-light"), | |
| dbc.Row([ | |
| dbc.Col([ | |
| dbc.Label("u_apercor", className="text-info x-small"), | |
| dbc.Input(id="u_v", type="number", value=24.9132, className="bg-dark border-secondary text-white mb-3"), | |
| dbc.Label("r_apercor", className="text-info x-small"), | |
| dbc.Input(id="r_v", type="number", value=23.4690, className="bg-dark border-secondary text-white mb-3"), | |
| dbc.Label("z_apercor", className="text-info x-small"), | |
| dbc.Input(id="z_v", type="number", value=21.7408, className="bg-dark border-secondary text-white"), | |
| ], width=6), | |
| dbc.Col([ | |
| dbc.Label("g_apercor", className="text-info x-small"), | |
| dbc.Input(id="g_v", type="number", value=24.4809, className="bg-dark border-secondary text-white mb-3"), | |
| dbc.Label("i_apercor", className="text-info x-small"), | |
| dbc.Input(id="i_v", type="number", value=22.4933, className="bg-dark border-secondary text-white mb-3"), | |
| dbc.Label("y_apercor", className="text-info x-small"), | |
| dbc.Input(id="y_v", type="number", value=21.5000, className="bg-dark border-secondary text-white"), | |
| ], width=6), | |
| ]), | |
| dbc.Button("ANALYZE SPECTRUM", id="run-btn", className="w-100 py-3 mt-4 fw-bold", | |
| style={"background": "linear-gradient(45deg, #00d2ff, #9d50bb)", "border": "none", "borderRadius": "15px"}) | |
| ]) | |
| ], className="modern-card p-3 mb-4") | |
| ], lg=6), | |
| # Right Column: Predictions, Intervals, and Local Shape Profiles | |
| dbc.Col([ | |
| dbc.Card([ | |
| dbc.CardBody([ | |
| html.H6("ESTIMATED REDSHIFT", className="text-center text-white-50 letter-spacing-2 mb-1"), | |
| html.Div("---", id="pred-out", className="text-center py-1", style={"fontSize": "3.8rem", "color": "#00d2ff", "fontWeight": "bold"}), | |
| # 95% Confidence Bounds | |
| html.Div("Awaiting matrix initialization...", id="ci-out", className="text-center text-muted small mb-3", style={"fontStyle": "italic"}), | |
| # Production Model Validation Performance Statistics | |
| html.Div([ | |
| dbc.Row([ | |
| dbc.Col([ | |
| html.Small("MODEL R²", className="text-info d-block x-small"), | |
| html.Span("0.8429", className="fw-bold text-white"), | |
| ], className="text-center border-end border-secondary"), | |
| dbc.Col([ | |
| html.Small("GLOBAL MAE", className="text-info d-block x-small"), | |
| html.Span("0.0538", className="fw-bold text-white"), | |
| ], className="text-center border-end border-secondary"), | |
| dbc.Col([ | |
| html.Small("OUTLIER FRACTION", className="text-info d-block x-small"), | |
| html.Span("3.17%", className="fw-bold text-success"), | |
| ], className="text-center"), | |
| ]), | |
| ], className="py-2 px-1 mb-3", style={"background": "rgba(0,0,0,0.3)", "borderRadius": "12px"}), | |
| # Real-time Feature Contribution Figure Container | |
| html.H6("LOCALIZED PROFILE CONTEXT (SHAPE IMPORTANCE)", className="text-white-50 x-small letter-spacing-2 mb-2"), | |
| html.Div([ | |
| dcc.Graph(id="importance-graph", config={"displayModeBar": False}) | |
| ], style={"background": "rgba(0,0,0,0.2)", "borderRadius": "12px", "padding": "5px"}) | |
| ], className="d-flex flex-column justify-content-center h-100") | |
| ], className="modern-card") | |
| ], lg=6) | |
| ]) | |
| ], fluid=True) | |
| # ===================================================================== | |
| # 5. REACTIVE REACTION CALLBACK CONTROL | |
| # ===================================================================== | |
| def update_prediction_dashboard(n, u, g, r, i, z, y): | |
| if not n or pipeline is None: | |
| empty_fig = go.Figure().update_layout(paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)', font=dict(color='white')) | |
| return "---", "Awaiting spectrum parameters...", empty_fig | |
| if None in [u, g, r, i, z, y]: | |
| return "ERROR", "Incomplete matrix parameters provided.", go.Figure() | |
| # Structure raw elements into a DataFrame mapping the pipeline's anticipated inputs | |
| df = pd.DataFrame( | |
| [[u, g, r, i, z, y]], | |
| columns=['u_apercor', 'g_apercor', 'r_apercor', 'i_apercor', 'z_apercor', 'y_apercor'] | |
| ) | |
| # Compute using the advanced inference engine | |
| pred, ci_low, ci_high, importance_figure = run_advanced_inference(df) | |
| ci_text = f"95% Confidence Bounds: [{ci_low:.4f} — {ci_high:.4f}]" | |
| pred_text = f"{pred:.4f}" | |
| return pred_text, ci_text, importance_figure |