Spaces:
Runtime error
Runtime error
Update min_dalle/models/dalle_bart_decoder.py
Browse files
min_dalle/models/dalle_bart_decoder.py
CHANGED
|
@@ -169,11 +169,14 @@ class DalleBartDecoder(nn.Module):
|
|
| 169 |
attention_mask,
|
| 170 |
token_index
|
| 171 |
)
|
|
|
|
| 172 |
decoder_state = self.final_ln(decoder_state)
|
| 173 |
logits = self.lm_head(decoder_state)
|
|
|
|
| 174 |
del decoder_state
|
| 175 |
temperature = settings[[0]]
|
| 176 |
top_k = settings[[1]].to(torch.long)
|
|
|
|
| 177 |
supercondition_factor = settings[[2]]
|
| 178 |
logits = logits[:, -1, : 2 ** 14]
|
| 179 |
logits: FloatTensor = (
|
|
|
|
| 169 |
attention_mask,
|
| 170 |
token_index
|
| 171 |
)
|
| 172 |
+
print(tracemalloc.get_traced_memory())
|
| 173 |
decoder_state = self.final_ln(decoder_state)
|
| 174 |
logits = self.lm_head(decoder_state)
|
| 175 |
+
print(tracemalloc.get_traced_memory())
|
| 176 |
del decoder_state
|
| 177 |
temperature = settings[[0]]
|
| 178 |
top_k = settings[[1]].to(torch.long)
|
| 179 |
+
print(tracemalloc.get_traced_memory())
|
| 180 |
supercondition_factor = settings[[2]]
|
| 181 |
logits = logits[:, -1, : 2 ** 14]
|
| 182 |
logits: FloatTensor = (
|