|
|
import gc |
|
|
import torch |
|
|
import accelerate |
|
|
|
|
|
|
|
|
def get_module_by_name_suffix(model, module_name: str): |
|
|
for name, module in model.named_modules(): |
|
|
if name.endswith(module_name): |
|
|
return module |
|
|
|
|
|
|
|
|
def simple_dispatch_model(model, device_map): |
|
|
from accelerate.hooks import add_hook_to_module, AlignDevicesHook |
|
|
|
|
|
if "" in device_map: |
|
|
d = device_map[""] |
|
|
model = model.to(torch.device(d)) |
|
|
model.hf_device_map = device_map |
|
|
return model |
|
|
|
|
|
tied_params = accelerate.utils.modeling.find_tied_parameters(model) |
|
|
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == { |
|
|
"cpu", |
|
|
"disk", |
|
|
}: |
|
|
main_device = "cpu" |
|
|
else: |
|
|
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0] |
|
|
|
|
|
cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"] |
|
|
prev_hook = None |
|
|
for idx, (n, d) in enumerate(cpu_offload_group): |
|
|
m = get_module_by_name_suffix(model, n) |
|
|
_, prev_hook = accelerate.cpu_offload_with_hook( |
|
|
m, execution_device=main_device, prev_module_hook=prev_hook |
|
|
) |
|
|
|
|
|
if len(cpu_offload_group) > 1: |
|
|
get_module_by_name_suffix( |
|
|
model, cpu_offload_group[0][0] |
|
|
)._hf_hook.prev_module_hook = prev_hook |
|
|
|
|
|
for n, d in device_map.items(): |
|
|
m = get_module_by_name_suffix(model, n) |
|
|
if d != "cpu": |
|
|
d = torch.device(d) |
|
|
hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True) |
|
|
add_hook_to_module(m, hook) |
|
|
accelerate.utils.modeling.retie_parameters(model, tied_params) |
|
|
model.hf_device_map = device_map |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def set_module_name(model, name, value): |
|
|
if "." in name: |
|
|
parent_name = name.rsplit(".", 1)[0] |
|
|
child_name = name[len(parent_name) + 1 :] |
|
|
parent = model.get_submodule(parent_name) |
|
|
else: |
|
|
parent_name = "" |
|
|
parent = model |
|
|
child_name = name |
|
|
|
|
|
setattr(parent, child_name, value) |
|
|
|
|
|
|
|
|
def clear_memory(weight=None): |
|
|
if weight is not None: |
|
|
del weight |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
def compute_memory_used_pct(device): |
|
|
memory_used = torch.cuda.max_memory_allocated(device) / (1024**3) |
|
|
memory_pct = ( |
|
|
memory_used |
|
|
/ (torch.cuda.get_device_properties(device).total_memory / (1024**3)) |
|
|
* 100 |
|
|
) |
|
|
return memory_pct |
|
|
|
|
|
|
|
|
def get_best_device(): |
|
|
if torch.backends.mps.is_available(): |
|
|
return "mps" |
|
|
elif torch.cuda.is_available(): |
|
|
return "cuda:0" |
|
|
else: |
|
|
return "cpu" |
|
|
|
|
|
|
|
|
def get_lowest_memory_device_index(): |
|
|
device = None |
|
|
curr_device_memory_pct = 0 |
|
|
for device_index in range(torch.cuda.device_count()): |
|
|
device_memory_pct = compute_memory_used_pct(device_index) |
|
|
if device is None or device_memory_pct < curr_device_memory_pct: |
|
|
device = device_index |
|
|
curr_device_memory_pct = device_memory_pct |
|
|
|
|
|
return device |
|
|
|