Spaces:
Running
on
Zero
Running
on
Zero
IceClear
commited on
Commit
·
10946ec
1
Parent(s):
d62797d
update
Browse files
projects/video_diffusion_sr/infer.py
CHANGED
|
@@ -123,12 +123,12 @@ class VideoDiffusionInfer():
|
|
| 123 |
def configure_diffusion(self):
|
| 124 |
self.schedule = create_schedule_from_config(
|
| 125 |
config=self.config.diffusion.schedule,
|
| 126 |
-
device=
|
| 127 |
)
|
| 128 |
self.sampling_timesteps = create_sampling_timesteps_from_config(
|
| 129 |
config=self.config.diffusion.timesteps.sampling,
|
| 130 |
schedule=self.schedule,
|
| 131 |
-
device=
|
| 132 |
)
|
| 133 |
self.sampler = create_sampler_from_config(
|
| 134 |
config=self.config.diffusion.sampler,
|
|
@@ -143,7 +143,7 @@ class VideoDiffusionInfer():
|
|
| 143 |
use_sample = self.config.vae.get("use_sample", True)
|
| 144 |
latents = []
|
| 145 |
if len(samples) > 0:
|
| 146 |
-
device =
|
| 147 |
dtype = getattr(torch, self.config.vae.dtype)
|
| 148 |
scale = self.config.vae.scaling_factor
|
| 149 |
shift = self.config.vae.get("shifting_factor", 0.0)
|
|
@@ -186,7 +186,7 @@ class VideoDiffusionInfer():
|
|
| 186 |
def vae_decode(self, latents: List[Tensor]) -> List[Tensor]:
|
| 187 |
samples = []
|
| 188 |
if len(latents) > 0:
|
| 189 |
-
device =
|
| 190 |
dtype = getattr(torch, self.config.vae.dtype)
|
| 191 |
scale = self.config.vae.scaling_factor
|
| 192 |
shift = self.config.vae.get("shifting_factor", 0.0)
|
|
@@ -340,9 +340,9 @@ class VideoDiffusionInfer():
|
|
| 340 |
self.dit.to("cpu")
|
| 341 |
|
| 342 |
# Vae decode.
|
| 343 |
-
self.vae.to(
|
| 344 |
samples = self.vae_decode(latents)
|
| 345 |
|
| 346 |
if dit_offload:
|
| 347 |
-
self.dit.to(
|
| 348 |
return samples
|
|
|
|
| 123 |
def configure_diffusion(self):
|
| 124 |
self.schedule = create_schedule_from_config(
|
| 125 |
config=self.config.diffusion.schedule,
|
| 126 |
+
device="cuda",
|
| 127 |
)
|
| 128 |
self.sampling_timesteps = create_sampling_timesteps_from_config(
|
| 129 |
config=self.config.diffusion.timesteps.sampling,
|
| 130 |
schedule=self.schedule,
|
| 131 |
+
device="cuda",
|
| 132 |
)
|
| 133 |
self.sampler = create_sampler_from_config(
|
| 134 |
config=self.config.diffusion.sampler,
|
|
|
|
| 143 |
use_sample = self.config.vae.get("use_sample", True)
|
| 144 |
latents = []
|
| 145 |
if len(samples) > 0:
|
| 146 |
+
device = "cuda"
|
| 147 |
dtype = getattr(torch, self.config.vae.dtype)
|
| 148 |
scale = self.config.vae.scaling_factor
|
| 149 |
shift = self.config.vae.get("shifting_factor", 0.0)
|
|
|
|
| 186 |
def vae_decode(self, latents: List[Tensor]) -> List[Tensor]:
|
| 187 |
samples = []
|
| 188 |
if len(latents) > 0:
|
| 189 |
+
device = "cuda"
|
| 190 |
dtype = getattr(torch, self.config.vae.dtype)
|
| 191 |
scale = self.config.vae.scaling_factor
|
| 192 |
shift = self.config.vae.get("shifting_factor", 0.0)
|
|
|
|
| 340 |
self.dit.to("cpu")
|
| 341 |
|
| 342 |
# Vae decode.
|
| 343 |
+
self.vae.to("cuda")
|
| 344 |
samples = self.vae_decode(latents)
|
| 345 |
|
| 346 |
if dit_offload:
|
| 347 |
+
self.dit.to("cuda")
|
| 348 |
return samples
|