Spaces:
Running
on
Zero
Running
on
Zero
fix: update some comments, env vars, and settings for h200 and HF backend issues
Browse files
app.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import Optional, Tuple
|
|
| 5 |
# Save original asyncio.run BEFORE any imports that might patch it (nest_asyncio)
|
| 6 |
_ORIGINAL_ASYNCIO_RUN = asyncio.run
|
| 7 |
|
| 8 |
-
# On ZeroGPU
|
| 9 |
# some einsum-heavy models. Prefer full FP32 math for stability.
|
| 10 |
os.environ.setdefault("NVIDIA_TF32_OVERRIDE", "0")
|
| 11 |
# ZeroGPU H200-specific workarounds for cuBLAS strided-batch GEMM issues
|
|
@@ -126,70 +126,28 @@ def _run_inference(forecast_date: str, nsteps: int):
|
|
| 126 |
|
| 127 |
_ensure_cache_dirs()
|
| 128 |
|
| 129 |
-
#
|
| 130 |
-
torch.backends.cudnn.benchmark = False
|
| 131 |
-
#
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
except Exception:
|
| 135 |
-
pass
|
| 136 |
-
try:
|
| 137 |
-
torch.backends.cuda.matmul.allow_tf32 = False
|
| 138 |
-
except Exception:
|
| 139 |
-
pass
|
| 140 |
-
try:
|
| 141 |
-
torch.backends.cudnn.allow_tf32 = False
|
| 142 |
-
except Exception:
|
| 143 |
-
pass
|
| 144 |
-
# Avoid reduced-precision reductions (guarded for older torch versions)
|
| 145 |
-
try:
|
| 146 |
-
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
| 147 |
-
except Exception:
|
| 148 |
-
pass
|
| 149 |
-
try:
|
| 150 |
-
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
| 151 |
-
except Exception:
|
| 152 |
-
pass
|
| 153 |
-
try:
|
| 154 |
-
torch.cuda.set_device(0)
|
| 155 |
-
except Exception:
|
| 156 |
-
pass
|
| 157 |
torch.cuda.empty_cache()
|
| 158 |
|
| 159 |
-
#
|
| 160 |
-
# producing unsupported strides for batched GEMM. Force contiguity at the einsum boundary.
|
| 161 |
_orig_einsum = torch.einsum
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
for op in operands:
|
| 166 |
-
if isinstance(op, torch.Tensor) and not op.is_contiguous():
|
| 167 |
-
ops.append(op.contiguous())
|
| 168 |
-
else:
|
| 169 |
-
ops.append(op)
|
| 170 |
-
return _orig_einsum(equation, *ops)
|
| 171 |
-
|
| 172 |
-
torch.einsum = _einsum_contiguous # type: ignore[assignment]
|
| 173 |
|
| 174 |
# Load model inside GPU function (ZeroGPU requirement)
|
| 175 |
-
|
| 176 |
-
from earth2studio.models.px.fcn import FCN
|
| 177 |
-
except Exception:
|
| 178 |
-
from earth2studio.models.px import FCN
|
| 179 |
|
| 180 |
package = FCN.load_default_package()
|
| 181 |
model = FCN.load_model(package)
|
| 182 |
|
|
|
|
| 183 |
device = torch.device("cuda")
|
| 184 |
-
|
| 185 |
-
try:
|
| 186 |
-
model = model.float()
|
| 187 |
-
except Exception:
|
| 188 |
-
pass
|
| 189 |
-
model = model.to(device)
|
| 190 |
-
model.eval() # Ensure eval mode
|
| 191 |
-
|
| 192 |
-
# Clear memory after model load
|
| 193 |
torch.cuda.empty_cache()
|
| 194 |
|
| 195 |
# CRITICAL: Warmup CUDA/cuBLAS context on ZeroGPU's H200 before complex ops
|
|
@@ -221,27 +179,11 @@ def _run_inference(forecast_date: str, nsteps: int):
|
|
| 221 |
|
| 222 |
return lon, lat, all_fields
|
| 223 |
finally:
|
| 224 |
-
#
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
# Free GPU memory aggressively
|
| 230 |
-
try:
|
| 231 |
-
model.to("cpu")
|
| 232 |
-
except Exception:
|
| 233 |
-
pass
|
| 234 |
-
try:
|
| 235 |
-
del model
|
| 236 |
-
del data
|
| 237 |
-
del io
|
| 238 |
-
except Exception:
|
| 239 |
-
pass
|
| 240 |
-
try:
|
| 241 |
-
torch.cuda.empty_cache()
|
| 242 |
-
torch.cuda.synchronize()
|
| 243 |
-
except Exception:
|
| 244 |
-
pass
|
| 245 |
|
| 246 |
|
| 247 |
def run_forecast(forecast_date: str, nsteps: int):
|
|
|
|
| 5 |
# Save original asyncio.run BEFORE any imports that might patch it (nest_asyncio)
|
| 6 |
_ORIGINAL_ASYNCIO_RUN = asyncio.run
|
| 7 |
|
| 8 |
+
# On ZeroGPU H200, TF32 matmul paths can occasionally trip cuBLAS errors in
|
| 9 |
# some einsum-heavy models. Prefer full FP32 math for stability.
|
| 10 |
os.environ.setdefault("NVIDIA_TF32_OVERRIDE", "0")
|
| 11 |
# ZeroGPU H200-specific workarounds for cuBLAS strided-batch GEMM issues
|
|
|
|
| 126 |
|
| 127 |
_ensure_cache_dirs()
|
| 128 |
|
| 129 |
+
# Critical precision settings for ZeroGPU H200 cuBLAS stability
|
| 130 |
+
torch.backends.cudnn.benchmark = False
|
| 131 |
+
torch.set_float32_matmul_precision("highest") # Full FP32, no TF32
|
| 132 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
| 133 |
+
torch.backends.cudnn.allow_tf32 = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
torch.cuda.empty_cache()
|
| 135 |
|
| 136 |
+
# Force einsum operand contiguity to avoid cuBLAS strided-batch GEMM errors
|
|
|
|
| 137 |
_orig_einsum = torch.einsum
|
| 138 |
+
torch.einsum = lambda eq, *ops: _orig_einsum(
|
| 139 |
+
eq, *[op.contiguous() if torch.is_tensor(op) else op for op in ops]
|
| 140 |
+
) # type: ignore[assignment]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
# Load model inside GPU function (ZeroGPU requirement)
|
| 143 |
+
from earth2studio.models.px import FCN
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
package = FCN.load_default_package()
|
| 146 |
model = FCN.load_model(package)
|
| 147 |
|
| 148 |
+
# Move to GPU with FP32 precision
|
| 149 |
device = torch.device("cuda")
|
| 150 |
+
model = model.float().to(device).eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
torch.cuda.empty_cache()
|
| 152 |
|
| 153 |
# CRITICAL: Warmup CUDA/cuBLAS context on ZeroGPU's H200 before complex ops
|
|
|
|
| 179 |
|
| 180 |
return lon, lat, all_fields
|
| 181 |
finally:
|
| 182 |
+
# Cleanup: restore einsum and free GPU memory
|
| 183 |
+
torch.einsum = _orig_einsum # type: ignore[assignment]
|
| 184 |
+
del model, data, io
|
| 185 |
+
torch.cuda.empty_cache()
|
| 186 |
+
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
|
| 189 |
def run_forecast(forecast_date: str, nsteps: int):
|