Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import datetime as dt | |
| from pathlib import Path | |
| import joblib | |
| import pandas as pd | |
| import pydeck as pdk | |
| import streamlit as st | |
| from wqf7009_aqi.aqi import pm25_to_aqi_category_us_epa, pm25_to_aqi_us_epa | |
| from wqf7009_aqi.features import build_features_for_inference | |
| from wqf7009_aqi.geo import nearest_location | |
| from wqf7009_aqi.geocoding import geocode_name | |
| from wqf7009_aqi.openaq_archive import load_location_catalog | |
| from wqf7009_aqi.xai import explain_local_shap, generate_counterfactuals | |
| ARTIFACTS_DIR = Path("artifacts_50loc") | |
| def _load_artifacts(artifacts_dir: Path) -> tuple[object, dict]: | |
| model_path = artifacts_dir / "model.joblib" | |
| meta_path = artifacts_dir / "meta.joblib" | |
| if not model_path.exists() or not meta_path.exists(): | |
| raise FileNotFoundError( | |
| "Missing artifacts. Run: python -m wqf7009_aqi train --data data\\dataset.parquet --outdir artifacts" | |
| ) | |
| model = joblib.load(model_path) | |
| meta = joblib.load(meta_path) | |
| return model, meta | |
| def _load_dataset(path: str) -> pd.DataFrame: | |
| df = pd.read_parquet(path) | |
| df["date"] = pd.to_datetime(df["date"]) | |
| return df | |
| st.set_page_config(page_title="AQI PM2.5 + XAI", layout="wide") | |
| st.title("AQI PM2.5 predictor with explanations") | |
| st.caption("Global OpenAQ data | Daily PM2.5 regression | SHAP + counterfactuals") | |
| with st.sidebar: | |
| st.header("Setup") | |
| artifacts_dir = Path(st.text_input("Artifacts directory", value=str(ARTIFACTS_DIR))) | |
| model, meta = _load_artifacts(artifacts_dir) | |
| data_path = Path(st.text_input("Dataset path", value=str(Path("data/dataset.parquet")))) | |
| if not data_path.exists(): | |
| st.error(f"Dataset not found: {data_path}") | |
| st.stop() | |
| data = _load_dataset(str(data_path)) | |
| catalog_path = Path(st.text_input("Location catalog", value=str(artifacts_dir / "location_catalog.parquet"))) | |
| locations = load_location_catalog(catalog_path) | |
| # Filter catalog to only show stations that have data in the dataset | |
| dataset_location_ids = set(data["location_id"].unique()) | |
| locations = locations[locations["location_id"].isin(dataset_location_ids)] | |
| if locations.empty: | |
| st.error("No locations with data found. The catalog and dataset may be out of sync.") | |
| st.stop() | |
| st.divider() | |
| st.header("Query") | |
| mode = st.radio("Location mode", ["Search stations", "Any location (geocode + nearest station)"]) | |
| snapped_km = 0.0 | |
| geocoded_lat = None | |
| geocoded_lon = None | |
| if mode == "Search stations": | |
| location_query = st.text_input("Search station name", value="") | |
| if location_query.strip(): | |
| matches = locations[locations["location"].str.contains(location_query, case=False, na=False)].head(50) | |
| else: | |
| matches = locations.head(50) | |
| if matches.empty: | |
| st.warning(f"No stations found matching '{location_query}'. Try a different search term or leave blank to see all stations.") | |
| st.stop() | |
| choice = st.selectbox( | |
| "Choose a station", | |
| options=matches.index, | |
| format_func=lambda i: f"{matches.loc[i,'location']} (id {int(i)})", | |
| ) | |
| else: | |
| place = st.text_input("Enter any place name", value="") | |
| if not place.strip(): | |
| st.info("Enter a place name to find the nearest air quality station.") | |
| st.stop() | |
| results = geocode_name(place, count=5) | |
| if not results: | |
| st.warning(f"No geocoding results for '{place}'. Try a different name.") | |
| st.stop() | |
| labels = [ | |
| f"{r.get('name')}, {r.get('admin1','')}, {r.get('country','')} ({float(r.get('latitude')):.3f},{float(r.get('longitude')):.3f})" | |
| for r in results | |
| ] | |
| pick = st.selectbox("Geocoding results", options=list(range(len(results))), format_func=lambda i: labels[i]) | |
| geocoded_lat = float(results[pick]["latitude"]) | |
| geocoded_lon = float(results[pick]["longitude"]) | |
| # Find nearest station | |
| nearest = nearest_location(lat=geocoded_lat, lon=geocoded_lon, catalog=locations, k=1) | |
| choice = nearest.index[0] | |
| snapped_km = float(nearest.iloc[0]["distance_km"]) | |
| st.caption(f"Snapped to nearest OpenAQ station: id `{choice}` (distance `{snapped_km:.1f} km`).") | |
| # Determine valid date range for the selected location | |
| loc_id = int(locations.loc[choice, "location_id"]) | |
| loc_subset = data[data["location_id"] == loc_id] | |
| if not loc_subset.empty: | |
| loc_dates = loc_subset["date"].dt.date | |
| valid_min = loc_dates.min() | |
| valid_max = loc_dates.max() | |
| else: | |
| # Fallback to global metadata if specific data not found | |
| valid_min = dt.date.fromisoformat(meta["date_min"]) | |
| valid_max = dt.date.fromisoformat(meta["date_max"]) | |
| with st.expander("Historical Analysis (Audit)"): | |
| query_date = st.date_input( | |
| "Date (within available history)", | |
| value=valid_max, | |
| min_value=valid_min, | |
| max_value=valid_max, | |
| ) | |
| row = locations.loc[choice] | |
| lat = float(row["lat"]) | |
| lon = float(row["lon"]) | |
| location_name = str(row["location"]) | |
| left, right = st.columns([1.2, 1.0], gap="large") | |
| with left: | |
| st.subheader(f"{location_name} (id {int(row['location_id'])})") | |
| extra = f" | snapped {snapped_km:.1f} km" if snapped_km else "" | |
| st.write(f"Lat/Lon: `{lat:.4f}, {lon:.4f}` | Date: `{query_date.isoformat()}`{extra}") | |
| # Build map layers | |
| map_layers = [] | |
| # If geocoded, show both the searched location (red) and station (blue) | |
| if geocoded_lat is not None and geocoded_lon is not None: | |
| # Searched location (red marker) | |
| map_layers.append( | |
| pdk.Layer( | |
| "ScatterplotLayer", | |
| data=pd.DataFrame([{"lat": geocoded_lat, "lon": geocoded_lon}]), | |
| get_position="[lon, lat]", | |
| get_radius=15000, | |
| get_fill_color=[255, 50, 50, 200], | |
| pickable=True, | |
| ) | |
| ) | |
| # Station location (blue marker) | |
| map_layers.append( | |
| pdk.Layer( | |
| "ScatterplotLayer", | |
| data=pd.DataFrame([{"lat": lat, "lon": lon}]), | |
| get_position="[lon, lat]", | |
| get_radius=15000, | |
| get_fill_color=[30, 144, 255, 200], | |
| pickable=True, | |
| ) | |
| ) | |
| # Center map between both points | |
| center_lat = (geocoded_lat + lat) / 2 | |
| center_lon = (geocoded_lon + lon) / 2 | |
| # Zoom out to show both points (approximate) | |
| zoom_level = max(0, min(8, 8 - int(snapped_km / 500))) | |
| else: | |
| # Just station location (blue marker) | |
| map_layers.append( | |
| pdk.Layer( | |
| "ScatterplotLayer", | |
| data=pd.DataFrame([{"lat": lat, "lon": lon}]), | |
| get_position="[lon, lat]", | |
| get_radius=20000, | |
| get_fill_color=[30, 144, 255, 180], | |
| ) | |
| ) | |
| center_lat = lat | |
| center_lon = lon | |
| zoom_level = 8 | |
| st.pydeck_chart( | |
| pdk.Deck( | |
| initial_view_state=pdk.ViewState(latitude=center_lat, longitude=center_lon, zoom=zoom_level, pitch=0), | |
| layers=map_layers, | |
| ), | |
| key=f"map_{choice}_{geocoded_lat}_{geocoded_lon}", # Force update on location change | |
| ) | |
| if geocoded_lat is not None and geocoded_lon is not None: | |
| st.caption("Your searched location vs nearest air quality station") | |
| # Ensure precise date matching by converting input date to pandas Timestamp (midnight) | |
| query_ts = pd.Timestamp(query_date) | |
| match = data[ | |
| data["location_id"].eq(int(row["location_id"])) | |
| & (data["date"] == query_ts) | |
| ] | |
| if match.empty: | |
| st.error("No feature row for this date/location. Choose a later date with enough history.") | |
| st.stop() | |
| feature_row = match.iloc[0] | |
| X = build_features_for_inference(feature_row=feature_row, feature_schema=meta["feature_schema"]) | |
| pred_pm25 = float(model.predict(X)[0]) | |
| aqi = pm25_to_aqi_us_epa(pred_pm25) | |
| badge = pm25_to_aqi_category_us_epa(pred_pm25) | |
| with right: | |
| st.subheader("Prediction") | |
| c1, c2, c3 = st.columns(3) | |
| c1.metric("Predicted PM2.5 (ug/m3)", f"{pred_pm25:.1f}") | |
| c2.metric("Derived AQI (US EPA)", f"{aqi:.0f}") | |
| c3.metric("AQI Category", badge) | |
| st.subheader("Recent PM2.5 history (lag features)") | |
| history_cols = meta["feature_schema"]["history_features"] | |
| st.dataframe(feature_row[history_cols].to_frame().T, use_container_width=True, hide_index=True) | |
| st.divider() | |
| tab1, tab2 = st.tabs(["Local explanation (SHAP)", "Counterfactuals (DiCE)"]) | |
| with tab1: | |
| st.caption("Local post-hoc explanation for this single prediction.") | |
| st.info(""" | |
| **How to read this chart:** | |
| - Each bar shows how much a feature pushed the prediction higher (red/positive) or lower (blue/negative) | |
| - Longer bars = stronger impact on the prediction | |
| - The chart shows which historical PM2.5 measurements (recent days, weekly averages) had the biggest influence on today's predicted value | |
| """) | |
| fig = explain_local_shap(model=model, X_row=X) | |
| st.pyplot(fig, clear_figure=True) | |
| with tab2: | |
| st.caption("Counterfactual scenario analysis: small changes in lag features that would reduce predicted PM2.5.") | |
| desired = st.number_input( | |
| "Desired PM2.5 upper bound (ug/m3)", | |
| min_value=1.0, | |
| value=min(15.0, max(1.0, pred_pm25 - 5.0)), | |
| ) | |
| try: | |
| dice_path = artifacts_dir / "dice_data.parquet" | |
| if not dice_path.exists(): | |
| raise FileNotFoundError("Missing artifacts/dice_data.parquet (re-run training).") | |
| dice_data = pd.read_parquet(dice_path) | |
| cfs = generate_counterfactuals( | |
| model=model, | |
| meta=meta, | |
| dice_data=dice_data, | |
| query_X=X, | |
| desired_upper=desired, | |
| total_CFs=3, | |
| ) | |
| if cfs.empty: | |
| st.warning("No counterfactuals found. Try increasing the desired PM2.5 upper bound.") | |
| else: | |
| history_features = meta["feature_schema"]["history_features"] | |
| # Calculate feasibility metrics for each counterfactual | |
| scenarios = [] | |
| for idx, cf_row in cfs.iterrows(): | |
| # Calculate total change magnitude | |
| total_change = 0 | |
| num_features_changed = 0 | |
| max_change_pct = 0 | |
| for col in history_features: | |
| if col in feature_row.index and col in cf_row.index: | |
| original = feature_row[col] | |
| cf_val = cf_row[col] | |
| if abs(cf_val - original) > 0.01: # threshold for "changed" | |
| num_features_changed += 1 | |
| pct_change = abs((cf_val - original) / original * 100) if original != 0 else 0 | |
| max_change_pct = max(max_change_pct, pct_change) | |
| total_change += abs(cf_val - original) | |
| improvement = pred_pm25 - cf_row['pm25'] | |
| improvement_pct = (improvement / pred_pm25) * 100 if pred_pm25 != 0 else 0 | |
| # Feasibility = lower change is more feasible | |
| feasibility = 100 / (1 + total_change) if total_change > 0 else 100 | |
| scenarios.append({ | |
| 'Scenario': f"Scenario {len(scenarios)+1}", | |
| 'Target PM2.5': cf_row['pm25'], | |
| 'Improvement': f"{improvement:.1f} µg/m³ ({improvement_pct:.1f}%)", | |
| 'Features Changed': num_features_changed, | |
| 'Total Change': total_change, | |
| 'Feasibility Score': feasibility, | |
| 'Max % Change': max_change_pct, | |
| 'Improvement_Val': improvement, | |
| 'cf_row': cf_row | |
| }) | |
| # Sort by feasibility (most feasible first) | |
| scenarios_df = pd.DataFrame(scenarios).sort_values('Feasibility Score', ascending=False) | |
| # Show overview table with recommendation | |
| st.subheader("Scenario Overview") | |
| display_cols = ['Scenario', 'Target PM2.5', 'Improvement', 'Features Changed', 'Feasibility Score'] | |
| overview = scenarios_df[display_cols].copy() | |
| overview['Recommendation'] = ['⭐ Most Feasible' if i == 0 else '✓ Alternative' | |
| for i in range(len(overview))] | |
| st.dataframe( | |
| overview.style.background_gradient(subset=['Feasibility Score'], cmap='RdYlGn') | |
| .format({'Target PM2.5': '{:.1f}', 'Feasibility Score': '{:.1f}'}), | |
| use_container_width=True, | |
| hide_index=True | |
| ) | |
| # Effort vs Benefit Chart | |
| st.subheader("Scenario Trade-offs: Effort vs Benefit") | |
| import plotly.express as px | |
| plot_df = scenarios_df[['Scenario', 'Total Change', 'Improvement_Val']].copy() | |
| plot_df['Target PM2.5'] = scenarios_df['Target PM2.5'].values | |
| fig = px.scatter( | |
| plot_df, | |
| x='Total Change', | |
| y='Improvement_Val', | |
| text='Scenario', | |
| labels={'Total Change': 'Total Change Required (Effort)', 'Improvement_Val': 'Improvement (µg/m³)'}, | |
| hover_data={'Scenario': True, 'Target PM2.5': ':.1f', 'Total Change': ':.2f', 'Improvement_Val': ':.2f'} | |
| ) | |
| fig.update_traces(textposition='top center', marker=dict(size=15, color='#1f77b4')) | |
| fig.update_layout(height=400, showlegend=False) | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.caption("💡 Best scenarios are in the top-left: high improvement with low effort") | |
| # Expandable scenarios with natural language explanations | |
| st.divider() | |
| st.subheader("Detailed Scenarios") | |
| def get_feature_interpretation(col_name): | |
| """Convert feature name to human-readable interpretation.""" | |
| if 'lag1' in col_name: | |
| return "yesterday's PM2.5" | |
| elif 'lag7' in col_name: | |
| return "last week's PM2.5" | |
| elif 'roll3' in col_name: | |
| return "the 3-day rolling average PM2.5" | |
| elif 'roll7' in col_name: | |
| return "the 7-day rolling average PM2.5" | |
| else: | |
| return col_name | |
| def generate_scenario_explanation(scenario_data, cf_row_data, feature_row_data, history_feats): | |
| """Generate natural language explanation for a scenario.""" | |
| improvement_val = scenario_data['Improvement_Val'] | |
| improvement_pct = (improvement_val / pred_pm25 * 100) if pred_pm25 != 0 else 0 | |
| target_pm25 = scenario_data['Target PM2.5'] | |
| # Find significant changes | |
| changes = [] | |
| for col in history_feats: | |
| if col in feature_row_data.index and col in cf_row_data.index: | |
| original = feature_row_data[col] | |
| cf_val = cf_row_data[col] | |
| delta = cf_val - original | |
| if abs(delta) > 0.01: | |
| pct = (delta / original * 100) if original != 0 else 0 | |
| direction = "decrease" if delta < 0 else "increase" | |
| feature_name = get_feature_interpretation(col) | |
| changes.append({ | |
| 'feature': feature_name, | |
| 'delta': abs(delta), | |
| 'pct': abs(pct), | |
| 'direction': direction, | |
| 'original': original, | |
| 'new_val': cf_val | |
| }) | |
| # Sort by magnitude of change | |
| changes.sort(key=lambda x: x['delta'], reverse=True) | |
| # Build explanation | |
| explanation = f"To improve PM2.5 by **{improvement_val:.1f} µg/m³** ({improvement_pct:.1f}%) " | |
| explanation += f"and achieve a target of **{target_pm25:.1f} µg/m³**, " | |
| if len(changes) == 0: | |
| explanation += "no significant changes are needed." | |
| elif len(changes) == 1: | |
| c = changes[0] | |
| explanation += f"{c['feature']} must {c['direction']} by **{c['delta']:.2f} µg/m³** ({c['pct']:.1f}%), " | |
| explanation += f"from {c['original']:.2f} to {c['new_val']:.2f}." | |
| else: | |
| explanation += "the following changes are needed:\n\n" | |
| for i, c in enumerate(changes, 1): | |
| explanation += f"{i}. **{c['feature'].capitalize()}** must {c['direction']} by **{c['delta']:.2f} µg/m³** ({c['pct']:.1f}%), " | |
| explanation += f"from {c['original']:.2f} to {c['new_val']:.2f}\n" | |
| return explanation, changes | |
| # Create expander for each scenario | |
| for idx, row_data in scenarios_df.iterrows(): | |
| cf_row = row_data['cf_row'] | |
| is_best = idx == scenarios_df.index[0] | |
| expander_label = f"{row_data['Scenario']}: {row_data['Target PM2.5']:.1f} µg/m³ ({row_data['Improvement']})" | |
| if is_best: | |
| expander_label += " ⭐ Most Feasible" | |
| with st.expander(expander_label, expanded=is_best): | |
| # Generate natural language explanation | |
| explanation, changes = generate_scenario_explanation( | |
| row_data, cf_row, feature_row, history_features | |
| ) | |
| st.markdown("### Natural Language Summary") | |
| st.markdown(explanation) | |
| st.divider() | |
| # Metrics row | |
| col1, col2, col3, col4 = st.columns(4) | |
| col1.metric("Current PM2.5", f"{pred_pm25:.1f} µg/m³") | |
| col2.metric("Scenario PM2.5", f"{cf_row['pm25']:.1f} µg/m³", | |
| delta=f"{cf_row['pm25'] - pred_pm25:.1f}", delta_color="inverse") | |
| col3.metric("Improvement", row_data['Improvement']) | |
| col4.metric("Feasibility", f"{row_data['Feasibility Score']:.0f}/100") | |
| # Feature changes table | |
| if changes: | |
| st.write("**Detailed Feature Changes**") | |
| change_data = [] | |
| for c in changes: | |
| impact = 'Major' if c['pct'] > 20 else 'Moderate' if c['pct'] > 10 else 'Minor' | |
| change_data.append({ | |
| 'Feature': c['feature'].capitalize(), | |
| 'Current': f"{c['original']:.2f}", | |
| 'Scenario': f"{c['new_val']:.2f}", | |
| 'Change': f"{c['delta'] if c['direction'] == 'increase' else -c['delta']:+.2f} ({c['pct']:+.1f}%)", | |
| 'Impact': impact | |
| }) | |
| changes_table_df = pd.DataFrame(change_data) | |
| st.dataframe(changes_table_df, use_container_width=True, hide_index=True) | |
| # Contextual insight | |
| st.info(f""" | |
| **Interpretation**: This scenario shows that {row_data['Features Changed']} historical factors | |
| drive the current prediction. While we cannot change past PM2.5 levels, this reveals which | |
| temporal patterns (recent days vs weekly trends) have the strongest influence on today's air quality. | |
| """) | |
| else: | |
| st.write("No significant feature changes detected.") | |
| st.caption("Note: Lag features are historical context; treat counterfactuals as 'what-if' scenarios for understanding key drivers.") | |
| except Exception as e: | |
| st.warning(f"Counterfactual generation not available: {e}") | |