Spaces:
Runtime error
Runtime error
Update min_dalle/models/dalle_bart_decoder.py
Browse files
min_dalle/models/dalle_bart_decoder.py
CHANGED
|
@@ -162,14 +162,13 @@ class DalleBartDecoder(nn.Module):
|
|
| 162 |
print(tracemalloc.get_traced_memory())
|
| 163 |
|
| 164 |
for i in range(self.layer_count):
|
| 165 |
-
decoder_state, attention_state[i] = self.layers[i].forward(
|
| 166 |
decoder_state,
|
| 167 |
encoder_state,
|
| 168 |
attention_state[i],
|
| 169 |
attention_mask,
|
| 170 |
token_index
|
| 171 |
)
|
| 172 |
-
del decoder_state
|
| 173 |
print(tracemalloc.get_traced_memory())
|
| 174 |
decoder_state = self.final_ln(decoder_state)
|
| 175 |
logits = self.lm_head(decoder_state)
|
|
|
|
| 162 |
print(tracemalloc.get_traced_memory())
|
| 163 |
|
| 164 |
for i in range(self.layer_count):
|
| 165 |
+
del decoder_state, attention_state[i] = self.layers[i].forward(
|
| 166 |
decoder_state,
|
| 167 |
encoder_state,
|
| 168 |
attention_state[i],
|
| 169 |
attention_mask,
|
| 170 |
token_index
|
| 171 |
)
|
|
|
|
| 172 |
print(tracemalloc.get_traced_memory())
|
| 173 |
decoder_state = self.final_ln(decoder_state)
|
| 174 |
logits = self.lm_head(decoder_state)
|