Spaces:
Runtime error
Runtime error
Update min_dalle/min_dalle.py
Browse files- min_dalle/min_dalle.py +3 -3
min_dalle/min_dalle.py
CHANGED
|
@@ -177,7 +177,7 @@ class MinDalle:
|
|
| 177 |
progressive_outputs: bool = False,
|
| 178 |
is_seamless: bool = False,
|
| 179 |
temperature: float = 1,
|
| 180 |
-
top_k: int =
|
| 181 |
supercondition_factor: int = 16,
|
| 182 |
is_verbose: bool = False
|
| 183 |
) -> Iterator[FloatTensor]:
|
|
@@ -252,7 +252,7 @@ class MinDalle:
|
|
| 252 |
)
|
| 253 |
|
| 254 |
with torch.cuda.amp.autocast(dtype=torch.float32):
|
| 255 |
-
if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 ==
|
| 256 |
yield self.image_grid_from_tokens(
|
| 257 |
image_tokens=image_tokens[1:].T,
|
| 258 |
is_seamless=is_seamless,
|
|
@@ -270,7 +270,7 @@ class MinDalle:
|
|
| 270 |
image_stream = self.generate_raw_image_stream(*args, **kwargs)
|
| 271 |
for image in image_stream:
|
| 272 |
grid_size = kwargs['grid_size']
|
| 273 |
-
image = image.view([grid_size *
|
| 274 |
image = image.transpose(1, 0)
|
| 275 |
image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3])
|
| 276 |
yield image
|
|
|
|
| 177 |
progressive_outputs: bool = False,
|
| 178 |
is_seamless: bool = False,
|
| 179 |
temperature: float = 1,
|
| 180 |
+
top_k: int = 128,
|
| 181 |
supercondition_factor: int = 16,
|
| 182 |
is_verbose: bool = False
|
| 183 |
) -> Iterator[FloatTensor]:
|
|
|
|
| 252 |
)
|
| 253 |
|
| 254 |
with torch.cuda.amp.autocast(dtype=torch.float32):
|
| 255 |
+
if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 128:
|
| 256 |
yield self.image_grid_from_tokens(
|
| 257 |
image_tokens=image_tokens[1:].T,
|
| 258 |
is_seamless=is_seamless,
|
|
|
|
| 270 |
image_stream = self.generate_raw_image_stream(*args, **kwargs)
|
| 271 |
for image in image_stream:
|
| 272 |
grid_size = kwargs['grid_size']
|
| 273 |
+
image = image.view([grid_size * 128, grid_size, 128, 3])
|
| 274 |
image = image.transpose(1, 0)
|
| 275 |
image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3])
|
| 276 |
yield image
|