IceClear commited on
Commit
10946ec
·
1 Parent(s): d62797d
projects/video_diffusion_sr/infer.py CHANGED
@@ -123,12 +123,12 @@ class VideoDiffusionInfer():
123
  def configure_diffusion(self):
124
  self.schedule = create_schedule_from_config(
125
  config=self.config.diffusion.schedule,
126
- device=get_device(),
127
  )
128
  self.sampling_timesteps = create_sampling_timesteps_from_config(
129
  config=self.config.diffusion.timesteps.sampling,
130
  schedule=self.schedule,
131
- device=get_device(),
132
  )
133
  self.sampler = create_sampler_from_config(
134
  config=self.config.diffusion.sampler,
@@ -143,7 +143,7 @@ class VideoDiffusionInfer():
143
  use_sample = self.config.vae.get("use_sample", True)
144
  latents = []
145
  if len(samples) > 0:
146
- device = get_device()
147
  dtype = getattr(torch, self.config.vae.dtype)
148
  scale = self.config.vae.scaling_factor
149
  shift = self.config.vae.get("shifting_factor", 0.0)
@@ -186,7 +186,7 @@ class VideoDiffusionInfer():
186
  def vae_decode(self, latents: List[Tensor]) -> List[Tensor]:
187
  samples = []
188
  if len(latents) > 0:
189
- device = get_device()
190
  dtype = getattr(torch, self.config.vae.dtype)
191
  scale = self.config.vae.scaling_factor
192
  shift = self.config.vae.get("shifting_factor", 0.0)
@@ -340,9 +340,9 @@ class VideoDiffusionInfer():
340
  self.dit.to("cpu")
341
 
342
  # Vae decode.
343
- self.vae.to(get_device())
344
  samples = self.vae_decode(latents)
345
 
346
  if dit_offload:
347
- self.dit.to(get_device())
348
  return samples
 
123
  def configure_diffusion(self):
124
  self.schedule = create_schedule_from_config(
125
  config=self.config.diffusion.schedule,
126
+ device="cuda",
127
  )
128
  self.sampling_timesteps = create_sampling_timesteps_from_config(
129
  config=self.config.diffusion.timesteps.sampling,
130
  schedule=self.schedule,
131
+ device="cuda",
132
  )
133
  self.sampler = create_sampler_from_config(
134
  config=self.config.diffusion.sampler,
 
143
  use_sample = self.config.vae.get("use_sample", True)
144
  latents = []
145
  if len(samples) > 0:
146
+ device = "cuda"
147
  dtype = getattr(torch, self.config.vae.dtype)
148
  scale = self.config.vae.scaling_factor
149
  shift = self.config.vae.get("shifting_factor", 0.0)
 
186
  def vae_decode(self, latents: List[Tensor]) -> List[Tensor]:
187
  samples = []
188
  if len(latents) > 0:
189
+ device = "cuda"
190
  dtype = getattr(torch, self.config.vae.dtype)
191
  scale = self.config.vae.scaling_factor
192
  shift = self.config.vae.get("shifting_factor", 0.0)
 
340
  self.dit.to("cpu")
341
 
342
  # Vae decode.
343
+ self.vae.to("cuda")
344
  samples = self.vae_decode(latents)
345
 
346
  if dit_offload:
347
+ self.dit.to("cuda")
348
  return samples