test / app.py
NS-Genai's picture
Upload 3 files
0b1ac45 verified
import streamlit as st
import time
import sys
import os
import gc
import torch
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from datetime import datetime, timedelta, timezone
# --- PAGE CONFIG ---
st.set_page_config(page_title="Canada Generative Radar (Earth2Studio)", layout="wide")
# --- VISIBLE LOGGING ---
status_container = st.empty()
def log_to_ui(msg, type="info"):
"""Helper to print logs to the UI and console."""
print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True)
if type == "info": status_container.info(f"πŸ“‹ {msg}")
elif type == "success": status_container.success(f"βœ… {msg}")
elif type == "error": status_container.error(f"❌ {msg}")
elif type == "warning": status_container.warning(f"⚠️ {msg}")
log_to_ui("πŸš€ Initializing Radar App...")
# --- EARTH2STUDIO IMPORTS ---
try:
# 1. Import Earth2Studio
from earth2studio.models.px import StormScopeMRMS
from earth2studio.utils.time import to_time_array
# 2. Setup Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"βœ… Setup Complete. Device: {device}")
except ImportError as e:
st.error(f"CRITICAL ERROR: {e}")
st.info("The app is likely missing the 'StormScope' update. Check the Dockerfile installation logs.")
st.stop()
# --- CONFIG ---
# Canadian Regions of Interest
REGIONS = {
"Toronto / Southern Ontario": {"lat": 43.7, "lon": -79.4, "zoom": 4},
"Montreal / Quebec": {"lat": 45.5, "lon": -73.6, "zoom": 4},
"Vancouver / BC": {"lat": 49.3, "lon": -123.1, "zoom": 4},
"Calgary / Alberta": {"lat": 51.0, "lon": -114.1, "zoom": 5},
}
# --- MODEL MANAGERS (CACHED) ---
@st.cache_resource(show_spinner=False)
def load_radar_model():
"""
Loads the StormScopeMRMS model.
Cached so we don't reload 4GB+ weights on every interaction.
"""
gc.collect()
torch.cuda.empty_cache()
print("Loading StormScope Model...", flush=True)
# Load Model Package
package = StormScopeMRMS.load_default_package()
model = StormScopeMRMS.load_model(package)
model.to(device)
model.eval()
return model
@st.cache_data(show_spinner=False, ttl=1800) # Cache for 30 mins
def run_radar_generation(region_name, lat_center, lon_center):
"""
Runs the generative radar inference for a specific location.
Returns the image figure (matplotlib) to display.
"""
model = load_radar_model()
# 1. Setup Time (Use a recent past time to ensure data availability)
# Real-time radar data often lags by 30-60 mins in public buckets
now = datetime.now(timezone.utc) - timedelta(hours=1)
time_str = now.strftime("%Y-%m-%dT%H:00:00")
time_obj = to_time_array([time_str])
# 2. Fetch Input Data
# StormScope has a built-in fetch_data method that grabs the necessary
# initial conditions (usually previous radar frames) from the internet.
try:
# Note: This connects to NASA/NOAA servers. If they are down, this fails.
x, coords = model.fetch_data(time_obj)
x = x.to(device)
except Exception as e:
return None, f"Data Fetch Error (External Source): {str(e)}"
# 3. Run Inference (Generate Next Frame)
with torch.no_grad():
out, out_coords = model(x, coords)
# 4. Extract Data (Reflectivity)
# Output shape: [Batch, Time, Lat, Lon] or [Batch, Channel, Lat, Lon]
# StormScopeMRMS output channel 0 is typically reflectivity
radar_data = out[0, 0, :, :].cpu().numpy()
lats = out_coords['lat']
lons = out_coords['lon']
# 5. Crop / Focus on Canada Region requested
# We simply return the full array and coords, and let the UI handle zooming via plotting
return (radar_data, lats, lons), "Success"
# --- UI LAYOUT ---
st.title("πŸ‡¨πŸ‡¦ Canada Generative Radar (Earth2Studio)")
st.markdown("""
Using **NVIDIA Earth-2 StormScope** to generate high-resolution radar reflectivity.
*Note: This model is trained on US data but covers Southern Canada.*
""")
# Sidebar Controls
with st.sidebar:
st.header("Settings")
selected_region = st.selectbox("Choose Region", list(REGIONS.keys()))
region_info = REGIONS[selected_region]
lat_center = region_info["lat"]
lon_center = region_info["lon"]
zoom_deg = st.slider("Zoom (Degrees Radius)", 1.0, 10.0, 4.0)
# Main Execution Button
if st.button("πŸ“‘ Generate Radar Forecast", type="primary", use_container_width=True):
log_to_ui(f"Fetching data and generating radar for {selected_region}...", type="info")
t0 = time.time()
# Run Inference
result, msg = run_radar_generation(selected_region, lat_center, lon_center)
if result is None:
log_to_ui(msg, type="error")
else:
radar_data, lats, lons = result
elapsed = time.time() - t0
log_to_ui(f"Generation Complete in {elapsed:.2f}s", type="success")
# Plotting
st.subheader(f"Radar Reflectivity: {selected_region}")
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
# Map Features
ax.add_feature(cfeature.COASTLINE, linewidth=1)
ax.add_feature(cfeature.BORDERS, linestyle=':', linewidth=1)
ax.add_feature(cfeature.LAKES, alpha=0.3, color='blue')
# Set Extent (Zoom)
extent = [lon_center - zoom_deg, lon_center + zoom_deg,
lat_center - zoom_deg, lat_center + zoom_deg]
ax.set_extent(extent, crs=ccrs.PlateCarree())
# Plot Radar Data
# Use a transparent colormap for low values (0 reflectivity)
mesh = ax.pcolormesh(lons, lats, radar_data,
transform=ccrs.PlateCarree(),
cmap='nipy_spectral',
vmin=0, vmax=70, # dBZ range
shading='auto')
plt.colorbar(mesh, ax=ax, label='Reflectivity (dBZ)', shrink=0.7)
plt.title(f"Simulated Radar | Center: {lat_center}, {lon_center}")
st.pyplot(fig)
# Footer
st.markdown("---")
st.caption("Powered by NVIDIA Earth2Studio β€’ Runs on Hugging Face GPU Spaces")