BiliSakura commited on
Commit
b26647d
·
verified ·
1 Parent(s): 2668ff9

Update all files for BitDance-14B-64x-diffusers

Browse files
bitdance_diffusers/modeling_diffusion_head.py CHANGED
@@ -111,7 +111,7 @@ def euler_maruyama(
111
  dt = t_all[1:] - t_all[:-1]
112
 
113
  t = torch.tensor(0.0, device=c.device, dtype=torch.float32)
114
- t_batch = torch.zeros(c.shape[0], device=c.device)
115
  for i in range(num_sampling_steps):
116
  t_batch[:] = t
117
  combined = torch.cat([x] * cfg_mult, dim=0)
@@ -152,6 +152,7 @@ class TimestepEmbedder(nn.Module):
152
 
153
  def forward(self, t: torch.Tensor) -> torch.Tensor:
154
  t_freq = timestep_embedding(t, self.frequency_embedding_size)
 
155
  return self.mlp(t_freq)
156
 
157
 
@@ -301,6 +302,10 @@ class TransEncoder(nn.Module):
301
  nn.init.constant_(self.final_layer.linear.bias, 0)
302
 
303
  def forward(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
 
 
 
 
304
  x = self.input_proj(x)
305
  t = self.time_embed(t).unsqueeze(1)
306
  c = self.cond_embed(c)
 
111
  dt = t_all[1:] - t_all[:-1]
112
 
113
  t = torch.tensor(0.0, device=c.device, dtype=torch.float32)
114
+ t_batch = torch.zeros(c.shape[0], device=c.device, dtype=c.dtype)
115
  for i in range(num_sampling_steps):
116
  t_batch[:] = t
117
  combined = torch.cat([x] * cfg_mult, dim=0)
 
152
 
153
  def forward(self, t: torch.Tensor) -> torch.Tensor:
154
  t_freq = timestep_embedding(t, self.frequency_embedding_size)
155
+ t_freq = t_freq.to(self.mlp[0].weight.dtype)
156
  return self.mlp(t_freq)
157
 
158
 
 
302
  nn.init.constant_(self.final_layer.linear.bias, 0)
303
 
304
  def forward(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
305
+ dtype = next(self.parameters()).dtype
306
+ x = x.to(dtype)
307
+ t = t.to(dtype)
308
+ c = c.to(dtype)
309
  x = self.input_proj(x)
310
  t = self.time_embed(t).unsqueeze(1)
311
  c = self.cond_embed(c)