File size: 5,956 Bytes
436b829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# copy from EasyVolumetricVideo
# Author: Zhen Xu https://github.com/dendenxu

from typing import Callable, List, Dict
from multiprocessing.pool import ThreadPool
from tqdm import tqdm
from threading import Thread
import torch
from torch import Tensor
from typing import List
import imageio
import os


import asyncio
from functools import wraps

def async_call_func(func):
    @wraps(func)
    async def wrapper(*args, **kwargs):
        loop = asyncio.get_event_loop()
        # Use run_in_executor to run the blocking function in a separate thread
        return await loop.run_in_executor(None, func, *args, **kwargs)
    return wrapper

def _save_image_impl(save_img, save_path):
    """Common implementation for saving images synchronously or asynchronously"""
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    imageio.imwrite(save_path, save_img)


def cat_dict(dict_list, dim=0, reserved_keys=['select_indice', 'dens_volume_fine', 'z_vals_fine', 'z_vals_mid_fine', 'points_fine', 'dens_volume', 'prob_volume', 't_vals', 't2z_func', 'img_feat', 't_mean_std_min_max']):
    return {k: torch.cat([item[k] for item in dict_list], dim) for k in dict_list[0] if isinstance(dict_list[0][k], Tensor) and k not in reserved_keys}

def cat_list(list_list: List[List[Tensor]], dim: int = 0):
    return [torch.cat([item[i] for item in list_list], dim=dim)  for i in range(len(list_list[0]))]

def cat_tensor(tensor_list: List[Tensor], dim=0):
    return torch.cat(tensor_list, dim=dim)

slice_func = lambda chunk_index, chunk_dim, chunk_size: [slice(None)] * chunk_dim + [slice(chunk_index, chunk_index+chunk_size)]
def chunkify(func, cat_func, chunk_tensors: List[Tensor], chunk_dim: int, chunk_size: int, **kwargs):
    '''
    func: function to be chunkified
    cat: function to concatenate the results
    chunk_tensors: list of tensors to be chunkified
    chunk_dim: dimension to be chunkified
    chunk_size: size of each chunk
    '''
    total_chunk_size = chunk_tensors[0].shape[chunk_dim]
    assert all([total_chunk_size == chunk_tensors[i].shape[chunk_dim] for i in range(1, len(chunk_tensors))])
    return cat_func([func(*[chunk_tensor[slice_func(i, chunk_dim, chunk_size)] for chunk_tensor in chunk_tensors], **kwargs) for i in range(0, total_chunk_size, chunk_size)], chunk_dim)
    
    all_ret = {}
    for i in range(0, chunk_size, chunk_num):
        ret = func(*[chunk_tensor[:, i:i + chunk_num] for chunk_tensor in chunk_tensors], **kwargs)
        if isinstance(ret, dict):
            for k in ret:
                if ret[k] is None:
                    continue
                if k not in all_ret:
                    all_ret[k] = []
                all_ret[k].append(ret[k])
        elif isinstance(ret, list) or isinstance(ret, tuple):
            for k in range(len(ret)):
                if ret[k] is None:
                    continue
                if k not in all_ret:
                    all_ret[k] = []
                all_ret[k].append(ret[k])
        elif isinstance(ret, torch.Tensor):
            if 0 not in all_ret:
                all_ret[0] = []
            all_ret[0].append(ret)
    if isinstance(ret, dict):
        return {k: torch.cat(all_ret[k], dim=chunk_dim) for k in all_ret}
    elif isinstance(ret, list) or isinstance(ret, tuple):
        return [torch.cat(all_ret[k], dim=chunk_dim) for k in all_ret]
    elif isinstance(ret, torch.Tensor):
        return torch.cat(all_ret[0], dim=chunk_dim)

def async_call(fn):
    def wrapper(*args, **kwargs):
        Thread(target=fn, args=args, kwargs=kwargs).start()
    return wrapper

@async_call
def save_image_async(save_img, save_path):
    """Save image asynchronously"""
    _save_image_impl(save_img, save_path)

def save_image(save_img, save_path):
    """Save image synchronously"""
    _save_image_impl(save_img, save_path)

def parallel_execution(*args, action: Callable, num_processes=32, print_progress=False, sequential=False, async_return=False, desc=None, **kwargs):
    # NOTE: we expect first arg / or kwargs to be distributed
    # NOTE: print_progress arg is reserved

    def get_length(args: List, kwargs: Dict):
        for a in args:
            if isinstance(a, list):
                return len(a)
        for v in kwargs.values():
            if isinstance(v, list):
                return len(v)
        raise NotImplementedError

    def get_action_args(length: int, args: List, kwargs: Dict, i: int):
        action_args = [(arg[i] if isinstance(arg, list) and len(arg) == length else arg) for arg in args]
        # TODO: Support all types of iterable
        action_kwargs = {key: (kwargs[key][i] if isinstance(kwargs[key], list) and len(kwargs[key]) == length else kwargs[key]) for key in kwargs}
        return action_args, action_kwargs

    if not sequential:
        # Create ThreadPool
        pool = ThreadPool(processes=num_processes)

        # Spawn threads
        results = []
        asyncs = []
        length = get_length(args, kwargs)
        for i in range(length):
            action_args, action_kwargs = get_action_args(length, args, kwargs, i)
            async_result = pool.apply_async(action, action_args, action_kwargs)
            asyncs.append(async_result)

        # Join threads and get return values
        if not async_return:
            for async_result in tqdm(asyncs, desc=desc, disable=not print_progress):
                results.append(async_result.get())  # will sync the corresponding thread
            pool.close()
            pool.join()
            return results
        else:
            return pool
    else:
        results = []
        length = get_length(args, kwargs)
        for i in tqdm(range(length), desc=desc, disable=not print_progress):
            action_args, action_kwargs = get_action_args(length, args, kwargs, i)
            async_result = action(*action_args, **action_kwargs)
            results.append(async_result)
        return results