Spaces:
Runtime error
Runtime error
forgot device type
Browse files- wan/modules/model.py +4 -4
wan/modules/model.py
CHANGED
|
@@ -294,7 +294,7 @@ class WanAttentionBlock(nn.Module):
|
|
| 294 |
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 295 |
"""
|
| 296 |
assert e.dtype == torch.float32
|
| 297 |
-
with amp.autocast(dtype=torch.float32):
|
| 298 |
e = (self.modulation + e).chunk(6, dim=1)
|
| 299 |
assert e[0].dtype == torch.float32
|
| 300 |
|
|
@@ -309,7 +309,7 @@ class WanAttentionBlock(nn.Module):
|
|
| 309 |
def cross_attn_ffn(x, context, context_lens, e):
|
| 310 |
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
| 311 |
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
|
| 312 |
-
with amp.autocast(dtype=torch.float32):
|
| 313 |
x = x + y * e[5]
|
| 314 |
return x
|
| 315 |
|
|
@@ -341,7 +341,7 @@ class Head(nn.Module):
|
|
| 341 |
e(Tensor): Shape [B, C]
|
| 342 |
"""
|
| 343 |
assert e.dtype == torch.float32
|
| 344 |
-
with amp.autocast(dtype=torch.float32):
|
| 345 |
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
| 346 |
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
| 347 |
return x
|
|
@@ -542,7 +542,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 542 |
])
|
| 543 |
|
| 544 |
# time embeddings
|
| 545 |
-
with amp.autocast(dtype=torch.float32):
|
| 546 |
e = self.time_embedding(
|
| 547 |
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
| 548 |
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
|
|
|
| 294 |
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 295 |
"""
|
| 296 |
assert e.dtype == torch.float32
|
| 297 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
| 298 |
e = (self.modulation + e).chunk(6, dim=1)
|
| 299 |
assert e[0].dtype == torch.float32
|
| 300 |
|
|
|
|
| 309 |
def cross_attn_ffn(x, context, context_lens, e):
|
| 310 |
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
| 311 |
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
|
| 312 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
| 313 |
x = x + y * e[5]
|
| 314 |
return x
|
| 315 |
|
|
|
|
| 341 |
e(Tensor): Shape [B, C]
|
| 342 |
"""
|
| 343 |
assert e.dtype == torch.float32
|
| 344 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
| 345 |
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
| 346 |
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
| 347 |
return x
|
|
|
|
| 542 |
])
|
| 543 |
|
| 544 |
# time embeddings
|
| 545 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
| 546 |
e = self.time_embedding(
|
| 547 |
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
| 548 |
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|