import dash from dash import html, dcc, callback, Input, Output, State import dash_bootstrap_components as dbc import pandas as pd import numpy as np import joblib import os import sys import plotly.graph_objects as go from sklearn.base import BaseEstimator, TransformerMixin dash.register_page(__name__) # ===================================================================== # 1. PIPELINE DE-SERIALIZATION CLASS DEFINITION # ===================================================================== class AstronomicalColorEngineer(BaseEstimator, TransformerMixin): """ Custom transformer class matching the exact blueprint used during pipeline training. """ def __init__(self): self.feature_order = [ 'u_apercor', 'g_apercor', 'r_apercor', 'i_apercor', 'z_apercor', 'y_apercor', 'u_g', 'g_r', 'r_i', 'i_z', 'z_y' ] def fit(self, X, y=None): return self def transform(self, X): if not isinstance(X, pd.DataFrame): X = pd.DataFrame(X, columns=['u_apercor', 'g_apercor', 'r_apercor', 'i_apercor', 'z_apercor', 'y_apercor']) # Reset the index to 0 to ensure flawless row math alignment on single web-form vectors df = X.copy().reset_index(drop=True) # Derive relative telescope color pass bands df['u_g'] = df['u_apercor'] - df['g_apercor'] df['g_r'] = df['g_apercor'] - df['r_apercor'] df['r_i'] = df['r_apercor'] - df['i_apercor'] df['i_z'] = df['i_apercor'] - df['z_apercor'] df['z_y'] = df['z_apercor'] - df['y_apercor'] return df[self.feature_order] # ===================================================================== # 2. SAFE PIPELINE INITIALIZATION & NAMESPACE INJECTION # ===================================================================== # Dynamically inject the class definition into the root main thread to intercept Joblib de-serialization errors sys.modules['__main__'].AstronomicalColorEngineer = AstronomicalColorEngineer BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) model_path = os.path.join(BASE_DIR, 'models', 'xgb_production_pipeline.pkl') try: pipeline = joblib.load(model_path) except Exception as e: print(f"Error loading pipeline tracking binary: {e}") pipeline = None # ===================================================================== # 3. INTERPRETABILITY & CONFIDENCE INTERVAL GENERATOR HELPER # ===================================================================== def run_advanced_inference(raw_df): """ Extracts the sub-steps of the pipeline to run localized feature contribution calculations and dynamically scale cosmological error bars. """ transformer = pipeline.named_steps['color_engineer'] xgb_model = pipeline.named_steps['model'] # 1. Transform input to view the 11 engineered bands + colors enriched_df = transformer.transform(raw_df) # 2. Compute base redshift prediction (z) prediction = xgb_model.predict(enriched_df)[0] # 3. Calculate 95% Confidence Interval boundaries using the pipeline's true NMAD (0.0196) # Photometric redshift errors scale heteroscedastically: σ ≈ NMAD * (1 + z) nmad_scale = 0.019678 error_margin = 1.96 * nmad_scale * (1 + prediction) ci_lower = max(0.0, prediction - error_margin) ci_upper = prediction + error_margin # 4. Extract Real-Time Feature Importance / Shape Vector global_importances = xgb_model.feature_importances_ sample_values = np.abs(enriched_df.iloc[0].values) raw_contributions = global_importances * sample_values # Convert to relative contribution percentages for cleaner visualization contribution_pct = (raw_contributions / np.sum(raw_contributions)) * 100 importance_df = pd.DataFrame({ 'Feature': enriched_df.columns, 'Contribution': contribution_pct }).sort_values(by='Contribution', ascending=True) # 5. Build Horizontal Contribution Graph fig = go.Figure(go.Bar( x=importance_df['Contribution'], y=importance_df['Feature'], orientation='h', marker=dict( color=importance_df['Contribution'], colorscale=[[0, '#3a1c5c'], [0.5, '#770A7A'], [1, '#00d2ff']], line=dict(color='rgba(255,255,255,0.1)', width=1) ) )) fig.update_layout( paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)', font=dict(color='white', size=10), margin=dict(l=75, r=15, t=10, b=10), height=260, xaxis=dict(title="Relative Feature Weight Contribution (%)", gridcolor='rgba(255,255,255,0.05)', title_font=dict(size=10)), yaxis=dict(gridcolor='rgba(255,255,255,0.05)') ) return prediction, ci_lower, ci_upper, fig # ===================================================================== # 4. INTERFACE LAYOUT DEFINITION # ===================================================================== layout = dbc.Container([ html.Div([ html.Span("NEBULA ENGINE V3.0 — PRODUCTION ENHANCED", className="text-info small fw-bold tracking-widest"), html.H1("Galaxy Redshift Estimator", className="text-white fw-bold display-4"), ], className="mb-4 py-3"), dbc.Row([ # Left Column: User Input Form Panel dbc.Col([ dbc.Card([ dbc.CardBody([ html.H5("Photometric Magnitudes", className="mb-4 text-white-50 fw-light"), dbc.Row([ dbc.Col([ dbc.Label("u_apercor", className="text-info x-small"), dbc.Input(id="u_v", type="number", value=24.9132, className="bg-dark border-secondary text-white mb-3"), dbc.Label("r_apercor", className="text-info x-small"), dbc.Input(id="r_v", type="number", value=23.4690, className="bg-dark border-secondary text-white mb-3"), dbc.Label("z_apercor", className="text-info x-small"), dbc.Input(id="z_v", type="number", value=21.7408, className="bg-dark border-secondary text-white"), ], width=6), dbc.Col([ dbc.Label("g_apercor", className="text-info x-small"), dbc.Input(id="g_v", type="number", value=24.4809, className="bg-dark border-secondary text-white mb-3"), dbc.Label("i_apercor", className="text-info x-small"), dbc.Input(id="i_v", type="number", value=22.4933, className="bg-dark border-secondary text-white mb-3"), dbc.Label("y_apercor", className="text-info x-small"), dbc.Input(id="y_v", type="number", value=21.5000, className="bg-dark border-secondary text-white"), ], width=6), ]), dbc.Button("ANALYZE SPECTRUM", id="run-btn", className="w-100 py-3 mt-4 fw-bold", style={"background": "linear-gradient(45deg, #00d2ff, #9d50bb)", "border": "none", "borderRadius": "15px"}) ]) ], className="modern-card p-3 mb-4") ], lg=6), # Right Column: Predictions, Intervals, and Local Shape Profiles dbc.Col([ dbc.Card([ dbc.CardBody([ html.H6("ESTIMATED REDSHIFT", className="text-center text-white-50 letter-spacing-2 mb-1"), html.Div("---", id="pred-out", className="text-center py-1", style={"fontSize": "3.8rem", "color": "#00d2ff", "fontWeight": "bold"}), # 95% Confidence Bounds html.Div("Awaiting matrix initialization...", id="ci-out", className="text-center text-muted small mb-3", style={"fontStyle": "italic"}), # Production Model Validation Performance Statistics html.Div([ dbc.Row([ dbc.Col([ html.Small("MODEL R²", className="text-info d-block x-small"), html.Span("0.8429", className="fw-bold text-white"), ], className="text-center border-end border-secondary"), dbc.Col([ html.Small("GLOBAL MAE", className="text-info d-block x-small"), html.Span("0.0538", className="fw-bold text-white"), ], className="text-center border-end border-secondary"), dbc.Col([ html.Small("OUTLIER FRACTION", className="text-info d-block x-small"), html.Span("3.17%", className="fw-bold text-success"), ], className="text-center"), ]), ], className="py-2 px-1 mb-3", style={"background": "rgba(0,0,0,0.3)", "borderRadius": "12px"}), # Real-time Feature Contribution Figure Container html.H6("LOCALIZED PROFILE CONTEXT (SHAPE IMPORTANCE)", className="text-white-50 x-small letter-spacing-2 mb-2"), html.Div([ dcc.Graph(id="importance-graph", config={"displayModeBar": False}) ], style={"background": "rgba(0,0,0,0.2)", "borderRadius": "12px", "padding": "5px"}) ], className="d-flex flex-column justify-content-center h-100") ], className="modern-card") ], lg=6) ]) ], fluid=True) # ===================================================================== # 5. REACTIVE REACTION CALLBACK CONTROL # ===================================================================== @callback( [Output("pred-out", "children"), Output("ci-out", "children"), Output("importance-graph", "figure")], Input("run-btn", "n_clicks"), [State("u_v", "value"), State("g_v", "value"), State("r_v", "value"), State("i_v", "value"), State("z_v", "value"), State("y_v", "value")] ) def update_prediction_dashboard(n, u, g, r, i, z, y): if not n or pipeline is None: empty_fig = go.Figure().update_layout(paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)', font=dict(color='white')) return "---", "Awaiting spectrum parameters...", empty_fig if None in [u, g, r, i, z, y]: return "ERROR", "Incomplete matrix parameters provided.", go.Figure() # Structure raw elements into a DataFrame mapping the pipeline's anticipated inputs df = pd.DataFrame( [[u, g, r, i, z, y]], columns=['u_apercor', 'g_apercor', 'r_apercor', 'i_apercor', 'z_apercor', 'y_apercor'] ) # Compute using the advanced inference engine pred, ci_low, ci_high, importance_figure = run_advanced_inference(df) ci_text = f"95% Confidence Bounds: [{ci_low:.4f} — {ci_high:.4f}]" pred_text = f"{pred:.4f}" return pred_text, ci_text, importance_figure