Spaces:
Runtime error
Runtime error
Update min_dalle/models/dalle_bart_decoder.py
Browse files
min_dalle/models/dalle_bart_decoder.py
CHANGED
|
@@ -161,7 +161,7 @@ class DalleBartDecoder(nn.Module):
|
|
| 161 |
)
|
| 162 |
decoder_state = self.final_ln(decoder_state)
|
| 163 |
logits = self.lm_head(decoder_state)
|
| 164 |
-
del
|
| 165 |
temperature = settings[[0]]
|
| 166 |
top_k = settings[[1]].to(torch.long)
|
| 167 |
supercondition_factor = settings[[2]]
|
|
@@ -176,6 +176,7 @@ class DalleBartDecoder(nn.Module):
|
|
| 176 |
logits -= logits_sorted[:, [0]]
|
| 177 |
del logits_sorted
|
| 178 |
logits /= temperature
|
|
|
|
| 179 |
logits.exp_()
|
| 180 |
logits *= is_kept.to(torch.float32)
|
| 181 |
image_tokens = torch.multinomial(logits, 1)[:, 0]
|
|
|
|
| 161 |
)
|
| 162 |
decoder_state = self.final_ln(decoder_state)
|
| 163 |
logits = self.lm_head(decoder_state)
|
| 164 |
+
del decoder_state
|
| 165 |
temperature = settings[[0]]
|
| 166 |
top_k = settings[[1]].to(torch.long)
|
| 167 |
supercondition_factor = settings[[2]]
|
|
|
|
| 176 |
logits -= logits_sorted[:, [0]]
|
| 177 |
del logits_sorted
|
| 178 |
logits /= temperature
|
| 179 |
+
del temperature
|
| 180 |
logits.exp_()
|
| 181 |
logits *= is_kept.to(torch.float32)
|
| 182 |
image_tokens = torch.multinomial(logits, 1)[:, 0]
|