Christopher Tan commited on
Commit
d2fd76d
·
1 Parent(s): 8fbc522

clearing cache before switching models

Browse files
__pycache__/app.cpython-313.pyc CHANGED
Binary files a/__pycache__/app.cpython-313.pyc and b/__pycache__/app.cpython-313.pyc differ
 
__pycache__/inference_openvla.cpython-313.pyc CHANGED
Binary files a/__pycache__/inference_openvla.cpython-313.pyc and b/__pycache__/inference_openvla.cpython-313.pyc differ
 
app.py CHANGED
@@ -598,6 +598,34 @@ def cleanup_workers():
598
  atexit.register(cleanup_workers)
599
 
600
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601
  @dataclasses.dataclass
602
  class InferenceRequest:
603
  """Normalized payload for invoking model backends from the UI."""
@@ -632,6 +660,9 @@ def run_pi0_inference(request: InferenceRequest) -> Tuple[Optional[str], str]:
632
  """Dispatch OpenPI inference to subprocess"""
633
  model_key = "openpi" # Define model_key for this function
634
  try:
 
 
 
635
  request.progress(0, desc="Starting OpenPI worker...")
636
  worker = get_inference_worker(model_key)
637
 
@@ -774,6 +805,9 @@ def run_openvla_inference(request: InferenceRequest) -> Tuple[Optional[str], str
774
  """Dispatch OpenVLA inference to subprocess"""
775
  model_key = "openvla" # Define model_key for this function
776
  try:
 
 
 
777
  request.progress(0, desc="Starting OpenVLA worker...")
778
  worker = get_inference_worker(model_key)
779
 
 
598
  atexit.register(cleanup_workers)
599
 
600
 
601
+ def terminate_other_worker(current_model_key: str):
602
+ """Terminate the other model's worker to free GPU memory when switching models."""
603
+ global _INFERENCE_WORKERS, _WORKER_STDERR
604
+
605
+ # Find the other model key
606
+ other_model_key = "openvla" if current_model_key == "openpi" else "openpi"
607
+
608
+ # Check if the other worker is running
609
+ other_worker = _INFERENCE_WORKERS.get(other_model_key)
610
+ if other_worker and other_worker.poll() is None:
611
+ print(f"Terminating {other_model_key} worker to free GPU memory for {current_model_key}...", flush=True)
612
+ try:
613
+ other_worker.terminate()
614
+ try:
615
+ other_worker.wait(timeout=5)
616
+ print(f"✓ {other_model_key} worker terminated successfully", flush=True)
617
+ except subprocess.TimeoutExpired:
618
+ print(f"⚠️ {other_model_key} worker didn't terminate gracefully, killing...", flush=True)
619
+ other_worker.kill()
620
+ other_worker.wait()
621
+ except Exception as e:
622
+ print(f"⚠️ Error terminating {other_model_key} worker: {e}", flush=True)
623
+ finally:
624
+ # Mark as terminated
625
+ _INFERENCE_WORKERS[other_model_key] = None
626
+ _WORKER_STDERR[other_model_key] = [] # Clear stderr buffer
627
+
628
+
629
  @dataclasses.dataclass
630
  class InferenceRequest:
631
  """Normalized payload for invoking model backends from the UI."""
 
660
  """Dispatch OpenPI inference to subprocess"""
661
  model_key = "openpi" # Define model_key for this function
662
  try:
663
+ # Terminate OpenVLA worker if running to free GPU memory
664
+ terminate_other_worker(model_key)
665
+
666
  request.progress(0, desc="Starting OpenPI worker...")
667
  worker = get_inference_worker(model_key)
668
 
 
805
  """Dispatch OpenVLA inference to subprocess"""
806
  model_key = "openvla" # Define model_key for this function
807
  try:
808
+ # Terminate OpenPI worker if running to free GPU memory
809
+ terminate_other_worker(model_key)
810
+
811
  request.progress(0, desc="Starting OpenVLA worker...")
812
  worker = get_inference_worker(model_key)
813
 
inference_openvla.py CHANGED
@@ -29,6 +29,8 @@ try:
29
  os.environ["DISPLAY"] = ":99"
30
  os.environ["LIBGL_ALWAYS_SOFTWARE"] = "1"
31
  os.environ["GALLIUM_DRIVER"] = "llvmpipe"
 
 
32
  # Debug: verify environment variables are set
33
  print(f"DEBUG: MUJOCO_GL={os.environ.get('MUJOCO_GL')}, PYOPENGL_PLATFORM={os.environ.get('PYOPENGL_PLATFORM')}", file=sys.stderr, flush=True)
34
  except Exception as e:
@@ -203,6 +205,30 @@ DEFAULT_DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
203
  DEFAULT_DOWNSAMPLE_RATE = 25
204
  CAMERA_RESOLUTION = (256, 256)
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  # Environment registry
207
  _ENV_CLASSES = {
208
  "CubeHandover": (CubeHandover, "handover the rod from one hand to the other hand"),
@@ -240,8 +266,7 @@ _ENV_CLASSES = {
240
  "StackTwoBlocksPositionOrientation": (StackTwoBlocksPositionAndOrientation, "stack the two cubes")
241
  }
242
 
243
- # Global model cache
244
- _MODEL_CACHE = {}
245
 
246
 
247
  def get_checkpoint_path(task_name: str, ckpt_path: Optional[str] = None) -> str:
@@ -330,6 +355,12 @@ def load_vla_model(ckpt_path: str, device: str = DEFAULT_DEVICE) -> Tuple[AutoPr
330
  if ckpt_path in _MODEL_CACHE:
331
  return _MODEL_CACHE[ckpt_path]
332
 
 
 
 
 
 
 
333
  if not os.path.exists(ckpt_path):
334
  raise FileNotFoundError(f"Checkpoint path does not exist: {ckpt_path}")
335
 
 
29
  os.environ["DISPLAY"] = ":99"
30
  os.environ["LIBGL_ALWAYS_SOFTWARE"] = "1"
31
  os.environ["GALLIUM_DRIVER"] = "llvmpipe"
32
+ # PyTorch CUDA memory allocator settings to reduce fragmentation
33
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
34
  # Debug: verify environment variables are set
35
  print(f"DEBUG: MUJOCO_GL={os.environ.get('MUJOCO_GL')}, PYOPENGL_PLATFORM={os.environ.get('PYOPENGL_PLATFORM')}", file=sys.stderr, flush=True)
36
  except Exception as e:
 
205
  DEFAULT_DOWNSAMPLE_RATE = 25
206
  CAMERA_RESOLUTION = (256, 256)
207
 
208
+ # Model cache
209
+ _MODEL_CACHE: Dict[str, Tuple[AutoProcessor, AutoModelForVision2Seq]] = {}
210
+
211
+
212
+ def clear_gpu_memory():
213
+ """Clear PyTorch GPU memory and model cache."""
214
+ global _MODEL_CACHE
215
+
216
+ # Clear the model cache
217
+ if _MODEL_CACHE:
218
+ print(f"Clearing {len(_MODEL_CACHE)} cached model(s) to free GPU memory...", file=sys.stderr, flush=True)
219
+ _MODEL_CACHE.clear()
220
+
221
+ # Clear PyTorch CUDA cache
222
+ try:
223
+ import gc
224
+ if torch.cuda.is_available():
225
+ torch.cuda.empty_cache()
226
+ torch.cuda.synchronize()
227
+ gc.collect()
228
+ print("GPU memory cleared successfully", file=sys.stderr, flush=True)
229
+ except Exception as e:
230
+ print(f"Warning: Could not fully clear GPU memory: {e}", file=sys.stderr, flush=True)
231
+
232
  # Environment registry
233
  _ENV_CLASSES = {
234
  "CubeHandover": (CubeHandover, "handover the rod from one hand to the other hand"),
 
266
  "StackTwoBlocksPositionOrientation": (StackTwoBlocksPositionAndOrientation, "stack the two cubes")
267
  }
268
 
269
+ # Model cache is defined above (line 209)
 
270
 
271
 
272
  def get_checkpoint_path(task_name: str, ckpt_path: Optional[str] = None) -> str:
 
355
  if ckpt_path in _MODEL_CACHE:
356
  return _MODEL_CACHE[ckpt_path]
357
 
358
+ # Clear GPU memory before loading new model if cache is not empty
359
+ # This helps when switching from OpenPI to OpenVLA
360
+ if _MODEL_CACHE:
361
+ print("Clearing GPU memory before loading new model...", file=sys.stderr, flush=True)
362
+ clear_gpu_memory()
363
+
364
  if not os.path.exists(ckpt_path):
365
  raise FileNotFoundError(f"Checkpoint path does not exist: {ckpt_path}")
366