RedShift_App / pages /prediction.py
hellosara's picture
Upload 28 files
cdf1899 verified
Raw
History Blame Contribute Delete
11.6 kB
import dash
from dash import html, dcc, callback, Input, Output, State
import dash_bootstrap_components as dbc
import pandas as pd
import numpy as np
import joblib
import os
import sys
import plotly.graph_objects as go
from sklearn.base import BaseEstimator, TransformerMixin
dash.register_page(__name__)
# =====================================================================
# 1. PIPELINE DE-SERIALIZATION CLASS DEFINITION
# =====================================================================
class AstronomicalColorEngineer(BaseEstimator, TransformerMixin):
"""
Custom transformer class matching the exact blueprint used during pipeline training.
"""
def __init__(self):
self.feature_order = [
'u_apercor', 'g_apercor', 'r_apercor', 'i_apercor', 'z_apercor', 'y_apercor',
'u_g', 'g_r', 'r_i', 'i_z', 'z_y'
]
def fit(self, X, y=None):
return self
def transform(self, X):
if not isinstance(X, pd.DataFrame):
X = pd.DataFrame(X, columns=['u_apercor', 'g_apercor', 'r_apercor', 'i_apercor', 'z_apercor', 'y_apercor'])
# Reset the index to 0 to ensure flawless row math alignment on single web-form vectors
df = X.copy().reset_index(drop=True)
# Derive relative telescope color pass bands
df['u_g'] = df['u_apercor'] - df['g_apercor']
df['g_r'] = df['g_apercor'] - df['r_apercor']
df['r_i'] = df['r_apercor'] - df['i_apercor']
df['i_z'] = df['i_apercor'] - df['z_apercor']
df['z_y'] = df['z_apercor'] - df['y_apercor']
return df[self.feature_order]
# =====================================================================
# 2. SAFE PIPELINE INITIALIZATION & NAMESPACE INJECTION
# =====================================================================
# Dynamically inject the class definition into the root main thread to intercept Joblib de-serialization errors
sys.modules['__main__'].AstronomicalColorEngineer = AstronomicalColorEngineer
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
model_path = os.path.join(BASE_DIR, 'models', 'xgb_production_pipeline.pkl')
try:
pipeline = joblib.load(model_path)
except Exception as e:
print(f"Error loading pipeline tracking binary: {e}")
pipeline = None
# =====================================================================
# 3. INTERPRETABILITY & CONFIDENCE INTERVAL GENERATOR HELPER
# =====================================================================
def run_advanced_inference(raw_df):
"""
Extracts the sub-steps of the pipeline to run localized feature contribution
calculations and dynamically scale cosmological error bars.
"""
transformer = pipeline.named_steps['color_engineer']
xgb_model = pipeline.named_steps['model']
# 1. Transform input to view the 11 engineered bands + colors
enriched_df = transformer.transform(raw_df)
# 2. Compute base redshift prediction (z)
prediction = xgb_model.predict(enriched_df)[0]
# 3. Calculate 95% Confidence Interval boundaries using the pipeline's true NMAD (0.0196)
# Photometric redshift errors scale heteroscedastically: σ ≈ NMAD * (1 + z)
nmad_scale = 0.019678
error_margin = 1.96 * nmad_scale * (1 + prediction)
ci_lower = max(0.0, prediction - error_margin)
ci_upper = prediction + error_margin
# 4. Extract Real-Time Feature Importance / Shape Vector
global_importances = xgb_model.feature_importances_
sample_values = np.abs(enriched_df.iloc[0].values)
raw_contributions = global_importances * sample_values
# Convert to relative contribution percentages for cleaner visualization
contribution_pct = (raw_contributions / np.sum(raw_contributions)) * 100
importance_df = pd.DataFrame({
'Feature': enriched_df.columns,
'Contribution': contribution_pct
}).sort_values(by='Contribution', ascending=True)
# 5. Build Horizontal Contribution Graph
fig = go.Figure(go.Bar(
x=importance_df['Contribution'],
y=importance_df['Feature'],
orientation='h',
marker=dict(
color=importance_df['Contribution'],
colorscale=[[0, '#3a1c5c'], [0.5, '#770A7A'], [1, '#00d2ff']],
line=dict(color='rgba(255,255,255,0.1)', width=1)
)
))
fig.update_layout(
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)',
font=dict(color='white', size=10),
margin=dict(l=75, r=15, t=10, b=10),
height=260,
xaxis=dict(title="Relative Feature Weight Contribution (%)", gridcolor='rgba(255,255,255,0.05)', title_font=dict(size=10)),
yaxis=dict(gridcolor='rgba(255,255,255,0.05)')
)
return prediction, ci_lower, ci_upper, fig
# =====================================================================
# 4. INTERFACE LAYOUT DEFINITION
# =====================================================================
layout = dbc.Container([
html.Div([
html.Span("NEBULA ENGINE V3.0 — PRODUCTION ENHANCED", className="text-info small fw-bold tracking-widest"),
html.H1("Galaxy Redshift Estimator", className="text-white fw-bold display-4"),
], className="mb-4 py-3"),
dbc.Row([
# Left Column: User Input Form Panel
dbc.Col([
dbc.Card([
dbc.CardBody([
html.H5("Photometric Magnitudes", className="mb-4 text-white-50 fw-light"),
dbc.Row([
dbc.Col([
dbc.Label("u_apercor", className="text-info x-small"),
dbc.Input(id="u_v", type="number", value=24.9132, className="bg-dark border-secondary text-white mb-3"),
dbc.Label("r_apercor", className="text-info x-small"),
dbc.Input(id="r_v", type="number", value=23.4690, className="bg-dark border-secondary text-white mb-3"),
dbc.Label("z_apercor", className="text-info x-small"),
dbc.Input(id="z_v", type="number", value=21.7408, className="bg-dark border-secondary text-white"),
], width=6),
dbc.Col([
dbc.Label("g_apercor", className="text-info x-small"),
dbc.Input(id="g_v", type="number", value=24.4809, className="bg-dark border-secondary text-white mb-3"),
dbc.Label("i_apercor", className="text-info x-small"),
dbc.Input(id="i_v", type="number", value=22.4933, className="bg-dark border-secondary text-white mb-3"),
dbc.Label("y_apercor", className="text-info x-small"),
dbc.Input(id="y_v", type="number", value=21.5000, className="bg-dark border-secondary text-white"),
], width=6),
]),
dbc.Button("ANALYZE SPECTRUM", id="run-btn", className="w-100 py-3 mt-4 fw-bold",
style={"background": "linear-gradient(45deg, #00d2ff, #9d50bb)", "border": "none", "borderRadius": "15px"})
])
], className="modern-card p-3 mb-4")
], lg=6),
# Right Column: Predictions, Intervals, and Local Shape Profiles
dbc.Col([
dbc.Card([
dbc.CardBody([
html.H6("ESTIMATED REDSHIFT", className="text-center text-white-50 letter-spacing-2 mb-1"),
html.Div("---", id="pred-out", className="text-center py-1", style={"fontSize": "3.8rem", "color": "#00d2ff", "fontWeight": "bold"}),
# 95% Confidence Bounds
html.Div("Awaiting matrix initialization...", id="ci-out", className="text-center text-muted small mb-3", style={"fontStyle": "italic"}),
# Production Model Validation Performance Statistics
html.Div([
dbc.Row([
dbc.Col([
html.Small("MODEL R²", className="text-info d-block x-small"),
html.Span("0.8429", className="fw-bold text-white"),
], className="text-center border-end border-secondary"),
dbc.Col([
html.Small("GLOBAL MAE", className="text-info d-block x-small"),
html.Span("0.0538", className="fw-bold text-white"),
], className="text-center border-end border-secondary"),
dbc.Col([
html.Small("OUTLIER FRACTION", className="text-info d-block x-small"),
html.Span("3.17%", className="fw-bold text-success"),
], className="text-center"),
]),
], className="py-2 px-1 mb-3", style={"background": "rgba(0,0,0,0.3)", "borderRadius": "12px"}),
# Real-time Feature Contribution Figure Container
html.H6("LOCALIZED PROFILE CONTEXT (SHAPE IMPORTANCE)", className="text-white-50 x-small letter-spacing-2 mb-2"),
html.Div([
dcc.Graph(id="importance-graph", config={"displayModeBar": False})
], style={"background": "rgba(0,0,0,0.2)", "borderRadius": "12px", "padding": "5px"})
], className="d-flex flex-column justify-content-center h-100")
], className="modern-card")
], lg=6)
])
], fluid=True)
# =====================================================================
# 5. REACTIVE REACTION CALLBACK CONTROL
# =====================================================================
@callback(
[Output("pred-out", "children"),
Output("ci-out", "children"),
Output("importance-graph", "figure")],
Input("run-btn", "n_clicks"),
[State("u_v", "value"), State("g_v", "value"), State("r_v", "value"),
State("i_v", "value"), State("z_v", "value"), State("y_v", "value")]
)
def update_prediction_dashboard(n, u, g, r, i, z, y):
if not n or pipeline is None:
empty_fig = go.Figure().update_layout(paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)', font=dict(color='white'))
return "---", "Awaiting spectrum parameters...", empty_fig
if None in [u, g, r, i, z, y]:
return "ERROR", "Incomplete matrix parameters provided.", go.Figure()
# Structure raw elements into a DataFrame mapping the pipeline's anticipated inputs
df = pd.DataFrame(
[[u, g, r, i, z, y]],
columns=['u_apercor', 'g_apercor', 'r_apercor', 'i_apercor', 'z_apercor', 'y_apercor']
)
# Compute using the advanced inference engine
pred, ci_low, ci_high, importance_figure = run_advanced_inference(df)
ci_text = f"95% Confidence Bounds: [{ci_low:.4f}{ci_high:.4f}]"
pred_text = f"{pred:.4f}"
return pred_text, ci_text, importance_figure