Spaces:
Runtime error
Runtime error
Update min_dalle/models/dalle_bart_decoder.py
Browse files
min_dalle/models/dalle_bart_decoder.py
CHANGED
|
@@ -3,6 +3,7 @@ import torch
|
|
| 3 |
from torch import nn, LongTensor, FloatTensor, BoolTensor
|
| 4 |
from .dalle_bart_encoder import GLU, AttentionBase
|
| 5 |
import gc
|
|
|
|
| 6 |
|
| 7 |
IMAGE_TOKEN_COUNT = 256
|
| 8 |
|
|
@@ -154,6 +155,12 @@ class DalleBartDecoder(nn.Module):
|
|
| 154 |
decoder_state += self.embed_positions.forward(token_index_batched)
|
| 155 |
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
| 156 |
decoder_state = decoder_state[:, None]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
for i in range(self.layer_count):
|
| 158 |
decoder_state, attention_state[i] = self.layers[i].forward(
|
| 159 |
decoder_state,
|
|
@@ -173,6 +180,7 @@ class DalleBartDecoder(nn.Module):
|
|
| 173 |
logits[:image_count] * (1 - supercondition_factor) +
|
| 174 |
logits[image_count:] * supercondition_factor
|
| 175 |
)
|
|
|
|
| 176 |
del supercondition_factor
|
| 177 |
logits_sorted, _ = logits.sort(descending=True)
|
| 178 |
is_kept = logits >= logits_sorted[:, top_k - 1]
|
|
@@ -188,4 +196,6 @@ class DalleBartDecoder(nn.Module):
|
|
| 188 |
del logits
|
| 189 |
gc.collect()
|
| 190 |
|
|
|
|
|
|
|
| 191 |
return image_tokens, attention_state
|
|
|
|
| 3 |
from torch import nn, LongTensor, FloatTensor, BoolTensor
|
| 4 |
from .dalle_bart_encoder import GLU, AttentionBase
|
| 5 |
import gc
|
| 6 |
+
import tracemalloc
|
| 7 |
|
| 8 |
IMAGE_TOKEN_COUNT = 256
|
| 9 |
|
|
|
|
| 155 |
decoder_state += self.embed_positions.forward(token_index_batched)
|
| 156 |
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
| 157 |
decoder_state = decoder_state[:, None]
|
| 158 |
+
|
| 159 |
+
tracemalloc.start()
|
| 160 |
+
print("--")
|
| 161 |
+
# displaying the memory
|
| 162 |
+
print(tracemalloc.get_traced_memory())
|
| 163 |
+
|
| 164 |
for i in range(self.layer_count):
|
| 165 |
decoder_state, attention_state[i] = self.layers[i].forward(
|
| 166 |
decoder_state,
|
|
|
|
| 180 |
logits[:image_count] * (1 - supercondition_factor) +
|
| 181 |
logits[image_count:] * supercondition_factor
|
| 182 |
)
|
| 183 |
+
print(tracemalloc.get_traced_memory())
|
| 184 |
del supercondition_factor
|
| 185 |
logits_sorted, _ = logits.sort(descending=True)
|
| 186 |
is_kept = logits >= logits_sorted[:, top_k - 1]
|
|
|
|
| 196 |
del logits
|
| 197 |
gc.collect()
|
| 198 |
|
| 199 |
+
print(tracemalloc.get_traced_memory())
|
| 200 |
+
|
| 201 |
return image_tokens, attention_state
|