Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
|
@@ -1121,12 +1121,14 @@ class StreamMultiDiffusion(nn.Module):
|
|
| 1121 |
else:
|
| 1122 |
x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
|
| 1123 |
|
|
|
|
| 1124 |
model_pred = self.unet(
|
| 1125 |
x_t_latent_plus_uc.to(self.dtype), # (B, 4, h, w)
|
| 1126 |
t_list, # (B,)
|
| 1127 |
encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
|
| 1128 |
return_dict=False,
|
| 1129 |
)[0] # (B, 4, h, w)
|
|
|
|
| 1130 |
|
| 1131 |
if self.bootstrap_steps[0] > 0:
|
| 1132 |
# Uncentering.
|
|
|
|
| 1121 |
else:
|
| 1122 |
x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
|
| 1123 |
|
| 1124 |
+
print('1111111111111111111111', x_t_latent_plus_uc.dtype, self.unet.dtype, self.prompt_embeds.dtype)
|
| 1125 |
model_pred = self.unet(
|
| 1126 |
x_t_latent_plus_uc.to(self.dtype), # (B, 4, h, w)
|
| 1127 |
t_list, # (B,)
|
| 1128 |
encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
|
| 1129 |
return_dict=False,
|
| 1130 |
)[0] # (B, 4, h, w)
|
| 1131 |
+
print('222222222222222', model_pred.dtype)
|
| 1132 |
|
| 1133 |
if self.bootstrap_steps[0] > 0:
|
| 1134 |
# Uncentering.
|