carmelog commited on
Commit
7a84f11
·
1 Parent(s): 156dc4d

fix: update some comments, env vars, and settings for h200 and HF backend issues

Browse files
Files changed (1) hide show
  1. app.py +18 -76
app.py CHANGED
@@ -5,7 +5,7 @@ from typing import Optional, Tuple
5
  # Save original asyncio.run BEFORE any imports that might patch it (nest_asyncio)
6
  _ORIGINAL_ASYNCIO_RUN = asyncio.run
7
 
8
- # On ZeroGPU (shared A10G), TF32 matmul paths can occasionally trip cuBLAS errors in
9
  # some einsum-heavy models. Prefer full FP32 math for stability.
10
  os.environ.setdefault("NVIDIA_TF32_OVERRIDE", "0")
11
  # ZeroGPU H200-specific workarounds for cuBLAS strided-batch GEMM issues
@@ -126,70 +126,28 @@ def _run_inference(forecast_date: str, nsteps: int):
126
 
127
  _ensure_cache_dirs()
128
 
129
- # Memory management for ZeroGPU
130
- torch.backends.cudnn.benchmark = False # More stable on shared GPU
131
- # Prefer full FP32 math (avoid TF32) for stability on shared A10G
132
- try:
133
- torch.set_float32_matmul_precision("highest")
134
- except Exception:
135
- pass
136
- try:
137
- torch.backends.cuda.matmul.allow_tf32 = False
138
- except Exception:
139
- pass
140
- try:
141
- torch.backends.cudnn.allow_tf32 = False
142
- except Exception:
143
- pass
144
- # Avoid reduced-precision reductions (guarded for older torch versions)
145
- try:
146
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
147
- except Exception:
148
- pass
149
- try:
150
- torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
151
- except Exception:
152
- pass
153
- try:
154
- torch.cuda.set_device(0)
155
- except Exception:
156
- pass
157
  torch.cuda.empty_cache()
158
 
159
- # Some cuBLAS "INVALID_VALUE" failures originate from non-contiguous einsum operands
160
- # producing unsupported strides for batched GEMM. Force contiguity at the einsum boundary.
161
  _orig_einsum = torch.einsum
162
-
163
- def _einsum_contiguous(equation, *operands):
164
- ops = []
165
- for op in operands:
166
- if isinstance(op, torch.Tensor) and not op.is_contiguous():
167
- ops.append(op.contiguous())
168
- else:
169
- ops.append(op)
170
- return _orig_einsum(equation, *ops)
171
-
172
- torch.einsum = _einsum_contiguous # type: ignore[assignment]
173
 
174
  # Load model inside GPU function (ZeroGPU requirement)
175
- try:
176
- from earth2studio.models.px.fcn import FCN
177
- except Exception:
178
- from earth2studio.models.px import FCN
179
 
180
  package = FCN.load_default_package()
181
  model = FCN.load_model(package)
182
 
 
183
  device = torch.device("cuda")
184
- # Ensure FP32 weights for cuBLAS stability
185
- try:
186
- model = model.float()
187
- except Exception:
188
- pass
189
- model = model.to(device)
190
- model.eval() # Ensure eval mode
191
-
192
- # Clear memory after model load
193
  torch.cuda.empty_cache()
194
 
195
  # CRITICAL: Warmup CUDA/cuBLAS context on ZeroGPU's H200 before complex ops
@@ -221,27 +179,11 @@ def _run_inference(forecast_date: str, nsteps: int):
221
 
222
  return lon, lat, all_fields
223
  finally:
224
- # Restore einsum in case this worker is reused
225
- try:
226
- torch.einsum = _orig_einsum # type: ignore[assignment]
227
- except Exception:
228
- pass
229
- # Free GPU memory aggressively
230
- try:
231
- model.to("cpu")
232
- except Exception:
233
- pass
234
- try:
235
- del model
236
- del data
237
- del io
238
- except Exception:
239
- pass
240
- try:
241
- torch.cuda.empty_cache()
242
- torch.cuda.synchronize()
243
- except Exception:
244
- pass
245
 
246
 
247
  def run_forecast(forecast_date: str, nsteps: int):
 
5
  # Save original asyncio.run BEFORE any imports that might patch it (nest_asyncio)
6
  _ORIGINAL_ASYNCIO_RUN = asyncio.run
7
 
8
+ # On ZeroGPU H200, TF32 matmul paths can occasionally trip cuBLAS errors in
9
  # some einsum-heavy models. Prefer full FP32 math for stability.
10
  os.environ.setdefault("NVIDIA_TF32_OVERRIDE", "0")
11
  # ZeroGPU H200-specific workarounds for cuBLAS strided-batch GEMM issues
 
126
 
127
  _ensure_cache_dirs()
128
 
129
+ # Critical precision settings for ZeroGPU H200 cuBLAS stability
130
+ torch.backends.cudnn.benchmark = False
131
+ torch.set_float32_matmul_precision("highest") # Full FP32, no TF32
132
+ torch.backends.cuda.matmul.allow_tf32 = False
133
+ torch.backends.cudnn.allow_tf32 = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  torch.cuda.empty_cache()
135
 
136
+ # Force einsum operand contiguity to avoid cuBLAS strided-batch GEMM errors
 
137
  _orig_einsum = torch.einsum
138
+ torch.einsum = lambda eq, *ops: _orig_einsum(
139
+ eq, *[op.contiguous() if torch.is_tensor(op) else op for op in ops]
140
+ ) # type: ignore[assignment]
 
 
 
 
 
 
 
 
141
 
142
  # Load model inside GPU function (ZeroGPU requirement)
143
+ from earth2studio.models.px import FCN
 
 
 
144
 
145
  package = FCN.load_default_package()
146
  model = FCN.load_model(package)
147
 
148
+ # Move to GPU with FP32 precision
149
  device = torch.device("cuda")
150
+ model = model.float().to(device).eval()
 
 
 
 
 
 
 
 
151
  torch.cuda.empty_cache()
152
 
153
  # CRITICAL: Warmup CUDA/cuBLAS context on ZeroGPU's H200 before complex ops
 
179
 
180
  return lon, lat, all_fields
181
  finally:
182
+ # Cleanup: restore einsum and free GPU memory
183
+ torch.einsum = _orig_einsum # type: ignore[assignment]
184
+ del model, data, io
185
+ torch.cuda.empty_cache()
186
+ torch.cuda.synchronize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
 
189
  def run_forecast(forecast_date: str, nsteps: int):