|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
from typing import Optional |
|
|
|
|
|
from ..._common import precision |
|
|
from ...functional import geglu, matmul, softmax, split |
|
|
from ...layers import Conv2d, GroupNorm, LayerNorm, Linear |
|
|
from ...module import Module, ModuleList |
|
|
|
|
|
|
|
|
class AttentionBlock(Module): |
|
|
|
|
|
def __init__(self, |
|
|
channels: int, |
|
|
num_head_channels: Optional[int] = None, |
|
|
num_groups: int = 32, |
|
|
rescale_output_factor: float = 1.0, |
|
|
eps: float = 1e-5): |
|
|
super().__init__() |
|
|
self.channels = channels |
|
|
|
|
|
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 |
|
|
self.num_head_size = num_head_channels |
|
|
self.group_norm = GroupNorm(num_channels=channels, |
|
|
num_groups=num_groups, |
|
|
eps=eps, |
|
|
affine=True) |
|
|
|
|
|
self.qkv = Linear(channels, channels * 3) |
|
|
self.rescale_output_factor = rescale_output_factor |
|
|
self.proj_attn = Linear(channels, channels, 1) |
|
|
|
|
|
def transpose_for_scores(self, projection): |
|
|
new_projection_shape = projection.size()[:-1] + (self.num_heads, |
|
|
self.num_head_size) |
|
|
|
|
|
new_projection = projection.view(new_projection_shape).permute( |
|
|
[0, 2, 1, 3]) |
|
|
return new_projection |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
assert not hidden_states.is_dynamic() |
|
|
|
|
|
residual = hidden_states |
|
|
batch, channel, height, width = hidden_states.size() |
|
|
|
|
|
|
|
|
hidden_states = self.group_norm(hidden_states) |
|
|
hidden_states = hidden_states.view([batch, channel, |
|
|
height * width]).transpose(1, 2) |
|
|
|
|
|
|
|
|
qkv_proj = self.qkv(hidden_states) |
|
|
|
|
|
query_proj, key_proj, value_proj = split(qkv_proj, channel, dim=2) |
|
|
|
|
|
|
|
|
query_states = self.transpose_for_scores(query_proj) |
|
|
key_states = self.transpose_for_scores(key_proj) |
|
|
value_states = self.transpose_for_scores(value_proj) |
|
|
|
|
|
|
|
|
with precision('float32'): |
|
|
attention_scores = matmul(query_states, |
|
|
(key_states).transpose(-1, -2)) |
|
|
attention_scores = attention_scores / math.sqrt( |
|
|
self.channels / self.num_heads) |
|
|
attention_probs = softmax(attention_scores, dim=-1) |
|
|
|
|
|
|
|
|
hidden_states = matmul(attention_probs, value_states) |
|
|
hidden_states = hidden_states.permute([0, 2, 1, 3]) |
|
|
|
|
|
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels, ) |
|
|
hidden_states = hidden_states.view(new_hidden_states_shape) |
|
|
|
|
|
|
|
|
hidden_states = self.proj_attn(hidden_states) |
|
|
hidden_states = hidden_states.transpose(-1, -2).view( |
|
|
[batch, channel, height, width]) |
|
|
|
|
|
|
|
|
hidden_states = (hidden_states + residual) / self.rescale_output_factor |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
def _transpose_for_scores(tensor, heads): |
|
|
batch_size, seq_len, dim = tensor.size() |
|
|
tensor = tensor.view([batch_size, seq_len, heads, dim // heads]) |
|
|
tensor = tensor.permute([0, 2, 1, 3]) |
|
|
return tensor |
|
|
|
|
|
|
|
|
def _attention(query, key, value, scale): |
|
|
attention_scores = matmul(query, key.transpose(-1, -2)) |
|
|
attention_scores = attention_scores * scale |
|
|
attention_probs = softmax(attention_scores, dim=-1) |
|
|
hidden_states = matmul(attention_probs, value) |
|
|
hidden_states = hidden_states.permute([0, 2, 1, 3]) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class SelfAttention(Module): |
|
|
|
|
|
def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64): |
|
|
super().__init__() |
|
|
self.inner_dim = dim_head * heads |
|
|
self.scale = dim_head**-0.5 |
|
|
self.heads = heads |
|
|
self._slice_size = None |
|
|
|
|
|
self.to_qkv = Linear(query_dim, 3 * self.inner_dim, bias=False) |
|
|
self.to_out = Linear(self.inner_dim, query_dim) |
|
|
|
|
|
def forward(self, hidden_states, mask=None): |
|
|
assert not hidden_states.is_dynamic() |
|
|
|
|
|
qkv = self.to_qkv(hidden_states) |
|
|
|
|
|
query, key, value = split(qkv, self.inner_dim, dim=2) |
|
|
query = _transpose_for_scores(query, self.heads) |
|
|
key = _transpose_for_scores(key, self.heads) |
|
|
value = _transpose_for_scores(value, self.heads) |
|
|
hidden_states = _attention(query, key, value, self.scale) |
|
|
|
|
|
batch_size, seq_len, head_size, head_dim = hidden_states.size() |
|
|
hidden_states = hidden_states.view( |
|
|
[batch_size, seq_len, head_size * head_dim]) |
|
|
return self.to_out(hidden_states) |
|
|
|
|
|
|
|
|
class CrossAttention(Module): |
|
|
|
|
|
def __init__(self, |
|
|
query_dim: int, |
|
|
context_dim: Optional[int] = None, |
|
|
heads: int = 8, |
|
|
dim_head: int = 64): |
|
|
super().__init__() |
|
|
self.inner_dim = dim_head * heads |
|
|
context_dim = context_dim if context_dim is not None else query_dim |
|
|
self.scale = dim_head**-0.5 |
|
|
self.heads = heads |
|
|
self._slice_size = None |
|
|
|
|
|
self.to_q = Linear(query_dim, self.inner_dim, bias=False) |
|
|
self.to_kv = Linear(context_dim, 2 * self.inner_dim, bias=False) |
|
|
self.to_out = Linear(self.inner_dim, query_dim) |
|
|
|
|
|
def forward(self, hidden_states, context=None, mask=None): |
|
|
assert not hidden_states.is_dynamic() |
|
|
query = self.to_q(hidden_states) |
|
|
is_cross_attn = context is not None |
|
|
context = context if is_cross_attn else hidden_states |
|
|
assert not context.is_dynamic() |
|
|
kv = self.to_kv(context) |
|
|
|
|
|
query = _transpose_for_scores(query, self.heads) |
|
|
key, value = split(kv, self.inner_dim, dim=2) |
|
|
key = _transpose_for_scores(key, self.heads) |
|
|
value = _transpose_for_scores(value, self.heads) |
|
|
hidden_states = _attention(query, key, value, self.scale) |
|
|
|
|
|
batch_size, seq_len, head_size, head_dim = hidden_states.size() |
|
|
hidden_states = hidden_states.view( |
|
|
[batch_size, seq_len, head_size * head_dim]) |
|
|
return self.to_out(hidden_states) |
|
|
|
|
|
|
|
|
class FeedForward(Module): |
|
|
|
|
|
def __init__(self, dim: int, dim_out: Optional[int] = None, mult: int = 4): |
|
|
super().__init__() |
|
|
inner_dim = int(dim * mult) |
|
|
dim_out = dim_out if dim_out is not None else dim |
|
|
self.proj_in = Linear(dim, inner_dim * 2) |
|
|
self.proj_out = Linear(inner_dim, dim_out) |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
x = self.proj_in(hidden_states) |
|
|
x = geglu(x) |
|
|
return self.proj_out(x) |
|
|
|
|
|
|
|
|
class BasicTransformerBlock(Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
n_heads: int, |
|
|
d_head: int, |
|
|
context_dim: Optional[int] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.attn1 = SelfAttention(query_dim=dim, |
|
|
heads=n_heads, |
|
|
dim_head=d_head) |
|
|
self.ff = FeedForward(dim) |
|
|
self.attn2 = CrossAttention( |
|
|
query_dim=dim, |
|
|
context_dim=context_dim, |
|
|
heads=n_heads, |
|
|
dim_head=d_head) |
|
|
self.norm1 = LayerNorm(dim) |
|
|
self.norm2 = LayerNorm(dim) |
|
|
self.norm3 = LayerNorm(dim) |
|
|
|
|
|
def forward(self, hidden_states, context=None): |
|
|
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states |
|
|
hidden_states = self.attn2(self.norm2(hidden_states), |
|
|
context=context) + hidden_states |
|
|
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class Transformer2DModel(Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_attention_heads: int = 16, |
|
|
attention_head_dim: int = 88, |
|
|
in_channels: Optional[int] = None, |
|
|
num_layers: int = 1, |
|
|
norm_num_groups: int = 32, |
|
|
cross_attention_dim: Optional[int] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.num_attention_heads = num_attention_heads |
|
|
self.attention_head_dim = attention_head_dim |
|
|
inner_dim = num_attention_heads * attention_head_dim |
|
|
|
|
|
self.norm = GroupNorm(num_groups=norm_num_groups, |
|
|
num_channels=in_channels, |
|
|
eps=1e-6, |
|
|
affine=True) |
|
|
|
|
|
self.proj_in = Conv2d(in_channels, |
|
|
inner_dim, |
|
|
kernel_size=(1, 1), |
|
|
stride=(1, 1), |
|
|
padding=(0, 0)) |
|
|
|
|
|
self.transformer_blocks = ModuleList([ |
|
|
BasicTransformerBlock(inner_dim, |
|
|
num_attention_heads, |
|
|
attention_head_dim, |
|
|
context_dim=cross_attention_dim) |
|
|
for d in range(num_layers) |
|
|
]) |
|
|
self.proj_out = Conv2d(inner_dim, |
|
|
in_channels, |
|
|
kernel_size=(1, 1), |
|
|
stride=(1, 1), |
|
|
padding=(0, 0)) |
|
|
|
|
|
def forward(self, hidden_states, context=None): |
|
|
assert not hidden_states.is_dynamic() |
|
|
batch, _, height, weight = hidden_states.size() |
|
|
residual = hidden_states |
|
|
hidden_states = self.norm(hidden_states) |
|
|
hidden_states = self.proj_in(hidden_states) |
|
|
inner_dim = hidden_states.size()[1] |
|
|
hidden_states = hidden_states.permute([0, 2, 3, 1]).view( |
|
|
[batch, height * weight, inner_dim]) |
|
|
for block in self.transformer_blocks: |
|
|
hidden_states = block(hidden_states, context=context) |
|
|
hidden_states = hidden_states.view([batch, height, weight, |
|
|
inner_dim]).permute([0, 3, 1, 2]) |
|
|
hidden_states = self.proj_out(hidden_states) |
|
|
return hidden_states + residual |
|
|
|