Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,508 Bytes
26a63c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import os
import shutil
import threading
from pathlib import Path
import torch
def import_from_transformers_modules(
pretrained_model_name_or_path, file_name, class_name
):
import transformers
module_path = transformers.dynamic_module_utils.get_cached_module_file(
pretrained_model_name_or_path, file_name
)
return transformers.dynamic_module_utils.get_class_in_module(
class_name, module_path
)
def deepspeed_zero_init_disabled_context_manager():
"""
returns either a context list that includes one that will disable zero.Init or an empty context list
"""
import accelerate
deepspeed_plugin = (
accelerate.state.AcceleratorState().deepspeed_plugin
if accelerate.state.is_initialized()
else None
)
if deepspeed_plugin is None:
return []
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
def remove_excess_checkpoints(
save_directory,
checkpoints_total_limit: int = None,
checkpoint_prefix="checkpoint",
is_main_process: bool = True,
):
# _after_ saving state, check if this save would set us over the `checkpoints_total_limit`
if is_main_process and checkpoints_total_limit is not None:
checkpoints = os.listdir(save_directory)
checkpoints = [d for d in checkpoints if d.startswith(checkpoint_prefix)]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[2]))
# _after_ we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit` checkpoints
if len(checkpoints) > checkpoints_total_limit:
num_to_remove = len(checkpoints) - checkpoints_total_limit
removing_checkpoints = checkpoints[0:num_to_remove]
print(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
print(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(save_directory, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
def is_distributed_training():
if torch.distributed.is_available() and torch.distributed.is_initialized():
return True
world_size = int(os.environ.get("WORLD_SIZE", 1))
return world_size > 1
def contain_invalid_grad(optimizer):
invalid_grad = False
for param_group in optimizer.param_groups:
for param in param_group["params"]:
if param.grad is not None:
invalid_grad = invalid_grad or (
torch.isnan(param.grad).any()
or torch.isinf(param.grad).any()
or torch.isneginf(param.grad).any()
)
if is_distributed_training():
invalid_grad_flag = torch.tensor(
[1.0 if invalid_grad else 0.0],
dtype=torch.float32,
requires_grad=False,
).cuda()
torch.distributed.all_reduce(
invalid_grad_flag, op=torch.distributed.ReduceOp.MAX
)
invalid_grad = invalid_grad_flag.item() > 0
return invalid_grad
def patch_npu_record_stream():
torch.utils.rename_privateuse1_backend("npu")
record_stream = torch.Tensor.record_stream
def _func(*args, **kwargs):
ret = record_stream(*args, **kwargs)
torch.cuda.synchronize()
return ret
torch.Tensor.record_stream = _func
def patch_npu_diffusers_get_1d_rotary_pos_embed():
from typing import Union
import numpy as np
import diffusers
def __get_1d_rotary_pos_embed(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
):
assert dim % 2 == 0
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
theta = theta * ntk_factor
freqs = (
1.0
/ (
theta
** (
torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[
: (dim // 2)
]
/ dim
)
)
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = (
freqs.cos().float().repeat_interleave(2, dim=1).float()
) # [S, D]
freqs_sin = (
freqs.sin().float().repeat_interleave(2, dim=1).float()
) # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# lumina
freqs_cis = torch.polar(
torch.ones_like(freqs), freqs
) # complex64 # [S, D/2]
return freqs_cis
diffusers.models.embeddings.get_1d_rotary_pos_embed = __get_1d_rotary_pos_embed
|