| | import torch
|
| | from contextlib import contextmanager
|
| |
|
| |
|
| | high_vram = False
|
| | gpu = torch.device('cuda')
|
| | cpu = torch.device('cpu')
|
| |
|
| | torch.zeros((1, 1)).to(gpu, torch.float32)
|
| | torch.cuda.empty_cache()
|
| |
|
| | models_in_gpu = []
|
| |
|
| |
|
| | @contextmanager
|
| | def movable_bnb_model(m):
|
| | if hasattr(m, 'quantization_method'):
|
| | m.quantization_method_backup = m.quantization_method
|
| | del m.quantization_method
|
| | try:
|
| | yield None
|
| | finally:
|
| | if hasattr(m, 'quantization_method_backup'):
|
| | m.quantization_method = m.quantization_method_backup
|
| | del m.quantization_method_backup
|
| | return
|
| |
|
| |
|
| | def load_models_to_gpu(models):
|
| | global models_in_gpu
|
| |
|
| | if not isinstance(models, (tuple, list)):
|
| | models = [models]
|
| |
|
| | models_to_remain = [m for m in set(models) if m in models_in_gpu]
|
| | models_to_load = [m for m in set(models) if m not in models_in_gpu]
|
| | models_to_unload = [m for m in set(models_in_gpu) if m not in models_to_remain]
|
| |
|
| | if not high_vram:
|
| | for m in models_to_unload:
|
| | with movable_bnb_model(m):
|
| | m.to(cpu)
|
| | print('Unload to CPU:', m.__class__.__name__)
|
| | models_in_gpu = models_to_remain
|
| |
|
| | for m in models_to_load:
|
| | with movable_bnb_model(m):
|
| | m.to(gpu)
|
| | print('Load to GPU:', m.__class__.__name__)
|
| |
|
| | models_in_gpu = list(set(models_in_gpu + models))
|
| | torch.cuda.empty_cache()
|
| | return
|
| |
|
| |
|
| | def unload_all_models(extra_models=None):
|
| | global models_in_gpu
|
| |
|
| | if extra_models is None:
|
| | extra_models = []
|
| |
|
| | if not isinstance(extra_models, (tuple, list)):
|
| | extra_models = [extra_models]
|
| |
|
| | models_in_gpu = list(set(models_in_gpu + extra_models))
|
| |
|
| | return load_models_to_gpu([])
|
| |
|