| import math |
|
|
| import diffusers |
| import torch |
|
|
| if torch.backends.mps.is_available(): |
| torch.empty = torch.zeros |
|
|
|
|
| _torch_layer_norm = torch.nn.functional.layer_norm |
|
|
|
|
| def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05): |
| if input.device.type == "mps" and input.dtype == torch.float16: |
| input = input.float() |
| if weight is not None: |
| weight = weight.float() |
| if bias is not None: |
| bias = bias.float() |
| return _torch_layer_norm(input, normalized_shape, weight, bias, eps).half() |
| else: |
| return _torch_layer_norm(input, normalized_shape, weight, bias, eps) |
|
|
|
|
| torch.nn.functional.layer_norm = new_layer_norm |
|
|
|
|
| _torch_tensor_permute = torch.Tensor.permute |
|
|
|
|
| def new_torch_tensor_permute(input, *dims): |
| result = _torch_tensor_permute(input, *dims) |
| if input.device == "mps" and input.dtype == torch.float16: |
| result = result.contiguous() |
| return result |
|
|
|
|
| torch.Tensor.permute = new_torch_tensor_permute |
|
|
|
|
| _torch_lerp = torch.lerp |
|
|
|
|
| def new_torch_lerp(input, end, weight, *, out=None): |
| if input.device.type == "mps" and input.dtype == torch.float16: |
| input = input.float() |
| end = end.float() |
| if isinstance(weight, torch.Tensor): |
| weight = weight.float() |
| if out is not None: |
| out_fp32 = torch.zeros_like(out, dtype=torch.float32) |
| else: |
| out_fp32 = None |
| result = _torch_lerp(input, end, weight, out=out_fp32) |
| if out is not None: |
| out.copy_(out_fp32.half()) |
| del out_fp32 |
| return result.half() |
|
|
| else: |
| return _torch_lerp(input, end, weight, out=out) |
|
|
|
|
| torch.lerp = new_torch_lerp |
|
|
|
|
| _torch_interpolate = torch.nn.functional.interpolate |
|
|
|
|
| def new_torch_interpolate( |
| input, |
| size=None, |
| scale_factor=None, |
| mode="nearest", |
| align_corners=None, |
| recompute_scale_factor=None, |
| antialias=False, |
| ): |
| if input.device.type == "mps" and input.dtype == torch.float16: |
| return _torch_interpolate( |
| input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias |
| ).half() |
| else: |
| return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias) |
|
|
|
|
| torch.nn.functional.interpolate = new_torch_interpolate |
|
|
| |
| _SlicedAttnProcessor = diffusers.models.attention_processor.SlicedAttnProcessor |
|
|
|
|
| class ChunkedSlicedAttnProcessor: |
| r""" |
| Processor for implementing sliced attention. |
| |
| Args: |
| slice_size (`int`, *optional*): |
| The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and |
| `attention_head_dim` must be a multiple of the `slice_size`. |
| """ |
|
|
| def __init__(self, slice_size): |
| assert isinstance(slice_size, int) |
| slice_size = 1 |
| self.slice_size = slice_size |
| self._sliced_attn_processor = _SlicedAttnProcessor(slice_size) |
|
|
| def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): |
| if self.slice_size != 1 or attn.upcast_attention: |
| return self._sliced_attn_processor(attn, hidden_states, encoder_hidden_states, attention_mask) |
|
|
| residual = hidden_states |
|
|
| input_ndim = hidden_states.ndim |
|
|
| if input_ndim == 4: |
| batch_size, channel, height, width = hidden_states.shape |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
| batch_size, sequence_length, _ = ( |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
| ) |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
| if attn.group_norm is not None: |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
| query = attn.to_q(hidden_states) |
| dim = query.shape[-1] |
| query = attn.head_to_batch_dim(query) |
|
|
| if encoder_hidden_states is None: |
| encoder_hidden_states = hidden_states |
| elif attn.norm_cross: |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
| key = attn.to_k(encoder_hidden_states) |
| value = attn.to_v(encoder_hidden_states) |
| key = attn.head_to_batch_dim(key) |
| value = attn.head_to_batch_dim(value) |
|
|
| batch_size_attention, query_tokens, _ = query.shape |
| hidden_states = torch.zeros( |
| (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype |
| ) |
|
|
| chunk_tmp_tensor = torch.empty( |
| self.slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device |
| ) |
|
|
| for i in range(batch_size_attention // self.slice_size): |
| start_idx = i * self.slice_size |
| end_idx = (i + 1) * self.slice_size |
|
|
| query_slice = query[start_idx:end_idx] |
| key_slice = key[start_idx:end_idx] |
| attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None |
|
|
| self.get_attention_scores_chunked( |
| attn, |
| query_slice, |
| key_slice, |
| attn_mask_slice, |
| hidden_states[start_idx:end_idx], |
| value[start_idx:end_idx], |
| chunk_tmp_tensor, |
| ) |
|
|
| hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
| |
| hidden_states = attn.to_out[0](hidden_states) |
| |
| hidden_states = attn.to_out[1](hidden_states) |
|
|
| if input_ndim == 4: |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
| if attn.residual_connection: |
| hidden_states = hidden_states + residual |
|
|
| hidden_states = hidden_states / attn.rescale_output_factor |
|
|
| return hidden_states |
|
|
| def get_attention_scores_chunked(self, attn, query, key, attention_mask, hidden_states, value, chunk): |
| |
| assert query.shape[0] == 1 |
| assert key.shape[0] == 1 |
| assert value.shape[0] == 1 |
| assert hidden_states.shape[0] == 1 |
|
|
| |
| if attn.upcast_attention: |
| query = query.float() |
| key = key.float() |
|
|
| |
| |
| |
| out_item_size = query.element_size() |
| if attn.upcast_attention: |
| out_item_size = 4 |
|
|
| chunk_size = 2**29 |
|
|
| out_size = query.shape[1] * key.shape[1] * out_item_size |
| chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size)) |
| chunk_step = max(1, int(query.shape[1] / chunks_count)) |
|
|
| key = key.transpose(-1, -2) |
|
|
| def _get_chunk_view(tensor, start, length): |
| if start + length > tensor.shape[1]: |
| length = tensor.shape[1] - start |
| |
| return tensor[:, start : start + length] |
|
|
| for chunk_pos in range(0, query.shape[1], chunk_step): |
| if attention_mask is not None: |
| torch.baddbmm( |
| _get_chunk_view(attention_mask, chunk_pos, chunk_step), |
| _get_chunk_view(query, chunk_pos, chunk_step), |
| key, |
| beta=1, |
| alpha=attn.scale, |
| out=chunk, |
| ) |
| else: |
| torch.baddbmm( |
| torch.zeros((1, 1, 1), device=query.device, dtype=query.dtype), |
| _get_chunk_view(query, chunk_pos, chunk_step), |
| key, |
| beta=0, |
| alpha=attn.scale, |
| out=chunk, |
| ) |
| chunk = chunk.softmax(dim=-1) |
| torch.bmm(chunk, value, out=_get_chunk_view(hidden_states, chunk_pos, chunk_step)) |
|
|
| |
|
|
|
|
| diffusers.models.attention_processor.SlicedAttnProcessor = ChunkedSlicedAttnProcessor |
|
|