File size: 3,018 Bytes
31112ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import time
import threading
import torch

gen_lock = threading.Lock()

def get_gen_info(state):
    cache = state.get("gen", None)
    if cache == None:
        cache = dict()
        state["gen"] = cache
    return cache

def any_GPU_process_running(state, process_id, ignore_main = False):
    gen = get_gen_info(state)
#"process:" + process_id
    with gen_lock:
        process_status = gen.get("process_status", None)
        return process_status is not None and not (process_status =="process:main" and ignore_main)
    
def acquire_GPU_ressources(state, process_id, process_name, gr = None, custom_pause_msg = None, custom_wait_msg = None):
    gen = get_gen_info(state)
    original_process_status = None
    while True:
        with gen_lock:
            process_hierarchy = gen.get("process_hierarchy", None)
            if process_hierarchy is None:
                process_hierarchy = dict()
                gen["process_hierarchy"]= process_hierarchy

            process_status = gen.get("process_status", None)
            if process_status is None:
                original_process_status = process_status 
                gen["process_status"] = "process:" + process_id
                break
            elif process_status == "process:main":
                original_process_status = process_status 
                gen["process_status"] = "request:" + process_id

                gen["pause_msg"] = custom_pause_msg if custom_pause_msg is not None else f"Generation Suspended while using {process_name}" 
                break
            elif process_status == "process:" + process_id:
                break
        time.sleep(0.1)

    if original_process_status is not None:
        total_wait = 0
        wait_time = 0.1
        wait_msg_displayed = False
        while True:
            with gen_lock:
                process_status = gen.get("process_status", None)
                if process_status == "process:" + process_id: break
                if process_status is None:
                    # handle case when main process has finished at some point in between the last check and now
                    gen["process_status"] = "process:" + process_id
                    break

            total_wait += wait_time
            if round(total_wait,2) >= 5 and gr is not None and not wait_msg_displayed:
                wait_msg_displayed = True
                if custom_wait_msg is None:
                    gr.Info(f"Process {process_name} is Suspended while waiting that GPU Ressources become available")
                else:
                    gr.Info(custom_wait_msg)

            time.sleep(wait_time)
    
    with gen_lock:
        process_hierarchy[process_id] = original_process_status
    torch.cuda.synchronize()

def release_GPU_ressources(state, process_id):
    gen = get_gen_info(state)
    torch.cuda.synchronize()
    with gen_lock:
        process_hierarchy = gen.get("process_hierarchy", {})
        gen["process_status"] = process_hierarchy.get(process_id, None)