import time import threading import torch gen_lock = threading.Lock() _MAIN_PROCESS_RUNNING_KEY = "main_process_running" def get_gen_info(state): cache = state.get("gen", None) if cache == None: cache = dict() state["gen"] = cache return cache def _main_generation_active_locked(gen): return bool(gen.get(_MAIN_PROCESS_RUNNING_KEY, False)) def set_main_generation_running(state, running): gen = get_gen_info(state) with gen_lock: if running: gen[_MAIN_PROCESS_RUNNING_KEY] = True else: gen.pop(_MAIN_PROCESS_RUNNING_KEY, None) def any_GPU_process_running(state, process_id, ignore_main = False): gen = get_gen_info(state) #"process:" + process_id with gen_lock: process_status = gen.get("process_status", None) if process_status == "process:main" and not _main_generation_active_locked(gen): return False return process_status is not None and not (process_status =="process:main" and ignore_main) def _get_gpu_residents(gen): residents = gen.get("gpu_residents", None) if residents is None: residents = {} gen["gpu_residents"] = residents return residents def _drop_gpu_resident_locked(gen, process_id): _get_gpu_residents(gen).pop(process_id, None) def _collect_resident_release_actions_locked(gen, requester_id = None): release_actions = [] residents = _get_gpu_residents(gen) for resident_id, resident_info in list(residents.items()): if resident_id == requester_id: residents.pop(resident_id, None) continue if not bool(resident_info.get("force_release_on_acquire", False)): continue release_callback = resident_info.get("release_vram_callback", None) if not callable(release_callback): residents.pop(resident_id, None) continue release_actions.append((resident_id, resident_info.get("process_name", resident_id), release_callback)) residents.pop(resident_id, None) return release_actions def _run_release_actions(release_actions): for resident_id, process_name, release_callback in release_actions: try: release_callback() except Exception as exc: print(f"[GPU] Unable to release resident VRAM for {process_name} ({resident_id}): {exc}") if len(release_actions) > 0 and torch.cuda.is_available(): torch.cuda.synchronize() def register_GPU_resident(state, process_id, process_name, release_vram_callback = None, force_release_on_acquire = True): gen = get_gen_info(state) with gen_lock: _get_gpu_residents(gen)[process_id] = { "process_name": process_name, "release_vram_callback": release_vram_callback, "force_release_on_acquire": bool(force_release_on_acquire), } def unregister_GPU_resident(state, process_id): gen = get_gen_info(state) with gen_lock: _drop_gpu_resident_locked(gen, process_id) def force_release_GPU_resident(state, process_id): gen = get_gen_info(state) release_callback = None with gen_lock: resident_info = _get_gpu_residents(gen).pop(process_id, None) if resident_info is not None: release_callback = resident_info.get("release_vram_callback", None) if callable(release_callback): release_callback() if torch.cuda.is_available(): torch.cuda.synchronize() def acquire_main_GPU_ressources(state): gen = get_gen_info(state) release_actions = [] while True: with gen_lock: process_status = gen.get("process_status", None) if process_status is None or process_status == "process:main": release_actions = _collect_resident_release_actions_locked(gen, requester_id="main") gen["process_status"] = "process:main" break time.sleep(0.1) _run_release_actions(release_actions) if torch.cuda.is_available(): torch.cuda.synchronize() def acquire_GPU_ressources(state, process_id, process_name, gr = None, custom_pause_msg = None, custom_wait_msg = None): gen = get_gen_info(state) original_process_status = None release_actions = [] while True: with gen_lock: process_hierarchy = gen.get("process_hierarchy", None) if process_hierarchy is None: process_hierarchy = dict() gen["process_hierarchy"]= process_hierarchy process_status = gen.get("process_status", None) if process_status is None: _drop_gpu_resident_locked(gen, process_id) original_process_status = None release_actions = _collect_resident_release_actions_locked(gen, requester_id=process_id) gen["process_status"] = "process:" + process_id break elif process_status == "request:" + process_id and not _main_generation_active_locked(gen): _drop_gpu_resident_locked(gen, process_id) original_process_status = None release_actions = _collect_resident_release_actions_locked(gen, requester_id=process_id) gen["process_status"] = "process:" + process_id break elif process_status == "process:main": if not _main_generation_active_locked(gen): _drop_gpu_resident_locked(gen, process_id) original_process_status = None release_actions = _collect_resident_release_actions_locked(gen, requester_id=process_id) gen["process_status"] = "process:" + process_id break original_process_status = process_status gen["process_status"] = "request:" + process_id gen["pause_msg"] = custom_pause_msg if custom_pause_msg is not None else f"Generation Suspended while using {process_name}" break elif process_status == "process:" + process_id: _drop_gpu_resident_locked(gen, process_id) break time.sleep(0.1) _run_release_actions(release_actions) if original_process_status is not None: total_wait = 0 wait_time = 0.1 wait_msg_displayed = False while True: with gen_lock: process_status = gen.get("process_status", None) if process_status == "process:" + process_id: break if process_status is None or (process_status == "request:" + process_id and not _main_generation_active_locked(gen)): # handle case when main process has finished at some point in between the last check and now gen["process_status"] = "process:" + process_id break total_wait += wait_time if round(total_wait,2) >= 5 and gr is not None and not wait_msg_displayed: wait_msg_displayed = True if custom_wait_msg is None: gr.Info(f"Process {process_name} is Suspended while waiting that GPU Ressources become available") else: gr.Info(custom_wait_msg) time.sleep(wait_time) with gen_lock: process_hierarchy[process_id] = original_process_status if torch.cuda.is_available(): torch.cuda.synchronize() def release_GPU_ressources(state, process_id, keep_resident = False, process_name = None, release_vram_callback = None, force_release_on_acquire = True): gen = get_gen_info(state) if torch.cuda.is_available(): torch.cuda.synchronize() with gen_lock: if keep_resident: _get_gpu_residents(gen)[process_id] = { "process_name": process_name or process_id, "release_vram_callback": release_vram_callback, "force_release_on_acquire": bool(force_release_on_acquire), } else: _drop_gpu_resident_locked(gen, process_id) process_hierarchy = gen.get("process_hierarchy", {}) restore_status = process_hierarchy.pop(process_id, None) if restore_status == "process:main" and not _main_generation_active_locked(gen): restore_status = None gen["process_status"] = restore_status