Spaces:
Sleeping
Sleeping
Christopher Tan commited on
Commit ·
d2fd76d
1
Parent(s): 8fbc522
clearing cache before switching models
Browse files- __pycache__/app.cpython-313.pyc +0 -0
- __pycache__/inference_openvla.cpython-313.pyc +0 -0
- app.py +34 -0
- inference_openvla.py +33 -2
__pycache__/app.cpython-313.pyc
CHANGED
|
Binary files a/__pycache__/app.cpython-313.pyc and b/__pycache__/app.cpython-313.pyc differ
|
|
|
__pycache__/inference_openvla.cpython-313.pyc
CHANGED
|
Binary files a/__pycache__/inference_openvla.cpython-313.pyc and b/__pycache__/inference_openvla.cpython-313.pyc differ
|
|
|
app.py
CHANGED
|
@@ -598,6 +598,34 @@ def cleanup_workers():
|
|
| 598 |
atexit.register(cleanup_workers)
|
| 599 |
|
| 600 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
@dataclasses.dataclass
|
| 602 |
class InferenceRequest:
|
| 603 |
"""Normalized payload for invoking model backends from the UI."""
|
|
@@ -632,6 +660,9 @@ def run_pi0_inference(request: InferenceRequest) -> Tuple[Optional[str], str]:
|
|
| 632 |
"""Dispatch OpenPI inference to subprocess"""
|
| 633 |
model_key = "openpi" # Define model_key for this function
|
| 634 |
try:
|
|
|
|
|
|
|
|
|
|
| 635 |
request.progress(0, desc="Starting OpenPI worker...")
|
| 636 |
worker = get_inference_worker(model_key)
|
| 637 |
|
|
@@ -774,6 +805,9 @@ def run_openvla_inference(request: InferenceRequest) -> Tuple[Optional[str], str
|
|
| 774 |
"""Dispatch OpenVLA inference to subprocess"""
|
| 775 |
model_key = "openvla" # Define model_key for this function
|
| 776 |
try:
|
|
|
|
|
|
|
|
|
|
| 777 |
request.progress(0, desc="Starting OpenVLA worker...")
|
| 778 |
worker = get_inference_worker(model_key)
|
| 779 |
|
|
|
|
| 598 |
atexit.register(cleanup_workers)
|
| 599 |
|
| 600 |
|
| 601 |
+
def terminate_other_worker(current_model_key: str):
|
| 602 |
+
"""Terminate the other model's worker to free GPU memory when switching models."""
|
| 603 |
+
global _INFERENCE_WORKERS, _WORKER_STDERR
|
| 604 |
+
|
| 605 |
+
# Find the other model key
|
| 606 |
+
other_model_key = "openvla" if current_model_key == "openpi" else "openpi"
|
| 607 |
+
|
| 608 |
+
# Check if the other worker is running
|
| 609 |
+
other_worker = _INFERENCE_WORKERS.get(other_model_key)
|
| 610 |
+
if other_worker and other_worker.poll() is None:
|
| 611 |
+
print(f"Terminating {other_model_key} worker to free GPU memory for {current_model_key}...", flush=True)
|
| 612 |
+
try:
|
| 613 |
+
other_worker.terminate()
|
| 614 |
+
try:
|
| 615 |
+
other_worker.wait(timeout=5)
|
| 616 |
+
print(f"✓ {other_model_key} worker terminated successfully", flush=True)
|
| 617 |
+
except subprocess.TimeoutExpired:
|
| 618 |
+
print(f"⚠️ {other_model_key} worker didn't terminate gracefully, killing...", flush=True)
|
| 619 |
+
other_worker.kill()
|
| 620 |
+
other_worker.wait()
|
| 621 |
+
except Exception as e:
|
| 622 |
+
print(f"⚠️ Error terminating {other_model_key} worker: {e}", flush=True)
|
| 623 |
+
finally:
|
| 624 |
+
# Mark as terminated
|
| 625 |
+
_INFERENCE_WORKERS[other_model_key] = None
|
| 626 |
+
_WORKER_STDERR[other_model_key] = [] # Clear stderr buffer
|
| 627 |
+
|
| 628 |
+
|
| 629 |
@dataclasses.dataclass
|
| 630 |
class InferenceRequest:
|
| 631 |
"""Normalized payload for invoking model backends from the UI."""
|
|
|
|
| 660 |
"""Dispatch OpenPI inference to subprocess"""
|
| 661 |
model_key = "openpi" # Define model_key for this function
|
| 662 |
try:
|
| 663 |
+
# Terminate OpenVLA worker if running to free GPU memory
|
| 664 |
+
terminate_other_worker(model_key)
|
| 665 |
+
|
| 666 |
request.progress(0, desc="Starting OpenPI worker...")
|
| 667 |
worker = get_inference_worker(model_key)
|
| 668 |
|
|
|
|
| 805 |
"""Dispatch OpenVLA inference to subprocess"""
|
| 806 |
model_key = "openvla" # Define model_key for this function
|
| 807 |
try:
|
| 808 |
+
# Terminate OpenPI worker if running to free GPU memory
|
| 809 |
+
terminate_other_worker(model_key)
|
| 810 |
+
|
| 811 |
request.progress(0, desc="Starting OpenVLA worker...")
|
| 812 |
worker = get_inference_worker(model_key)
|
| 813 |
|
inference_openvla.py
CHANGED
|
@@ -29,6 +29,8 @@ try:
|
|
| 29 |
os.environ["DISPLAY"] = ":99"
|
| 30 |
os.environ["LIBGL_ALWAYS_SOFTWARE"] = "1"
|
| 31 |
os.environ["GALLIUM_DRIVER"] = "llvmpipe"
|
|
|
|
|
|
|
| 32 |
# Debug: verify environment variables are set
|
| 33 |
print(f"DEBUG: MUJOCO_GL={os.environ.get('MUJOCO_GL')}, PYOPENGL_PLATFORM={os.environ.get('PYOPENGL_PLATFORM')}", file=sys.stderr, flush=True)
|
| 34 |
except Exception as e:
|
|
@@ -203,6 +205,30 @@ DEFAULT_DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
| 203 |
DEFAULT_DOWNSAMPLE_RATE = 25
|
| 204 |
CAMERA_RESOLUTION = (256, 256)
|
| 205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
# Environment registry
|
| 207 |
_ENV_CLASSES = {
|
| 208 |
"CubeHandover": (CubeHandover, "handover the rod from one hand to the other hand"),
|
|
@@ -240,8 +266,7 @@ _ENV_CLASSES = {
|
|
| 240 |
"StackTwoBlocksPositionOrientation": (StackTwoBlocksPositionAndOrientation, "stack the two cubes")
|
| 241 |
}
|
| 242 |
|
| 243 |
-
#
|
| 244 |
-
_MODEL_CACHE = {}
|
| 245 |
|
| 246 |
|
| 247 |
def get_checkpoint_path(task_name: str, ckpt_path: Optional[str] = None) -> str:
|
|
@@ -330,6 +355,12 @@ def load_vla_model(ckpt_path: str, device: str = DEFAULT_DEVICE) -> Tuple[AutoPr
|
|
| 330 |
if ckpt_path in _MODEL_CACHE:
|
| 331 |
return _MODEL_CACHE[ckpt_path]
|
| 332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
if not os.path.exists(ckpt_path):
|
| 334 |
raise FileNotFoundError(f"Checkpoint path does not exist: {ckpt_path}")
|
| 335 |
|
|
|
|
| 29 |
os.environ["DISPLAY"] = ":99"
|
| 30 |
os.environ["LIBGL_ALWAYS_SOFTWARE"] = "1"
|
| 31 |
os.environ["GALLIUM_DRIVER"] = "llvmpipe"
|
| 32 |
+
# PyTorch CUDA memory allocator settings to reduce fragmentation
|
| 33 |
+
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
| 34 |
# Debug: verify environment variables are set
|
| 35 |
print(f"DEBUG: MUJOCO_GL={os.environ.get('MUJOCO_GL')}, PYOPENGL_PLATFORM={os.environ.get('PYOPENGL_PLATFORM')}", file=sys.stderr, flush=True)
|
| 36 |
except Exception as e:
|
|
|
|
| 205 |
DEFAULT_DOWNSAMPLE_RATE = 25
|
| 206 |
CAMERA_RESOLUTION = (256, 256)
|
| 207 |
|
| 208 |
+
# Model cache
|
| 209 |
+
_MODEL_CACHE: Dict[str, Tuple[AutoProcessor, AutoModelForVision2Seq]] = {}
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def clear_gpu_memory():
|
| 213 |
+
"""Clear PyTorch GPU memory and model cache."""
|
| 214 |
+
global _MODEL_CACHE
|
| 215 |
+
|
| 216 |
+
# Clear the model cache
|
| 217 |
+
if _MODEL_CACHE:
|
| 218 |
+
print(f"Clearing {len(_MODEL_CACHE)} cached model(s) to free GPU memory...", file=sys.stderr, flush=True)
|
| 219 |
+
_MODEL_CACHE.clear()
|
| 220 |
+
|
| 221 |
+
# Clear PyTorch CUDA cache
|
| 222 |
+
try:
|
| 223 |
+
import gc
|
| 224 |
+
if torch.cuda.is_available():
|
| 225 |
+
torch.cuda.empty_cache()
|
| 226 |
+
torch.cuda.synchronize()
|
| 227 |
+
gc.collect()
|
| 228 |
+
print("GPU memory cleared successfully", file=sys.stderr, flush=True)
|
| 229 |
+
except Exception as e:
|
| 230 |
+
print(f"Warning: Could not fully clear GPU memory: {e}", file=sys.stderr, flush=True)
|
| 231 |
+
|
| 232 |
# Environment registry
|
| 233 |
_ENV_CLASSES = {
|
| 234 |
"CubeHandover": (CubeHandover, "handover the rod from one hand to the other hand"),
|
|
|
|
| 266 |
"StackTwoBlocksPositionOrientation": (StackTwoBlocksPositionAndOrientation, "stack the two cubes")
|
| 267 |
}
|
| 268 |
|
| 269 |
+
# Model cache is defined above (line 209)
|
|
|
|
| 270 |
|
| 271 |
|
| 272 |
def get_checkpoint_path(task_name: str, ckpt_path: Optional[str] = None) -> str:
|
|
|
|
| 355 |
if ckpt_path in _MODEL_CACHE:
|
| 356 |
return _MODEL_CACHE[ckpt_path]
|
| 357 |
|
| 358 |
+
# Clear GPU memory before loading new model if cache is not empty
|
| 359 |
+
# This helps when switching from OpenPI to OpenVLA
|
| 360 |
+
if _MODEL_CACHE:
|
| 361 |
+
print("Clearing GPU memory before loading new model...", file=sys.stderr, flush=True)
|
| 362 |
+
clear_gpu_memory()
|
| 363 |
+
|
| 364 |
if not os.path.exists(ckpt_path):
|
| 365 |
raise FileNotFoundError(f"Checkpoint path does not exist: {ckpt_path}")
|
| 366 |
|