Spaces:
Runtime error
Runtime error
handle steps correctly
Browse files- audiodiffusion/__init__.py +10 -13
- scripts/train_unconditional.py +0 -1
audiodiffusion/__init__.py
CHANGED
|
@@ -177,11 +177,8 @@ class AudioDiffusionPipeline(DiffusionPipeline):
|
|
| 177 |
(float, List[np.ndarray]): sample rate and raw audios
|
| 178 |
"""
|
| 179 |
|
| 180 |
-
if steps is None:
|
| 181 |
-
|
| 182 |
-
# Unfortunately, the schedule is set up in the constructor
|
| 183 |
-
scheduler = self.scheduler.__class__(num_train_timesteps=steps)
|
| 184 |
-
scheduler.set_timesteps(steps)
|
| 185 |
mask = None
|
| 186 |
images = noise = torch.randn(
|
| 187 |
(batch_size, self.unet.in_channels, self.unet.sample_size,
|
|
@@ -204,7 +201,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
|
|
| 204 |
input_images = 0.18215 * input_images
|
| 205 |
|
| 206 |
if start_step > 0:
|
| 207 |
-
images[0, 0] = scheduler.add_noise(
|
| 208 |
torch.tensor(input_images[:, np.newaxis, np.newaxis, :]),
|
| 209 |
noise, torch.tensor(steps - start_step))
|
| 210 |
|
|
@@ -213,18 +210,18 @@ class AudioDiffusionPipeline(DiffusionPipeline):
|
|
| 213 |
mel.x_res)
|
| 214 |
mask_start = int(mask_start_secs * pixels_per_second)
|
| 215 |
mask_end = int(mask_end_secs * pixels_per_second)
|
| 216 |
-
mask = scheduler.add_noise(
|
| 217 |
torch.tensor(input_images[:, np.newaxis, :]), noise,
|
| 218 |
-
torch.tensor(scheduler.timesteps[start_step:]))
|
| 219 |
|
| 220 |
images = images.to(self.device)
|
| 221 |
for step, t in enumerate(
|
| 222 |
-
self.progress_bar(scheduler.timesteps[start_step:])):
|
| 223 |
model_output = self.unet(images, t)['sample']
|
| 224 |
-
images = scheduler.step(model_output,
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
|
| 229 |
if mask is not None:
|
| 230 |
if mask_start > 0:
|
|
|
|
| 177 |
(float, List[np.ndarray]): sample rate and raw audios
|
| 178 |
"""
|
| 179 |
|
| 180 |
+
if steps is not None:
|
| 181 |
+
self.scheduler.set_timesteps(steps)
|
|
|
|
|
|
|
|
|
|
| 182 |
mask = None
|
| 183 |
images = noise = torch.randn(
|
| 184 |
(batch_size, self.unet.in_channels, self.unet.sample_size,
|
|
|
|
| 201 |
input_images = 0.18215 * input_images
|
| 202 |
|
| 203 |
if start_step > 0:
|
| 204 |
+
images[0, 0] = self.scheduler.add_noise(
|
| 205 |
torch.tensor(input_images[:, np.newaxis, np.newaxis, :]),
|
| 206 |
noise, torch.tensor(steps - start_step))
|
| 207 |
|
|
|
|
| 210 |
mel.x_res)
|
| 211 |
mask_start = int(mask_start_secs * pixels_per_second)
|
| 212 |
mask_end = int(mask_end_secs * pixels_per_second)
|
| 213 |
+
mask = self.scheduler.add_noise(
|
| 214 |
torch.tensor(input_images[:, np.newaxis, :]), noise,
|
| 215 |
+
torch.tensor(self.scheduler.timesteps[start_step:]))
|
| 216 |
|
| 217 |
images = images.to(self.device)
|
| 218 |
for step, t in enumerate(
|
| 219 |
+
self.progress_bar(self.scheduler.timesteps[start_step:])):
|
| 220 |
model_output = self.unet(images, t)['sample']
|
| 221 |
+
images = self.scheduler.step(model_output,
|
| 222 |
+
t,
|
| 223 |
+
images,
|
| 224 |
+
generator=generator)['prev_sample']
|
| 225 |
|
| 226 |
if mask is not None:
|
| 227 |
if mask_start > 0:
|
scripts/train_unconditional.py
CHANGED
|
@@ -274,7 +274,6 @@ def main(args):
|
|
| 274 |
mel=mel,
|
| 275 |
generator=generator,
|
| 276 |
batch_size=args.eval_batch_size,
|
| 277 |
-
steps=args.num_train_steps,
|
| 278 |
)
|
| 279 |
|
| 280 |
# denormalize the images and save to tensorboard
|
|
|
|
| 274 |
mel=mel,
|
| 275 |
generator=generator,
|
| 276 |
batch_size=args.eval_batch_size,
|
|
|
|
| 277 |
)
|
| 278 |
|
| 279 |
# denormalize the images and save to tensorboard
|