Spaces:
Runtime error
Runtime error
Update min_dalle/min_dalle.py
Browse files- min_dalle/min_dalle.py +6 -6
min_dalle/min_dalle.py
CHANGED
|
@@ -17,7 +17,7 @@ torch.backends.cudnn.enabled = True
|
|
| 17 |
torch.backends.cudnn.allow_tf32 = True
|
| 18 |
|
| 19 |
MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
|
| 20 |
-
IMAGE_TOKEN_COUNT =
|
| 21 |
|
| 22 |
|
| 23 |
class MinDalle:
|
|
@@ -177,7 +177,7 @@ class MinDalle:
|
|
| 177 |
progressive_outputs: bool = False,
|
| 178 |
is_seamless: bool = False,
|
| 179 |
temperature: float = 1,
|
| 180 |
-
top_k: int =
|
| 181 |
supercondition_factor: int = 16,
|
| 182 |
is_verbose: bool = False
|
| 183 |
) -> Iterator[FloatTensor]:
|
|
@@ -239,8 +239,8 @@ class MinDalle:
|
|
| 239 |
break
|
| 240 |
st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
|
| 241 |
|
| 242 |
-
#torch.cuda.empty_cache()
|
| 243 |
-
#torch.
|
| 244 |
with torch.cuda.amp.autocast(dtype=self.dtype):
|
| 245 |
image_tokens[i + 1], attention_state = self.decoder.forward(
|
| 246 |
settings=settings,
|
|
@@ -252,7 +252,7 @@ class MinDalle:
|
|
| 252 |
)
|
| 253 |
|
| 254 |
with torch.cuda.amp.autocast(dtype=torch.float32):
|
| 255 |
-
if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 ==
|
| 256 |
yield self.image_grid_from_tokens(
|
| 257 |
image_tokens=image_tokens[1:].T,
|
| 258 |
is_seamless=is_seamless,
|
|
@@ -270,7 +270,7 @@ class MinDalle:
|
|
| 270 |
image_stream = self.generate_raw_image_stream(*args, **kwargs)
|
| 271 |
for image in image_stream:
|
| 272 |
grid_size = kwargs['grid_size']
|
| 273 |
-
image = image.view([grid_size *
|
| 274 |
image = image.transpose(1, 0)
|
| 275 |
image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3])
|
| 276 |
yield image
|
|
|
|
| 17 |
torch.backends.cudnn.allow_tf32 = True
|
| 18 |
|
| 19 |
MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
|
| 20 |
+
IMAGE_TOKEN_COUNT = 256
|
| 21 |
|
| 22 |
|
| 23 |
class MinDalle:
|
|
|
|
| 177 |
progressive_outputs: bool = False,
|
| 178 |
is_seamless: bool = False,
|
| 179 |
temperature: float = 1,
|
| 180 |
+
top_k: int = 256,
|
| 181 |
supercondition_factor: int = 16,
|
| 182 |
is_verbose: bool = False
|
| 183 |
) -> Iterator[FloatTensor]:
|
|
|
|
| 239 |
break
|
| 240 |
st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
|
| 241 |
|
| 242 |
+
#torch.cuda.empty_cache()
|
| 243 |
+
#torch.cpu.empty_cache()
|
| 244 |
with torch.cuda.amp.autocast(dtype=self.dtype):
|
| 245 |
image_tokens[i + 1], attention_state = self.decoder.forward(
|
| 246 |
settings=settings,
|
|
|
|
| 252 |
)
|
| 253 |
|
| 254 |
with torch.cuda.amp.autocast(dtype=torch.float32):
|
| 255 |
+
if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 256:
|
| 256 |
yield self.image_grid_from_tokens(
|
| 257 |
image_tokens=image_tokens[1:].T,
|
| 258 |
is_seamless=is_seamless,
|
|
|
|
| 270 |
image_stream = self.generate_raw_image_stream(*args, **kwargs)
|
| 271 |
for image in image_stream:
|
| 272 |
grid_size = kwargs['grid_size']
|
| 273 |
+
image = image.view([grid_size * 256, grid_size, 256, 3])
|
| 274 |
image = image.transpose(1, 0)
|
| 275 |
image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3])
|
| 276 |
yield image
|