import streamlit as st import time import sys import os import gc import torch import numpy as np import matplotlib.pyplot as plt import cartopy.crs as ccrs import cartopy.feature as cfeature from datetime import datetime, timedelta, timezone # --- PAGE CONFIG --- st.set_page_config(page_title="Canada Generative Radar (Earth2Studio)", layout="wide") # --- VISIBLE LOGGING --- status_container = st.empty() def log_to_ui(msg, type="info"): """Helper to print logs to the UI and console.""" print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True) if type == "info": status_container.info(f"📋 {msg}") elif type == "success": status_container.success(f"✅ {msg}") elif type == "error": status_container.error(f"❌ {msg}") elif type == "warning": status_container.warning(f"⚠️ {msg}") log_to_ui("🚀 Initializing Radar App...") # --- EARTH2STUDIO IMPORTS --- try: # 1. Import Earth2Studio from earth2studio.models.px import StormScopeMRMS from earth2studio.utils.time import to_time_array # 2. Setup Device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"✅ Setup Complete. Device: {device}") except ImportError as e: st.error(f"CRITICAL ERROR: {e}") st.info("The app is likely missing the 'StormScope' update. Check the Dockerfile installation logs.") st.stop() # --- CONFIG --- # Canadian Regions of Interest REGIONS = { "Toronto / Southern Ontario": {"lat": 43.7, "lon": -79.4, "zoom": 4}, "Montreal / Quebec": {"lat": 45.5, "lon": -73.6, "zoom": 4}, "Vancouver / BC": {"lat": 49.3, "lon": -123.1, "zoom": 4}, "Calgary / Alberta": {"lat": 51.0, "lon": -114.1, "zoom": 5}, } # --- MODEL MANAGERS (CACHED) --- @st.cache_resource(show_spinner=False) def load_radar_model(): """ Loads the StormScopeMRMS model. Cached so we don't reload 4GB+ weights on every interaction. """ gc.collect() torch.cuda.empty_cache() print("Loading StormScope Model...", flush=True) # Load Model Package package = StormScopeMRMS.load_default_package() model = StormScopeMRMS.load_model(package) model.to(device) model.eval() return model @st.cache_data(show_spinner=False, ttl=1800) # Cache for 30 mins def run_radar_generation(region_name, lat_center, lon_center): """ Runs the generative radar inference for a specific location. Returns the image figure (matplotlib) to display. """ model = load_radar_model() # 1. Setup Time (Use a recent past time to ensure data availability) # Real-time radar data often lags by 30-60 mins in public buckets now = datetime.now(timezone.utc) - timedelta(hours=1) time_str = now.strftime("%Y-%m-%dT%H:00:00") time_obj = to_time_array([time_str]) # 2. Fetch Input Data # StormScope has a built-in fetch_data method that grabs the necessary # initial conditions (usually previous radar frames) from the internet. try: # Note: This connects to NASA/NOAA servers. If they are down, this fails. x, coords = model.fetch_data(time_obj) x = x.to(device) except Exception as e: return None, f"Data Fetch Error (External Source): {str(e)}" # 3. Run Inference (Generate Next Frame) with torch.no_grad(): out, out_coords = model(x, coords) # 4. Extract Data (Reflectivity) # Output shape: [Batch, Time, Lat, Lon] or [Batch, Channel, Lat, Lon] # StormScopeMRMS output channel 0 is typically reflectivity radar_data = out[0, 0, :, :].cpu().numpy() lats = out_coords['lat'] lons = out_coords['lon'] # 5. Crop / Focus on Canada Region requested # We simply return the full array and coords, and let the UI handle zooming via plotting return (radar_data, lats, lons), "Success" # --- UI LAYOUT --- st.title("🇨🇦 Canada Generative Radar (Earth2Studio)") st.markdown(""" Using **NVIDIA Earth-2 StormScope** to generate high-resolution radar reflectivity. *Note: This model is trained on US data but covers Southern Canada.* """) # Sidebar Controls with st.sidebar: st.header("Settings") selected_region = st.selectbox("Choose Region", list(REGIONS.keys())) region_info = REGIONS[selected_region] lat_center = region_info["lat"] lon_center = region_info["lon"] zoom_deg = st.slider("Zoom (Degrees Radius)", 1.0, 10.0, 4.0) # Main Execution Button if st.button("📡 Generate Radar Forecast", type="primary", use_container_width=True): log_to_ui(f"Fetching data and generating radar for {selected_region}...", type="info") t0 = time.time() # Run Inference result, msg = run_radar_generation(selected_region, lat_center, lon_center) if result is None: log_to_ui(msg, type="error") else: radar_data, lats, lons = result elapsed = time.time() - t0 log_to_ui(f"Generation Complete in {elapsed:.2f}s", type="success") # Plotting st.subheader(f"Radar Reflectivity: {selected_region}") fig = plt.figure(figsize=(10, 8)) ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree()) # Map Features ax.add_feature(cfeature.COASTLINE, linewidth=1) ax.add_feature(cfeature.BORDERS, linestyle=':', linewidth=1) ax.add_feature(cfeature.LAKES, alpha=0.3, color='blue') # Set Extent (Zoom) extent = [lon_center - zoom_deg, lon_center + zoom_deg, lat_center - zoom_deg, lat_center + zoom_deg] ax.set_extent(extent, crs=ccrs.PlateCarree()) # Plot Radar Data # Use a transparent colormap for low values (0 reflectivity) mesh = ax.pcolormesh(lons, lats, radar_data, transform=ccrs.PlateCarree(), cmap='nipy_spectral', vmin=0, vmax=70, # dBZ range shading='auto') plt.colorbar(mesh, ax=ax, label='Reflectivity (dBZ)', shrink=0.7) plt.title(f"Simulated Radar | Center: {lat_center}, {lon_center}") st.pyplot(fig) # Footer st.markdown("---") st.caption("Powered by NVIDIA Earth2Studio • Runs on Hugging Face GPU Spaces")