Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from dotmap import DotMap | |
| from salad.model_components.simple_module import TimePointWiseEncoder, TimestepEmbedder | |
| from salad.model_components.transformer import ( | |
| PositionalEncoding, | |
| TimeTransformerDecoder, | |
| TimeTransformerEncoder, | |
| ) | |
| class UnCondDiffNetwork(nn.Module): | |
| def __init__(self, input_dim, residual, **kwargs): | |
| """ | |
| Transformer Encoder. | |
| """ | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.residual = residual | |
| self.__dict__.update(kwargs) | |
| self.hparams = DotMap(self.__dict__) | |
| self._build_model() | |
| def _build_model(self): | |
| self.act = F.leaky_relu | |
| if self.hparams.get("use_timestep_embedder"): | |
| self.time_embedder = TimestepEmbedder(self.hparams.timestep_embedder_dim) | |
| dim_ctx = self.hparams.timestep_embedder_dim | |
| else: | |
| dim_ctx = 3 | |
| """ | |
| Encoder part | |
| """ | |
| enc_dim = self.hparams.embedding_dim | |
| self.embedding = nn.Linear(self.hparams.input_dim, enc_dim) | |
| if not self.hparams.get("encoder_type"): | |
| self.encoder = TimeTransformerEncoder( | |
| enc_dim, | |
| dim_ctx=dim_ctx, | |
| num_heads=self.hparams.num_heads | |
| if self.hparams.get("num_heads") | |
| else 4, | |
| use_time=True, | |
| num_layers=self.hparams.enc_num_layers, | |
| last_fc=True, | |
| last_fc_dim_out=self.hparams.input_dim, | |
| ) | |
| else: | |
| if self.hparams.encoder_type == "transformer": | |
| self.encoder = TimeTransformerEncoder( | |
| enc_dim, | |
| dim_ctx=dim_ctx, | |
| num_heads=self.hparams.num_heads | |
| if self.hparams.get("num_heads") | |
| else 4, | |
| use_time=True, | |
| num_layers=self.hparams.enc_num_layers, | |
| last_fc=True, | |
| last_fc_dim_out=self.hparams.input_dim, | |
| dropout=self.hparams.get("attn_dropout", 0.0) | |
| ) | |
| else: | |
| raise ValueError | |
| def forward(self, x, beta): | |
| """ | |
| Input: | |
| x: [B,G,D] latent | |
| beta: B | |
| Output: | |
| eta: [B,G,D] | |
| """ | |
| B, G = x.shape[:2] | |
| if self.hparams.get("use_timestep_embedder"): | |
| time_emb = self.time_embedder(beta).unsqueeze(1) | |
| else: | |
| beta = beta.view(B, 1, 1) | |
| time_emb = torch.cat( | |
| [beta, torch.sin(beta), torch.cos(beta)], dim=-1 | |
| ) # [B,1,3] | |
| ctx = time_emb | |
| x_emb = self.embedding(x) | |
| out = self.encoder(x_emb, ctx=ctx) | |
| if self.hparams.residual: | |
| out = out + x | |
| return out | |
| class CondDiffNetwork(nn.Module): | |
| def __init__(self, input_dim, residual, **kwargs): | |
| """ | |
| Transformer Encoder + Decoder. | |
| """ | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.residual = residual | |
| self.__dict__.update(kwargs) | |
| self.hparams = DotMap(self.__dict__) | |
| self._build_model() | |
| def _build_model(self): | |
| self.act = F.leaky_relu | |
| if self.hparams.get("use_timestep_embedder"): | |
| self.time_embedder = TimestepEmbedder(self.hparams.timestep_embedder_dim) | |
| dim_ctx = self.hparams.timestep_embedder_dim | |
| else: | |
| dim_ctx = 3 | |
| """ | |
| Encoder part | |
| """ | |
| enc_dim = self.hparams.context_embedding_dim | |
| self.context_embedding = nn.Linear(self.hparams.context_dim, enc_dim) | |
| if self.hparams.encoder_type == "transformer": | |
| self.encoder = TimeTransformerEncoder( | |
| enc_dim, | |
| 3, | |
| num_heads=4, | |
| use_time=self.hparams.encoder_use_time, | |
| num_layers=self.hparams.enc_num_layers | |
| if self.hparams.get("enc_num_layers") | |
| else 3, | |
| last_fc=False, | |
| ) | |
| elif self.hparams.encoder_type == "pointwise": | |
| self.encoder = TimePointWiseEncoder( | |
| enc_dim, | |
| dim_ctx=None, | |
| use_time=self.hparams.encoder_use_time, | |
| num_layers=self.hparams.enc_num_layers, | |
| ) | |
| else: | |
| raise ValueError | |
| """ | |
| Decoder part | |
| """ | |
| dec_dim = self.hparams.embedding_dim | |
| input_dim = self.hparams.input_dim | |
| self.query_embedding = nn.Linear(self.hparams.input_dim, dec_dim) | |
| if self.hparams.decoder_type == "transformer_decoder": | |
| self.decoder = TimeTransformerDecoder( | |
| dec_dim, | |
| enc_dim, | |
| dim_ctx=dim_ctx, | |
| num_heads=4, | |
| last_fc=True, | |
| last_fc_dim_out=input_dim, | |
| num_layers=self.hparams.dec_num_layers | |
| if self.hparams.get("dec_num_layers") | |
| else 3, | |
| ) | |
| elif self.hparams.decoder_type == "transformer_encoder": | |
| self.decoder = TimeTransformerEncoder( | |
| dec_dim, | |
| dim_ctx=enc_dim + dim_ctx, | |
| num_heads=4, | |
| last_fc=True, | |
| last_fc_dim_out=input_dim, | |
| num_layers=self.hparams.dec_num_layers | |
| if self.hparams.get("dec_num_layers") | |
| else 3, | |
| ) | |
| else: | |
| raise ValueError | |
| def forward(self, x, beta, context): | |
| """ | |
| Input: | |
| x: [B,G,D] intrinsic | |
| beta: B | |
| context: [B,G,D2] or [B, D2] condition | |
| Output: | |
| eta: [B,G,D] | |
| """ | |
| # print(f"x: {x.shape} context: {context.shape} beta: {beta.shape}") | |
| B, G = x.shape[:2] | |
| if self.hparams.get("use_timestep_embedder"): | |
| time_emb = self.time_embedder(beta).unsqueeze(1) | |
| else: | |
| beta = beta.view(B, 1, 1) | |
| time_emb = torch.cat( | |
| [beta, torch.sin(beta), torch.cos(beta)], dim=-1 | |
| ) # [B,1,3] | |
| ctx = time_emb | |
| """ | |
| Encoding | |
| """ | |
| cout = self.context_embedding(context) | |
| cout = self.encoder(cout, ctx=ctx if self.hparams.encoder_use_time else None) | |
| if cout.ndim == 2: | |
| cout = cout.unsqueeze(1).expand(-1, G, -1) | |
| """ | |
| Decoding | |
| """ | |
| out = self.query_embedding(x) | |
| if self.hparams.get("use_pos_encoding"): | |
| out = self.pos_encoding(out) | |
| if self.hparams.decoder_type == "transformer_encoder": | |
| try: | |
| ctx = ctx.expand(-1, G, -1) | |
| if cout.ndim == 2: | |
| cout = cout.unsqueeze(1) | |
| cout = cout.expand(-1, G, -1) | |
| ctx = torch.cat([ctx, cout], -1) | |
| except Exception as e: | |
| print(e, G, ctx.shape, cout.shape) | |
| out = self.decoder(out, ctx=ctx) | |
| else: | |
| out = self.decoder(out, cout, ctx=ctx) | |
| # if hasattr(self, "last_fc"): | |
| # out = self.last_fc(out) | |
| if self.hparams.residual: | |
| out = out + x | |
| return out | |