File size: 6,327 Bytes
0b1ac45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import streamlit as st
import time
import sys
import os
import gc
import torch
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from datetime import datetime, timedelta, timezone

# --- PAGE CONFIG ---
st.set_page_config(page_title="Canada Generative Radar (Earth2Studio)", layout="wide")

# --- VISIBLE LOGGING ---
status_container = st.empty()

def log_to_ui(msg, type="info"):
    """Helper to print logs to the UI and console."""
    print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True)
    if type == "info": status_container.info(f"πŸ“‹ {msg}")
    elif type == "success": status_container.success(f"βœ… {msg}")
    elif type == "error": status_container.error(f"❌ {msg}")
    elif type == "warning": status_container.warning(f"⚠️ {msg}")

log_to_ui("πŸš€ Initializing Radar App...")

# --- EARTH2STUDIO IMPORTS ---
try:
    # 1. Import Earth2Studio
    from earth2studio.models.px import StormScopeMRMS
    from earth2studio.utils.time import to_time_array
    
    # 2. Setup Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"βœ… Setup Complete. Device: {device}")
    
except ImportError as e:
    st.error(f"CRITICAL ERROR: {e}")
    st.info("The app is likely missing the 'StormScope' update. Check the Dockerfile installation logs.")
    st.stop()

# --- CONFIG ---
# Canadian Regions of Interest
REGIONS = {
    "Toronto / Southern Ontario": {"lat": 43.7, "lon": -79.4, "zoom": 4},
    "Montreal / Quebec":          {"lat": 45.5, "lon": -73.6, "zoom": 4},
    "Vancouver / BC":             {"lat": 49.3, "lon": -123.1, "zoom": 4},
    "Calgary / Alberta":          {"lat": 51.0, "lon": -114.1, "zoom": 5},
}

# --- MODEL MANAGERS (CACHED) ---

@st.cache_resource(show_spinner=False)
def load_radar_model():
    """
    Loads the StormScopeMRMS model. 
    Cached so we don't reload 4GB+ weights on every interaction.
    """
    gc.collect()
    torch.cuda.empty_cache()
    print("Loading StormScope Model...", flush=True)
    
    # Load Model Package
    package = StormScopeMRMS.load_default_package()
    model = StormScopeMRMS.load_model(package)
    model.to(device)
    model.eval()
    return model

@st.cache_data(show_spinner=False, ttl=1800) # Cache for 30 mins
def run_radar_generation(region_name, lat_center, lon_center):
    """
    Runs the generative radar inference for a specific location.
    Returns the image figure (matplotlib) to display.
    """
    model = load_radar_model()
    
    # 1. Setup Time (Use a recent past time to ensure data availability)
    # Real-time radar data often lags by 30-60 mins in public buckets
    now = datetime.now(timezone.utc) - timedelta(hours=1) 
    time_str = now.strftime("%Y-%m-%dT%H:00:00")
    time_obj = to_time_array([time_str])

    # 2. Fetch Input Data
    # StormScope has a built-in fetch_data method that grabs the necessary
    # initial conditions (usually previous radar frames) from the internet.
    try:
        # Note: This connects to NASA/NOAA servers. If they are down, this fails.
        x, coords = model.fetch_data(time_obj)
        x = x.to(device)
    except Exception as e:
        return None, f"Data Fetch Error (External Source): {str(e)}"

    # 3. Run Inference (Generate Next Frame)
    with torch.no_grad():
        out, out_coords = model(x, coords)

    # 4. Extract Data (Reflectivity)
    # Output shape: [Batch, Time, Lat, Lon] or [Batch, Channel, Lat, Lon]
    # StormScopeMRMS output channel 0 is typically reflectivity
    radar_data = out[0, 0, :, :].cpu().numpy()
    
    lats = out_coords['lat']
    lons = out_coords['lon']

    # 5. Crop / Focus on Canada Region requested
    # We simply return the full array and coords, and let the UI handle zooming via plotting
    return (radar_data, lats, lons), "Success"

# --- UI LAYOUT ---
st.title("πŸ‡¨πŸ‡¦ Canada Generative Radar (Earth2Studio)")
st.markdown("""
Using **NVIDIA Earth-2 StormScope** to generate high-resolution radar reflectivity.
*Note: This model is trained on US data but covers Southern Canada.*
""")

# Sidebar Controls
with st.sidebar:
    st.header("Settings")
    selected_region = st.selectbox("Choose Region", list(REGIONS.keys()))
    
    region_info = REGIONS[selected_region]
    lat_center = region_info["lat"]
    lon_center = region_info["lon"]
    zoom_deg = st.slider("Zoom (Degrees Radius)", 1.0, 10.0, 4.0)

# Main Execution Button
if st.button("πŸ“‘ Generate Radar Forecast", type="primary", use_container_width=True):
    
    log_to_ui(f"Fetching data and generating radar for {selected_region}...", type="info")
    
    t0 = time.time()
    
    # Run Inference
    result, msg = run_radar_generation(selected_region, lat_center, lon_center)
    
    if result is None:
        log_to_ui(msg, type="error")
    else:
        radar_data, lats, lons = result
        elapsed = time.time() - t0
        log_to_ui(f"Generation Complete in {elapsed:.2f}s", type="success")
        
        # Plotting
        st.subheader(f"Radar Reflectivity: {selected_region}")
        
        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
        
        # Map Features
        ax.add_feature(cfeature.COASTLINE, linewidth=1)
        ax.add_feature(cfeature.BORDERS, linestyle=':', linewidth=1)
        ax.add_feature(cfeature.LAKES, alpha=0.3, color='blue')
        
        # Set Extent (Zoom)
        extent = [lon_center - zoom_deg, lon_center + zoom_deg, 
                  lat_center - zoom_deg, lat_center + zoom_deg]
        ax.set_extent(extent, crs=ccrs.PlateCarree())
        
        # Plot Radar Data
        # Use a transparent colormap for low values (0 reflectivity)
        mesh = ax.pcolormesh(lons, lats, radar_data, 
                             transform=ccrs.PlateCarree(),
                             cmap='nipy_spectral', 
                             vmin=0, vmax=70,  # dBZ range
                             shading='auto')
        
        plt.colorbar(mesh, ax=ax, label='Reflectivity (dBZ)', shrink=0.7)
        plt.title(f"Simulated Radar | Center: {lat_center}, {lon_center}")
        
        st.pyplot(fig)

# Footer
st.markdown("---")
st.caption("Powered by NVIDIA Earth2Studio β€’ Runs on Hugging Face GPU Spaces")