NS-Genai commited on
Commit
0b1ac45
Β·
verified Β·
1 Parent(s): 0f7029e

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +47 -0
  2. app.py +177 -0
  3. requirements.txt +15 -0
Dockerfile ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1. Use Python 3.10 (Standard for HF Spaces)
2
+ FROM python:3.10-slim
3
+
4
+ WORKDIR /app
5
+
6
+ # 2. Install System Dependencies
7
+ # Required for Cartopy (maps) and Git (for installing from source)
8
+ RUN apt-get update && apt-get install -y \
9
+ build-essential \
10
+ curl \
11
+ git \
12
+ libgeos-dev \
13
+ libproj-dev \
14
+ proj-bin \
15
+ && rm -rf /var/lib/apt/lists/*
16
+
17
+ # 3. CRITICAL: Install NATTEN from Pre-built Wheels
18
+ # If we don't do this, pip tries to compile it from source (takes 20+ mins) and fails.
19
+ # We match the CUDA version (cu121) and Python version (cp310).
20
+ RUN pip install natten==0.17.3 -f https://shi-labs.com/natten/wheels/cu121/
21
+
22
+ # 4. CRITICAL: Install Earth2Studio from GitHub
23
+ # We use the [stormscope] tag to tell it we have the dependencies ready.
24
+ RUN pip install "earth2studio[stormscope] @ git+https://github.com/NVIDIA/earth2studio.git"
25
+
26
+ # 5. Install Other Python Dependencies
27
+ # (Streamlit, Maps, etc.)
28
+ RUN pip install \
29
+ streamlit \
30
+ torch \
31
+ torchvision \
32
+ numpy \
33
+ matplotlib \
34
+ cartopy \
35
+ huggingface_hub \
36
+ scipy
37
+
38
+ # 6. Copy App Code
39
+ COPY . .
40
+
41
+ # 7. Launch App
42
+ CMD ["streamlit", "run", "app.py", \
43
+ "--server.port", "7860", \
44
+ "--server.address", "0.0.0.0", \
45
+ "--server.enableCORS", "false", \
46
+ "--server.enableXsrfProtection", "false", \
47
+ "--server.fileWatcherType", "none"]
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import time
3
+ import sys
4
+ import os
5
+ import gc
6
+ import torch
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ import cartopy.crs as ccrs
10
+ import cartopy.feature as cfeature
11
+ from datetime import datetime, timedelta, timezone
12
+
13
+ # --- PAGE CONFIG ---
14
+ st.set_page_config(page_title="Canada Generative Radar (Earth2Studio)", layout="wide")
15
+
16
+ # --- VISIBLE LOGGING ---
17
+ status_container = st.empty()
18
+
19
+ def log_to_ui(msg, type="info"):
20
+ """Helper to print logs to the UI and console."""
21
+ print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True)
22
+ if type == "info": status_container.info(f"πŸ“‹ {msg}")
23
+ elif type == "success": status_container.success(f"βœ… {msg}")
24
+ elif type == "error": status_container.error(f"❌ {msg}")
25
+ elif type == "warning": status_container.warning(f"⚠️ {msg}")
26
+
27
+ log_to_ui("πŸš€ Initializing Radar App...")
28
+
29
+ # --- EARTH2STUDIO IMPORTS ---
30
+ try:
31
+ # 1. Import Earth2Studio
32
+ from earth2studio.models.px import StormScopeMRMS
33
+ from earth2studio.utils.time import to_time_array
34
+
35
+ # 2. Setup Device
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ print(f"βœ… Setup Complete. Device: {device}")
38
+
39
+ except ImportError as e:
40
+ st.error(f"CRITICAL ERROR: {e}")
41
+ st.info("The app is likely missing the 'StormScope' update. Check the Dockerfile installation logs.")
42
+ st.stop()
43
+
44
+ # --- CONFIG ---
45
+ # Canadian Regions of Interest
46
+ REGIONS = {
47
+ "Toronto / Southern Ontario": {"lat": 43.7, "lon": -79.4, "zoom": 4},
48
+ "Montreal / Quebec": {"lat": 45.5, "lon": -73.6, "zoom": 4},
49
+ "Vancouver / BC": {"lat": 49.3, "lon": -123.1, "zoom": 4},
50
+ "Calgary / Alberta": {"lat": 51.0, "lon": -114.1, "zoom": 5},
51
+ }
52
+
53
+ # --- MODEL MANAGERS (CACHED) ---
54
+
55
+ @st.cache_resource(show_spinner=False)
56
+ def load_radar_model():
57
+ """
58
+ Loads the StormScopeMRMS model.
59
+ Cached so we don't reload 4GB+ weights on every interaction.
60
+ """
61
+ gc.collect()
62
+ torch.cuda.empty_cache()
63
+ print("Loading StormScope Model...", flush=True)
64
+
65
+ # Load Model Package
66
+ package = StormScopeMRMS.load_default_package()
67
+ model = StormScopeMRMS.load_model(package)
68
+ model.to(device)
69
+ model.eval()
70
+ return model
71
+
72
+ @st.cache_data(show_spinner=False, ttl=1800) # Cache for 30 mins
73
+ def run_radar_generation(region_name, lat_center, lon_center):
74
+ """
75
+ Runs the generative radar inference for a specific location.
76
+ Returns the image figure (matplotlib) to display.
77
+ """
78
+ model = load_radar_model()
79
+
80
+ # 1. Setup Time (Use a recent past time to ensure data availability)
81
+ # Real-time radar data often lags by 30-60 mins in public buckets
82
+ now = datetime.now(timezone.utc) - timedelta(hours=1)
83
+ time_str = now.strftime("%Y-%m-%dT%H:00:00")
84
+ time_obj = to_time_array([time_str])
85
+
86
+ # 2. Fetch Input Data
87
+ # StormScope has a built-in fetch_data method that grabs the necessary
88
+ # initial conditions (usually previous radar frames) from the internet.
89
+ try:
90
+ # Note: This connects to NASA/NOAA servers. If they are down, this fails.
91
+ x, coords = model.fetch_data(time_obj)
92
+ x = x.to(device)
93
+ except Exception as e:
94
+ return None, f"Data Fetch Error (External Source): {str(e)}"
95
+
96
+ # 3. Run Inference (Generate Next Frame)
97
+ with torch.no_grad():
98
+ out, out_coords = model(x, coords)
99
+
100
+ # 4. Extract Data (Reflectivity)
101
+ # Output shape: [Batch, Time, Lat, Lon] or [Batch, Channel, Lat, Lon]
102
+ # StormScopeMRMS output channel 0 is typically reflectivity
103
+ radar_data = out[0, 0, :, :].cpu().numpy()
104
+
105
+ lats = out_coords['lat']
106
+ lons = out_coords['lon']
107
+
108
+ # 5. Crop / Focus on Canada Region requested
109
+ # We simply return the full array and coords, and let the UI handle zooming via plotting
110
+ return (radar_data, lats, lons), "Success"
111
+
112
+ # --- UI LAYOUT ---
113
+ st.title("πŸ‡¨πŸ‡¦ Canada Generative Radar (Earth2Studio)")
114
+ st.markdown("""
115
+ Using **NVIDIA Earth-2 StormScope** to generate high-resolution radar reflectivity.
116
+ *Note: This model is trained on US data but covers Southern Canada.*
117
+ """)
118
+
119
+ # Sidebar Controls
120
+ with st.sidebar:
121
+ st.header("Settings")
122
+ selected_region = st.selectbox("Choose Region", list(REGIONS.keys()))
123
+
124
+ region_info = REGIONS[selected_region]
125
+ lat_center = region_info["lat"]
126
+ lon_center = region_info["lon"]
127
+ zoom_deg = st.slider("Zoom (Degrees Radius)", 1.0, 10.0, 4.0)
128
+
129
+ # Main Execution Button
130
+ if st.button("πŸ“‘ Generate Radar Forecast", type="primary", use_container_width=True):
131
+
132
+ log_to_ui(f"Fetching data and generating radar for {selected_region}...", type="info")
133
+
134
+ t0 = time.time()
135
+
136
+ # Run Inference
137
+ result, msg = run_radar_generation(selected_region, lat_center, lon_center)
138
+
139
+ if result is None:
140
+ log_to_ui(msg, type="error")
141
+ else:
142
+ radar_data, lats, lons = result
143
+ elapsed = time.time() - t0
144
+ log_to_ui(f"Generation Complete in {elapsed:.2f}s", type="success")
145
+
146
+ # Plotting
147
+ st.subheader(f"Radar Reflectivity: {selected_region}")
148
+
149
+ fig = plt.figure(figsize=(10, 8))
150
+ ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
151
+
152
+ # Map Features
153
+ ax.add_feature(cfeature.COASTLINE, linewidth=1)
154
+ ax.add_feature(cfeature.BORDERS, linestyle=':', linewidth=1)
155
+ ax.add_feature(cfeature.LAKES, alpha=0.3, color='blue')
156
+
157
+ # Set Extent (Zoom)
158
+ extent = [lon_center - zoom_deg, lon_center + zoom_deg,
159
+ lat_center - zoom_deg, lat_center + zoom_deg]
160
+ ax.set_extent(extent, crs=ccrs.PlateCarree())
161
+
162
+ # Plot Radar Data
163
+ # Use a transparent colormap for low values (0 reflectivity)
164
+ mesh = ax.pcolormesh(lons, lats, radar_data,
165
+ transform=ccrs.PlateCarree(),
166
+ cmap='nipy_spectral',
167
+ vmin=0, vmax=70, # dBZ range
168
+ shading='auto')
169
+
170
+ plt.colorbar(mesh, ax=ax, label='Reflectivity (dBZ)', shrink=0.7)
171
+ plt.title(f"Simulated Radar | Center: {lat_center}, {lon_center}")
172
+
173
+ st.pyplot(fig)
174
+
175
+ # Footer
176
+ st.markdown("---")
177
+ st.caption("Powered by NVIDIA Earth2Studio β€’ Runs on Hugging Face GPU Spaces")
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ pandas
3
+ numpy
4
+ requests
5
+ altair
6
+ earth2studio
7
+ torch
8
+ matplotlib
9
+ cartopy
10
+ h5netcdf
11
+ zarr
12
+ onnx
13
+ onnxruntime-gpu
14
+ einops
15
+ nvidia-physicsnemo