Spaces:
Runtime error
Runtime error
Update min_dalle/models/dalle_bart_decoder.py
Browse files
min_dalle/models/dalle_bart_decoder.py
CHANGED
|
@@ -175,7 +175,7 @@ class DalleBartDecoder(nn.Module):
|
|
| 175 |
logits /= temperature
|
| 176 |
logits.exp_()
|
| 177 |
logits *= is_kept.to(torch.float32)
|
| 178 |
-
image_tokens = torch.multinomial(logits, 1)[:, 0]
|
| 179 |
-
del logits
|
| 180 |
-
|
| 181 |
-
return
|
|
|
|
| 175 |
logits /= temperature
|
| 176 |
logits.exp_()
|
| 177 |
logits *= is_kept.to(torch.float32)
|
| 178 |
+
#image_tokens = torch.multinomial(logits, 1)[:, 0]
|
| 179 |
+
#del logits
|
| 180 |
+
|
| 181 |
+
return torch.multinomial(logits, 1)[:, 0], attention_state
|