HSIGene / unet /model.py
BiliSakura's picture
Add files using upload-large-folder tool
66a2b45 verified
"""HSIGene UNet - LocalControlUNetModel for hyperspectral generation."""
import torch
from .diffusion import UNetModel
from .utils import timestep_embedding
class HSIGeneUNet(UNetModel):
"""UNet that accepts metadata and local_control from LocalAdapter."""
def forward(
self,
x,
timesteps=None,
metadata=None,
context=None,
local_control=None,
meta=False,
**kwargs,
):
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb) + metadata
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
h += local_control.pop()
for module in self.output_blocks:
h = torch.cat([h, hs.pop() + local_control.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
return self.out(h)