Upload 3 files
Browse files- Dockerfile +47 -0
- app.py +177 -0
- requirements.txt +15 -0
Dockerfile
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 1. Use Python 3.10 (Standard for HF Spaces)
|
| 2 |
+
FROM python:3.10-slim
|
| 3 |
+
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# 2. Install System Dependencies
|
| 7 |
+
# Required for Cartopy (maps) and Git (for installing from source)
|
| 8 |
+
RUN apt-get update && apt-get install -y \
|
| 9 |
+
build-essential \
|
| 10 |
+
curl \
|
| 11 |
+
git \
|
| 12 |
+
libgeos-dev \
|
| 13 |
+
libproj-dev \
|
| 14 |
+
proj-bin \
|
| 15 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 16 |
+
|
| 17 |
+
# 3. CRITICAL: Install NATTEN from Pre-built Wheels
|
| 18 |
+
# If we don't do this, pip tries to compile it from source (takes 20+ mins) and fails.
|
| 19 |
+
# We match the CUDA version (cu121) and Python version (cp310).
|
| 20 |
+
RUN pip install natten==0.17.3 -f https://shi-labs.com/natten/wheels/cu121/
|
| 21 |
+
|
| 22 |
+
# 4. CRITICAL: Install Earth2Studio from GitHub
|
| 23 |
+
# We use the [stormscope] tag to tell it we have the dependencies ready.
|
| 24 |
+
RUN pip install "earth2studio[stormscope] @ git+https://github.com/NVIDIA/earth2studio.git"
|
| 25 |
+
|
| 26 |
+
# 5. Install Other Python Dependencies
|
| 27 |
+
# (Streamlit, Maps, etc.)
|
| 28 |
+
RUN pip install \
|
| 29 |
+
streamlit \
|
| 30 |
+
torch \
|
| 31 |
+
torchvision \
|
| 32 |
+
numpy \
|
| 33 |
+
matplotlib \
|
| 34 |
+
cartopy \
|
| 35 |
+
huggingface_hub \
|
| 36 |
+
scipy
|
| 37 |
+
|
| 38 |
+
# 6. Copy App Code
|
| 39 |
+
COPY . .
|
| 40 |
+
|
| 41 |
+
# 7. Launch App
|
| 42 |
+
CMD ["streamlit", "run", "app.py", \
|
| 43 |
+
"--server.port", "7860", \
|
| 44 |
+
"--server.address", "0.0.0.0", \
|
| 45 |
+
"--server.enableCORS", "false", \
|
| 46 |
+
"--server.enableXsrfProtection", "false", \
|
| 47 |
+
"--server.fileWatcherType", "none"]
|
app.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import time
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
import gc
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import cartopy.crs as ccrs
|
| 10 |
+
import cartopy.feature as cfeature
|
| 11 |
+
from datetime import datetime, timedelta, timezone
|
| 12 |
+
|
| 13 |
+
# --- PAGE CONFIG ---
|
| 14 |
+
st.set_page_config(page_title="Canada Generative Radar (Earth2Studio)", layout="wide")
|
| 15 |
+
|
| 16 |
+
# --- VISIBLE LOGGING ---
|
| 17 |
+
status_container = st.empty()
|
| 18 |
+
|
| 19 |
+
def log_to_ui(msg, type="info"):
|
| 20 |
+
"""Helper to print logs to the UI and console."""
|
| 21 |
+
print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True)
|
| 22 |
+
if type == "info": status_container.info(f"π {msg}")
|
| 23 |
+
elif type == "success": status_container.success(f"β
{msg}")
|
| 24 |
+
elif type == "error": status_container.error(f"β {msg}")
|
| 25 |
+
elif type == "warning": status_container.warning(f"β οΈ {msg}")
|
| 26 |
+
|
| 27 |
+
log_to_ui("π Initializing Radar App...")
|
| 28 |
+
|
| 29 |
+
# --- EARTH2STUDIO IMPORTS ---
|
| 30 |
+
try:
|
| 31 |
+
# 1. Import Earth2Studio
|
| 32 |
+
from earth2studio.models.px import StormScopeMRMS
|
| 33 |
+
from earth2studio.utils.time import to_time_array
|
| 34 |
+
|
| 35 |
+
# 2. Setup Device
|
| 36 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 37 |
+
print(f"β
Setup Complete. Device: {device}")
|
| 38 |
+
|
| 39 |
+
except ImportError as e:
|
| 40 |
+
st.error(f"CRITICAL ERROR: {e}")
|
| 41 |
+
st.info("The app is likely missing the 'StormScope' update. Check the Dockerfile installation logs.")
|
| 42 |
+
st.stop()
|
| 43 |
+
|
| 44 |
+
# --- CONFIG ---
|
| 45 |
+
# Canadian Regions of Interest
|
| 46 |
+
REGIONS = {
|
| 47 |
+
"Toronto / Southern Ontario": {"lat": 43.7, "lon": -79.4, "zoom": 4},
|
| 48 |
+
"Montreal / Quebec": {"lat": 45.5, "lon": -73.6, "zoom": 4},
|
| 49 |
+
"Vancouver / BC": {"lat": 49.3, "lon": -123.1, "zoom": 4},
|
| 50 |
+
"Calgary / Alberta": {"lat": 51.0, "lon": -114.1, "zoom": 5},
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
# --- MODEL MANAGERS (CACHED) ---
|
| 54 |
+
|
| 55 |
+
@st.cache_resource(show_spinner=False)
|
| 56 |
+
def load_radar_model():
|
| 57 |
+
"""
|
| 58 |
+
Loads the StormScopeMRMS model.
|
| 59 |
+
Cached so we don't reload 4GB+ weights on every interaction.
|
| 60 |
+
"""
|
| 61 |
+
gc.collect()
|
| 62 |
+
torch.cuda.empty_cache()
|
| 63 |
+
print("Loading StormScope Model...", flush=True)
|
| 64 |
+
|
| 65 |
+
# Load Model Package
|
| 66 |
+
package = StormScopeMRMS.load_default_package()
|
| 67 |
+
model = StormScopeMRMS.load_model(package)
|
| 68 |
+
model.to(device)
|
| 69 |
+
model.eval()
|
| 70 |
+
return model
|
| 71 |
+
|
| 72 |
+
@st.cache_data(show_spinner=False, ttl=1800) # Cache for 30 mins
|
| 73 |
+
def run_radar_generation(region_name, lat_center, lon_center):
|
| 74 |
+
"""
|
| 75 |
+
Runs the generative radar inference for a specific location.
|
| 76 |
+
Returns the image figure (matplotlib) to display.
|
| 77 |
+
"""
|
| 78 |
+
model = load_radar_model()
|
| 79 |
+
|
| 80 |
+
# 1. Setup Time (Use a recent past time to ensure data availability)
|
| 81 |
+
# Real-time radar data often lags by 30-60 mins in public buckets
|
| 82 |
+
now = datetime.now(timezone.utc) - timedelta(hours=1)
|
| 83 |
+
time_str = now.strftime("%Y-%m-%dT%H:00:00")
|
| 84 |
+
time_obj = to_time_array([time_str])
|
| 85 |
+
|
| 86 |
+
# 2. Fetch Input Data
|
| 87 |
+
# StormScope has a built-in fetch_data method that grabs the necessary
|
| 88 |
+
# initial conditions (usually previous radar frames) from the internet.
|
| 89 |
+
try:
|
| 90 |
+
# Note: This connects to NASA/NOAA servers. If they are down, this fails.
|
| 91 |
+
x, coords = model.fetch_data(time_obj)
|
| 92 |
+
x = x.to(device)
|
| 93 |
+
except Exception as e:
|
| 94 |
+
return None, f"Data Fetch Error (External Source): {str(e)}"
|
| 95 |
+
|
| 96 |
+
# 3. Run Inference (Generate Next Frame)
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
out, out_coords = model(x, coords)
|
| 99 |
+
|
| 100 |
+
# 4. Extract Data (Reflectivity)
|
| 101 |
+
# Output shape: [Batch, Time, Lat, Lon] or [Batch, Channel, Lat, Lon]
|
| 102 |
+
# StormScopeMRMS output channel 0 is typically reflectivity
|
| 103 |
+
radar_data = out[0, 0, :, :].cpu().numpy()
|
| 104 |
+
|
| 105 |
+
lats = out_coords['lat']
|
| 106 |
+
lons = out_coords['lon']
|
| 107 |
+
|
| 108 |
+
# 5. Crop / Focus on Canada Region requested
|
| 109 |
+
# We simply return the full array and coords, and let the UI handle zooming via plotting
|
| 110 |
+
return (radar_data, lats, lons), "Success"
|
| 111 |
+
|
| 112 |
+
# --- UI LAYOUT ---
|
| 113 |
+
st.title("π¨π¦ Canada Generative Radar (Earth2Studio)")
|
| 114 |
+
st.markdown("""
|
| 115 |
+
Using **NVIDIA Earth-2 StormScope** to generate high-resolution radar reflectivity.
|
| 116 |
+
*Note: This model is trained on US data but covers Southern Canada.*
|
| 117 |
+
""")
|
| 118 |
+
|
| 119 |
+
# Sidebar Controls
|
| 120 |
+
with st.sidebar:
|
| 121 |
+
st.header("Settings")
|
| 122 |
+
selected_region = st.selectbox("Choose Region", list(REGIONS.keys()))
|
| 123 |
+
|
| 124 |
+
region_info = REGIONS[selected_region]
|
| 125 |
+
lat_center = region_info["lat"]
|
| 126 |
+
lon_center = region_info["lon"]
|
| 127 |
+
zoom_deg = st.slider("Zoom (Degrees Radius)", 1.0, 10.0, 4.0)
|
| 128 |
+
|
| 129 |
+
# Main Execution Button
|
| 130 |
+
if st.button("π‘ Generate Radar Forecast", type="primary", use_container_width=True):
|
| 131 |
+
|
| 132 |
+
log_to_ui(f"Fetching data and generating radar for {selected_region}...", type="info")
|
| 133 |
+
|
| 134 |
+
t0 = time.time()
|
| 135 |
+
|
| 136 |
+
# Run Inference
|
| 137 |
+
result, msg = run_radar_generation(selected_region, lat_center, lon_center)
|
| 138 |
+
|
| 139 |
+
if result is None:
|
| 140 |
+
log_to_ui(msg, type="error")
|
| 141 |
+
else:
|
| 142 |
+
radar_data, lats, lons = result
|
| 143 |
+
elapsed = time.time() - t0
|
| 144 |
+
log_to_ui(f"Generation Complete in {elapsed:.2f}s", type="success")
|
| 145 |
+
|
| 146 |
+
# Plotting
|
| 147 |
+
st.subheader(f"Radar Reflectivity: {selected_region}")
|
| 148 |
+
|
| 149 |
+
fig = plt.figure(figsize=(10, 8))
|
| 150 |
+
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
|
| 151 |
+
|
| 152 |
+
# Map Features
|
| 153 |
+
ax.add_feature(cfeature.COASTLINE, linewidth=1)
|
| 154 |
+
ax.add_feature(cfeature.BORDERS, linestyle=':', linewidth=1)
|
| 155 |
+
ax.add_feature(cfeature.LAKES, alpha=0.3, color='blue')
|
| 156 |
+
|
| 157 |
+
# Set Extent (Zoom)
|
| 158 |
+
extent = [lon_center - zoom_deg, lon_center + zoom_deg,
|
| 159 |
+
lat_center - zoom_deg, lat_center + zoom_deg]
|
| 160 |
+
ax.set_extent(extent, crs=ccrs.PlateCarree())
|
| 161 |
+
|
| 162 |
+
# Plot Radar Data
|
| 163 |
+
# Use a transparent colormap for low values (0 reflectivity)
|
| 164 |
+
mesh = ax.pcolormesh(lons, lats, radar_data,
|
| 165 |
+
transform=ccrs.PlateCarree(),
|
| 166 |
+
cmap='nipy_spectral',
|
| 167 |
+
vmin=0, vmax=70, # dBZ range
|
| 168 |
+
shading='auto')
|
| 169 |
+
|
| 170 |
+
plt.colorbar(mesh, ax=ax, label='Reflectivity (dBZ)', shrink=0.7)
|
| 171 |
+
plt.title(f"Simulated Radar | Center: {lat_center}, {lon_center}")
|
| 172 |
+
|
| 173 |
+
st.pyplot(fig)
|
| 174 |
+
|
| 175 |
+
# Footer
|
| 176 |
+
st.markdown("---")
|
| 177 |
+
st.caption("Powered by NVIDIA Earth2Studio β’ Runs on Hugging Face GPU Spaces")
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit
|
| 2 |
+
pandas
|
| 3 |
+
numpy
|
| 4 |
+
requests
|
| 5 |
+
altair
|
| 6 |
+
earth2studio
|
| 7 |
+
torch
|
| 8 |
+
matplotlib
|
| 9 |
+
cartopy
|
| 10 |
+
h5netcdf
|
| 11 |
+
zarr
|
| 12 |
+
onnx
|
| 13 |
+
onnxruntime-gpu
|
| 14 |
+
einops
|
| 15 |
+
nvidia-physicsnemo
|