Spaces:
Running
Running
Update wan/vace.py
Browse files- wan/vace.py +3 -3
wan/vace.py
CHANGED
|
@@ -15,7 +15,7 @@ from PIL import Image
|
|
| 15 |
import torchvision.transforms.functional as TF
|
| 16 |
import torch
|
| 17 |
import torch.nn.functional as F
|
| 18 |
-
import torch.
|
| 19 |
import torch.distributed as dist
|
| 20 |
import torch.multiprocessing as mp
|
| 21 |
from tqdm import tqdm
|
|
@@ -362,7 +362,7 @@ class WanVace(WanT2V):
|
|
| 362 |
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
| 363 |
|
| 364 |
# evaluation mode
|
| 365 |
-
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
|
| 366 |
|
| 367 |
if sample_solver == 'unipc':
|
| 368 |
sample_scheduler = FlowUniPCMultistepScheduler(
|
|
@@ -616,7 +616,7 @@ class WanVaceMP(WanVace):
|
|
| 616 |
no_sync = getattr(model, 'no_sync', noop_no_sync)
|
| 617 |
|
| 618 |
# evaluation mode
|
| 619 |
-
with amp.autocast(dtype=param_dtype), torch.no_grad(), no_sync():
|
| 620 |
|
| 621 |
if sample_solver == 'unipc':
|
| 622 |
sample_scheduler = FlowUniPCMultistepScheduler(
|
|
|
|
| 15 |
import torchvision.transforms.functional as TF
|
| 16 |
import torch
|
| 17 |
import torch.nn.functional as F
|
| 18 |
+
import torch.amp as amp
|
| 19 |
import torch.distributed as dist
|
| 20 |
import torch.multiprocessing as mp
|
| 21 |
from tqdm import tqdm
|
|
|
|
| 362 |
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
| 363 |
|
| 364 |
# evaluation mode
|
| 365 |
+
with amp.autocast("cuda", dtype=self.param_dtype), torch.no_grad(), no_sync():
|
| 366 |
|
| 367 |
if sample_solver == 'unipc':
|
| 368 |
sample_scheduler = FlowUniPCMultistepScheduler(
|
|
|
|
| 616 |
no_sync = getattr(model, 'no_sync', noop_no_sync)
|
| 617 |
|
| 618 |
# evaluation mode
|
| 619 |
+
with amp.autocast("cuda", dtype=param_dtype), torch.no_grad(), no_sync():
|
| 620 |
|
| 621 |
if sample_solver == 'unipc':
|
| 622 |
sample_scheduler = FlowUniPCMultistepScheduler(
|