Update all files for BitDance-14B-64x-diffusers
Browse files
bitdance_diffusers/modeling_diffusion_head.py
CHANGED
|
@@ -111,7 +111,7 @@ def euler_maruyama(
|
|
| 111 |
dt = t_all[1:] - t_all[:-1]
|
| 112 |
|
| 113 |
t = torch.tensor(0.0, device=c.device, dtype=torch.float32)
|
| 114 |
-
t_batch = torch.zeros(c.shape[0], device=c.device)
|
| 115 |
for i in range(num_sampling_steps):
|
| 116 |
t_batch[:] = t
|
| 117 |
combined = torch.cat([x] * cfg_mult, dim=0)
|
|
@@ -152,6 +152,7 @@ class TimestepEmbedder(nn.Module):
|
|
| 152 |
|
| 153 |
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
| 154 |
t_freq = timestep_embedding(t, self.frequency_embedding_size)
|
|
|
|
| 155 |
return self.mlp(t_freq)
|
| 156 |
|
| 157 |
|
|
@@ -301,6 +302,10 @@ class TransEncoder(nn.Module):
|
|
| 301 |
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 302 |
|
| 303 |
def forward(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
x = self.input_proj(x)
|
| 305 |
t = self.time_embed(t).unsqueeze(1)
|
| 306 |
c = self.cond_embed(c)
|
|
|
|
| 111 |
dt = t_all[1:] - t_all[:-1]
|
| 112 |
|
| 113 |
t = torch.tensor(0.0, device=c.device, dtype=torch.float32)
|
| 114 |
+
t_batch = torch.zeros(c.shape[0], device=c.device, dtype=c.dtype)
|
| 115 |
for i in range(num_sampling_steps):
|
| 116 |
t_batch[:] = t
|
| 117 |
combined = torch.cat([x] * cfg_mult, dim=0)
|
|
|
|
| 152 |
|
| 153 |
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
| 154 |
t_freq = timestep_embedding(t, self.frequency_embedding_size)
|
| 155 |
+
t_freq = t_freq.to(self.mlp[0].weight.dtype)
|
| 156 |
return self.mlp(t_freq)
|
| 157 |
|
| 158 |
|
|
|
|
| 302 |
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 303 |
|
| 304 |
def forward(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
| 305 |
+
dtype = next(self.parameters()).dtype
|
| 306 |
+
x = x.to(dtype)
|
| 307 |
+
t = t.to(dtype)
|
| 308 |
+
c = c.to(dtype)
|
| 309 |
x = self.input_proj(x)
|
| 310 |
t = self.time_embed(t).unsqueeze(1)
|
| 311 |
c = self.cond_embed(c)
|