Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import math | |
| from .model import QwenImageTransformer2DModel | |
| class QwenImageControlNetModel(QwenImageTransformer2DModel): | |
| def __init__( | |
| self, | |
| extra_condition_channels=0, | |
| dtype=None, | |
| device=None, | |
| operations=None, | |
| **kwargs | |
| ): | |
| super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) | |
| self.main_model_double = 60 | |
| # controlnet_blocks | |
| self.controlnet_blocks = torch.nn.ModuleList([]) | |
| for _ in range(len(self.transformer_blocks)): | |
| self.controlnet_blocks.append(operations.Linear(self.inner_dim, self.inner_dim, device=device, dtype=dtype)) | |
| self.controlnet_x_embedder = operations.Linear(self.in_channels + extra_condition_channels, self.inner_dim, device=device, dtype=dtype) | |
| def forward( | |
| self, | |
| x, | |
| timesteps, | |
| context, | |
| attention_mask=None, | |
| guidance: torch.Tensor = None, | |
| ref_latents=None, | |
| hint=None, | |
| transformer_options={}, | |
| **kwargs | |
| ): | |
| timestep = timesteps | |
| encoder_hidden_states = context | |
| encoder_hidden_states_mask = attention_mask | |
| hidden_states, img_ids, orig_shape = self.process_img(x) | |
| hint, _, _ = self.process_img(hint) | |
| txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) | |
| txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) | |
| ids = torch.cat((txt_ids, img_ids), dim=1) | |
| image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) | |
| del ids, txt_ids, img_ids | |
| hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint) | |
| encoder_hidden_states = self.txt_norm(encoder_hidden_states) | |
| encoder_hidden_states = self.txt_in(encoder_hidden_states) | |
| if guidance is not None: | |
| guidance = guidance * 1000 | |
| temb = ( | |
| self.time_text_embed(timestep, hidden_states) | |
| if guidance is None | |
| else self.time_text_embed(timestep, guidance, hidden_states) | |
| ) | |
| repeat = math.ceil(self.main_model_double / len(self.controlnet_blocks)) | |
| controlnet_block_samples = () | |
| for i, block in enumerate(self.transformer_blocks): | |
| encoder_hidden_states, hidden_states = block( | |
| hidden_states=hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_hidden_states_mask=encoder_hidden_states_mask, | |
| temb=temb, | |
| image_rotary_emb=image_rotary_emb, | |
| ) | |
| controlnet_block_samples = controlnet_block_samples + (self.controlnet_blocks[i](hidden_states),) * repeat | |
| return {"input": controlnet_block_samples[:self.main_model_double]} | |