Spaces:
Runtime error
Runtime error
Change deprecated cuda amp calls
Browse files- wan/modules/model.py +3 -3
wan/modules/model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
import math
|
| 3 |
|
| 4 |
import torch
|
| 5 |
-
import torch.
|
| 6 |
import torch.nn as nn
|
| 7 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 8 |
from diffusers.models.modeling_utils import ModelMixin
|
|
@@ -28,7 +28,7 @@ def sinusoidal_embedding_1d(dim, position):
|
|
| 28 |
return x
|
| 29 |
|
| 30 |
|
| 31 |
-
@amp.autocast(enabled=False)
|
| 32 |
def rope_params(max_seq_len, dim, theta=10000):
|
| 33 |
assert dim % 2 == 0
|
| 34 |
freqs = torch.outer(
|
|
@@ -39,7 +39,7 @@ def rope_params(max_seq_len, dim, theta=10000):
|
|
| 39 |
return freqs
|
| 40 |
|
| 41 |
|
| 42 |
-
@amp.autocast(enabled=False)
|
| 43 |
def rope_apply(x, grid_sizes, freqs):
|
| 44 |
n, c = x.size(2), x.size(3) // 2
|
| 45 |
|
|
|
|
| 2 |
import math
|
| 3 |
|
| 4 |
import torch
|
| 5 |
+
import torch.amp as amp
|
| 6 |
import torch.nn as nn
|
| 7 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 8 |
from diffusers.models.modeling_utils import ModelMixin
|
|
|
|
| 28 |
return x
|
| 29 |
|
| 30 |
|
| 31 |
+
@amp.autocast("cuda", enabled=False)
|
| 32 |
def rope_params(max_seq_len, dim, theta=10000):
|
| 33 |
assert dim % 2 == 0
|
| 34 |
freqs = torch.outer(
|
|
|
|
| 39 |
return freqs
|
| 40 |
|
| 41 |
|
| 42 |
+
@amp.autocast("cuda", enabled=False)
|
| 43 |
def rope_apply(x, grid_sizes, freqs):
|
| 44 |
n, c = x.size(2), x.size(3) // 2
|
| 45 |
|