Spaces:
Runtime error
Runtime error
| import torch | |
| def conv_forward(self): | |
| def forward(input_tensor, temb, scale=1.0): | |
| hidden_states = input_tensor | |
| hidden_states = self.norm1(hidden_states) | |
| hidden_states = self.nonlinearity(hidden_states) | |
| # import pdb; pdb.set_trace() | |
| if self.upsample is not None: | |
| # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 | |
| if hidden_states.shape[0] >= 64: | |
| input_tensor = input_tensor.contiguous() | |
| hidden_states = hidden_states.contiguous() | |
| input_tensor = self.upsample(input_tensor) | |
| hidden_states = self.upsample(hidden_states) | |
| elif self.downsample is not None: | |
| input_tensor = self.downsample(input_tensor) | |
| hidden_states = self.downsample(hidden_states) | |
| hidden_states = self.conv1(hidden_states) | |
| if temb is not None: | |
| temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None].repeat(1, 1, hidden_states.shape[2], 1, 1) | |
| if temb is not None and self.time_embedding_norm == "default": | |
| hidden_states = hidden_states + temb | |
| hidden_states = self.norm2(hidden_states) | |
| if temb is not None and self.time_embedding_norm == "scale_shift": | |
| scale, shift = torch.chunk(temb, 2, dim=1) | |
| hidden_states = hidden_states * (1 + scale) + shift | |
| hidden_states = self.nonlinearity(hidden_states) | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_states = self.conv2(hidden_states) | |
| # record hidden state | |
| self.record_hidden_state = hidden_states | |
| if self.conv_shortcut is not None: | |
| input_tensor = self.conv_shortcut(input_tensor) | |
| output_tensor = (input_tensor + hidden_states) / self.output_scale_factor | |
| return output_tensor | |
| return forward | |
| def get_conv_feat(unet): | |
| hidden_state_dict = dict() | |
| for i in range(len(unet.up_blocks)): | |
| for j in range(len(unet.up_blocks[i].resnets)): | |
| module = unet.up_blocks[i].resnets[j] | |
| module_name = f"up_blocks.{i}.resnets.{j}" | |
| # print(module_name) | |
| hidden_state_dict[module_name] = module.record_hidden_state | |
| return hidden_state_dict | |
| def prep_unet_conv(unet): | |
| for i in range(len(unet.up_blocks)): | |
| for j in range(len(unet.up_blocks[i].resnets)): | |
| module = unet.up_blocks[i].resnets[j] | |
| module.forward = conv_forward(module) | |
| return unet | |