| """HSIGene UNet - LocalControlUNetModel for hyperspectral generation.""" | |
| import torch | |
| from .diffusion import UNetModel | |
| from .utils import timestep_embedding | |
| class HSIGeneUNet(UNetModel): | |
| """UNet that accepts metadata and local_control from LocalAdapter.""" | |
| def forward( | |
| self, | |
| x, | |
| timesteps=None, | |
| metadata=None, | |
| context=None, | |
| local_control=None, | |
| meta=False, | |
| **kwargs, | |
| ): | |
| hs = [] | |
| t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) | |
| emb = self.time_embed(t_emb) + metadata | |
| h = x.type(self.dtype) | |
| for module in self.input_blocks: | |
| h = module(h, emb, context) | |
| hs.append(h) | |
| h = self.middle_block(h, emb, context) | |
| h += local_control.pop() | |
| for module in self.output_blocks: | |
| h = torch.cat([h, hs.pop() + local_control.pop()], dim=1) | |
| h = module(h, emb, context) | |
| h = h.type(x.dtype) | |
| return self.out(h) | |