Spaces:
Runtime error
Runtime error
| import logging | |
| import torch | |
| from einops import rearrange, repeat | |
| from lvdm.models.utils_diffusion import timestep_embedding | |
| try: | |
| import xformers | |
| import xformers.ops | |
| XFORMERS_IS_AVAILBLE = True | |
| except: | |
| XFORMERS_IS_AVAILBLE = False | |
| mainlogger = logging.getLogger('mainlogger') | |
| def TemporalTransformer_forward(self, x, context=None, is_imgbatch=False): | |
| b, c, t, h, w = x.shape | |
| x_in = x | |
| x = self.norm(x) | |
| x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous() | |
| if not self.use_linear: | |
| x = self.proj_in(x) | |
| x = rearrange(x, 'bhw c t -> bhw t c').contiguous() | |
| if self.use_linear: | |
| x = self.proj_in(x) | |
| temp_mask = None | |
| if self.causal_attention: | |
| temp_mask = torch.tril(torch.ones([1, t, t])) | |
| if is_imgbatch: | |
| temp_mask = torch.eye(t).unsqueeze(0) | |
| if temp_mask is not None: | |
| mask = temp_mask.to(x.device) | |
| mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b*h*w) | |
| else: | |
| mask = None | |
| if self.only_self_att: | |
| ## note: if no context is given, cross-attention defaults to self-attention | |
| for i, block in enumerate(self.transformer_blocks): | |
| x = block(x, context=context, mask=mask) | |
| x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() | |
| else: | |
| x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() | |
| context = rearrange(context, '(b t) l con -> b t l con', t=t).contiguous() | |
| for i, block in enumerate(self.transformer_blocks): | |
| # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) | |
| for j in range(b): | |
| unit_context = context[j][0:1] | |
| context_j = repeat(unit_context, 't l con -> (t r) l con', r=(h * w)).contiguous() | |
| ## note: causal mask will not applied in cross-attention case | |
| x[j] = block(x[j], context=context_j) | |
| if self.use_linear: | |
| x = self.proj_out(x) | |
| x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous() | |
| if not self.use_linear: | |
| x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous() | |
| x = self.proj_out(x) | |
| x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h, w=w).contiguous() | |
| if self.use_image_dataset: | |
| x = 0.0 * x + x_in | |
| else: | |
| x = x + x_in | |
| return x | |
| def selfattn_forward_unet(self, x, timesteps, context=None, y=None, features_adapter=None, is_imgbatch=False, T=None, **kwargs): | |
| b,_,t,_,_ = x.shape | |
| t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) | |
| emb = self.time_embed(t_emb) | |
| if self.micro_condition and y is not None: | |
| micro_emb = timestep_embedding(y, self.model_channels, repeat_only=False) | |
| emb = emb + self.micro_embed(micro_emb) | |
| # pose_emb = pose_emb.reshape(-1, pose_emb.shape[-1]) | |
| ## repeat t times for context [(b t) 77 768] & time embedding | |
| if not is_imgbatch: | |
| context = context.repeat_interleave(repeats=t, dim=0) | |
| if 'pose_emb' in kwargs: | |
| pose_emb = kwargs.pop('pose_emb') | |
| context = { 'context': context, 'pose_emb': pose_emb } | |
| emb = emb.repeat_interleave(repeats=t, dim=0) | |
| ## always in shape (b t) c h w, except for temporal layer | |
| x = rearrange(x, 'b c t h w -> (b t) c h w') | |
| if features_adapter is not None: | |
| features_adapter = [rearrange(feature, 'b c t h w -> (b t) c h w') for feature in features_adapter] | |
| h = x.type(self.dtype) | |
| adapter_idx = 0 | |
| hs = [] | |
| for id, module in enumerate(self.input_blocks): | |
| h = module(h, emb, context=context, batch_size=b,is_imgbatch=is_imgbatch) | |
| if id ==0 and self.addition_attention: | |
| h = self.init_attn(h, emb, context=context, batch_size=b,is_imgbatch=is_imgbatch) | |
| ## plug-in adapter features | |
| if ((id+1)%3 == 0) and features_adapter is not None: | |
| # if adapter_idx == 0 or adapter_idx == 1 or adapter_idx == 2: | |
| h = h + features_adapter[adapter_idx] | |
| adapter_idx += 1 | |
| hs.append(h) | |
| if features_adapter is not None: | |
| assert len(features_adapter)==adapter_idx, 'Wrong features_adapter' | |
| h = self.middle_block(h, emb, context=context, batch_size=b, is_imgbatch=is_imgbatch) | |
| for module in self.output_blocks: | |
| h = torch.cat([h, hs.pop()], dim=1) | |
| h = module(h, emb, context=context, batch_size=b, is_imgbatch=is_imgbatch) | |
| h = h.type(x.dtype) | |
| y = self.out(h) | |
| # reshape back to (b c t h w) | |
| y = rearrange(y, '(b t) c h w -> b c t h w', b=b) | |
| return y | |
| def spatial_forward_BasicTransformerBlock(self, x, context=None, mask=None): | |
| if isinstance(context, dict): | |
| context = context['context'] | |
| x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x | |
| x = self.attn2(self.norm2(x), context=context, mask=mask) + x | |
| x = self.ff(self.norm3(x)) + x | |
| return x | |
| def temporal_selfattn_forward_BasicTransformerBlock(self, x, context=None, mask=None): | |
| if isinstance(context, dict) and 'pose_emb' in context: | |
| pose_emb = context['pose_emb'] # {channel_num: [B, video_length, pose_dim, pose_embedding_dim]} | |
| context = None | |
| else: | |
| pose_emb = None | |
| context = None | |
| x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x | |
| # Add camera pose | |
| if pose_emb is not None: | |
| B, t, _, _ = pose_emb.shape # [B, video_length, pose_dim, pose_embedding_dim] | |
| hw = x.shape[0] // B | |
| pose_emb = pose_emb.reshape(B, t, -1) | |
| pose_emb = pose_emb.repeat_interleave(repeats=hw, dim=0) | |
| x = self.cc_projection(torch.cat([x, pose_emb], dim=-1)) | |
| x = self.attn2(self.norm2(x), context=context, mask=mask) + x | |
| x = self.ff(self.norm3(x)) + x | |
| return x | |