Spaces:
Runtime error
Runtime error
handle tuple / list for resolutions
Browse files
audiodiffusion/__init__.py
CHANGED
|
@@ -213,11 +213,12 @@ class AudioDiffusionPipeline(DiffusionPipeline):
|
|
| 213 |
step_generator = step_generator or generator
|
| 214 |
# For backwards compatibility
|
| 215 |
if type(self.unet.sample_size) == int:
|
| 216 |
-
self.unet.sample_size =
|
| 217 |
-
self.unet.sample_size
|
| 218 |
if noise is None:
|
| 219 |
noise = torch.randn(
|
| 220 |
-
|
|
|
|
| 221 |
generator=generator)
|
| 222 |
images = noise
|
| 223 |
mask = None
|
|
|
|
| 213 |
step_generator = step_generator or generator
|
| 214 |
# For backwards compatibility
|
| 215 |
if type(self.unet.sample_size) == int:
|
| 216 |
+
self.unet.sample_size = (self.unet.sample_size,
|
| 217 |
+
self.unet.sample_size)
|
| 218 |
if noise is None:
|
| 219 |
noise = torch.randn(
|
| 220 |
+
(batch_size, self.unet.in_channels, self.unet.sample_size[0],
|
| 221 |
+
self.unet.sample_size[1]),
|
| 222 |
generator=generator)
|
| 223 |
images = noise
|
| 224 |
mask = None
|