| | import streamlit as st |
| | import time |
| | import sys |
| | import os |
| | import gc |
| | import torch |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import cartopy.crs as ccrs |
| | import cartopy.feature as cfeature |
| | from datetime import datetime, timedelta, timezone |
| |
|
| | |
| | st.set_page_config(page_title="Canada Generative Radar (Earth2Studio)", layout="wide") |
| |
|
| | |
| | status_container = st.empty() |
| |
|
| | def log_to_ui(msg, type="info"): |
| | """Helper to print logs to the UI and console.""" |
| | print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True) |
| | if type == "info": status_container.info(f"π {msg}") |
| | elif type == "success": status_container.success(f"β
{msg}") |
| | elif type == "error": status_container.error(f"β {msg}") |
| | elif type == "warning": status_container.warning(f"β οΈ {msg}") |
| |
|
| | log_to_ui("π Initializing Radar App...") |
| |
|
| | |
| | try: |
| | |
| | from earth2studio.models.px import StormScopeMRMS |
| | from earth2studio.utils.time import to_time_array |
| | |
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"β
Setup Complete. Device: {device}") |
| | |
| | except ImportError as e: |
| | st.error(f"CRITICAL ERROR: {e}") |
| | st.info("The app is likely missing the 'StormScope' update. Check the Dockerfile installation logs.") |
| | st.stop() |
| |
|
| | |
| | |
| | REGIONS = { |
| | "Toronto / Southern Ontario": {"lat": 43.7, "lon": -79.4, "zoom": 4}, |
| | "Montreal / Quebec": {"lat": 45.5, "lon": -73.6, "zoom": 4}, |
| | "Vancouver / BC": {"lat": 49.3, "lon": -123.1, "zoom": 4}, |
| | "Calgary / Alberta": {"lat": 51.0, "lon": -114.1, "zoom": 5}, |
| | } |
| |
|
| | |
| |
|
| | @st.cache_resource(show_spinner=False) |
| | def load_radar_model(): |
| | """ |
| | Loads the StormScopeMRMS model. |
| | Cached so we don't reload 4GB+ weights on every interaction. |
| | """ |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | print("Loading StormScope Model...", flush=True) |
| | |
| | |
| | package = StormScopeMRMS.load_default_package() |
| | model = StormScopeMRMS.load_model(package) |
| | model.to(device) |
| | model.eval() |
| | return model |
| |
|
| | @st.cache_data(show_spinner=False, ttl=1800) |
| | def run_radar_generation(region_name, lat_center, lon_center): |
| | """ |
| | Runs the generative radar inference for a specific location. |
| | Returns the image figure (matplotlib) to display. |
| | """ |
| | model = load_radar_model() |
| | |
| | |
| | |
| | now = datetime.now(timezone.utc) - timedelta(hours=1) |
| | time_str = now.strftime("%Y-%m-%dT%H:00:00") |
| | time_obj = to_time_array([time_str]) |
| |
|
| | |
| | |
| | |
| | try: |
| | |
| | x, coords = model.fetch_data(time_obj) |
| | x = x.to(device) |
| | except Exception as e: |
| | return None, f"Data Fetch Error (External Source): {str(e)}" |
| |
|
| | |
| | with torch.no_grad(): |
| | out, out_coords = model(x, coords) |
| |
|
| | |
| | |
| | |
| | radar_data = out[0, 0, :, :].cpu().numpy() |
| | |
| | lats = out_coords['lat'] |
| | lons = out_coords['lon'] |
| |
|
| | |
| | |
| | return (radar_data, lats, lons), "Success" |
| |
|
| | |
| | st.title("π¨π¦ Canada Generative Radar (Earth2Studio)") |
| | st.markdown(""" |
| | Using **NVIDIA Earth-2 StormScope** to generate high-resolution radar reflectivity. |
| | *Note: This model is trained on US data but covers Southern Canada.* |
| | """) |
| |
|
| | |
| | with st.sidebar: |
| | st.header("Settings") |
| | selected_region = st.selectbox("Choose Region", list(REGIONS.keys())) |
| | |
| | region_info = REGIONS[selected_region] |
| | lat_center = region_info["lat"] |
| | lon_center = region_info["lon"] |
| | zoom_deg = st.slider("Zoom (Degrees Radius)", 1.0, 10.0, 4.0) |
| |
|
| | |
| | if st.button("π‘ Generate Radar Forecast", type="primary", use_container_width=True): |
| | |
| | log_to_ui(f"Fetching data and generating radar for {selected_region}...", type="info") |
| | |
| | t0 = time.time() |
| | |
| | |
| | result, msg = run_radar_generation(selected_region, lat_center, lon_center) |
| | |
| | if result is None: |
| | log_to_ui(msg, type="error") |
| | else: |
| | radar_data, lats, lons = result |
| | elapsed = time.time() - t0 |
| | log_to_ui(f"Generation Complete in {elapsed:.2f}s", type="success") |
| | |
| | |
| | st.subheader(f"Radar Reflectivity: {selected_region}") |
| | |
| | fig = plt.figure(figsize=(10, 8)) |
| | ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree()) |
| | |
| | |
| | ax.add_feature(cfeature.COASTLINE, linewidth=1) |
| | ax.add_feature(cfeature.BORDERS, linestyle=':', linewidth=1) |
| | ax.add_feature(cfeature.LAKES, alpha=0.3, color='blue') |
| | |
| | |
| | extent = [lon_center - zoom_deg, lon_center + zoom_deg, |
| | lat_center - zoom_deg, lat_center + zoom_deg] |
| | ax.set_extent(extent, crs=ccrs.PlateCarree()) |
| | |
| | |
| | |
| | mesh = ax.pcolormesh(lons, lats, radar_data, |
| | transform=ccrs.PlateCarree(), |
| | cmap='nipy_spectral', |
| | vmin=0, vmax=70, |
| | shading='auto') |
| | |
| | plt.colorbar(mesh, ax=ax, label='Reflectivity (dBZ)', shrink=0.7) |
| | plt.title(f"Simulated Radar | Center: {lat_center}, {lon_center}") |
| | |
| | st.pyplot(fig) |
| |
|
| | |
| | st.markdown("---") |
| | st.caption("Powered by NVIDIA Earth2Studio β’ Runs on Hugging Face GPU Spaces") |