Spaces:
Runtime error
Runtime error
Update min_dalle/models/dalle_bart_decoder.py
Browse files
min_dalle/models/dalle_bart_decoder.py
CHANGED
|
@@ -2,6 +2,7 @@ from typing import Tuple, List
|
|
| 2 |
import torch
|
| 3 |
from torch import nn, LongTensor, FloatTensor, BoolTensor
|
| 4 |
from .dalle_bart_encoder import GLU, AttentionBase
|
|
|
|
| 5 |
|
| 6 |
IMAGE_TOKEN_COUNT = 256
|
| 7 |
|
|
@@ -100,6 +101,8 @@ class DecoderLayer(nn.Module):
|
|
| 100 |
decoder_state = self.glu.forward(decoder_state)
|
| 101 |
decoder_state = residual + decoder_state
|
| 102 |
|
|
|
|
|
|
|
| 103 |
return decoder_state, attention_state
|
| 104 |
|
| 105 |
|
|
@@ -170,6 +173,7 @@ class DalleBartDecoder(nn.Module):
|
|
| 170 |
logits[:image_count] * (1 - supercondition_factor) +
|
| 171 |
logits[image_count:] * supercondition_factor
|
| 172 |
)
|
|
|
|
| 173 |
logits_sorted, _ = logits.sort(descending=True)
|
| 174 |
is_kept = logits >= logits_sorted[:, top_k - 1]
|
| 175 |
del top_k
|
|
@@ -179,7 +183,9 @@ class DalleBartDecoder(nn.Module):
|
|
| 179 |
del temperature
|
| 180 |
logits.exp_()
|
| 181 |
logits *= is_kept.to(torch.float32)
|
|
|
|
| 182 |
image_tokens = torch.multinomial(logits, 1)[:, 0]
|
| 183 |
del logits
|
|
|
|
| 184 |
|
| 185 |
return image_tokens, attention_state
|
|
|
|
| 2 |
import torch
|
| 3 |
from torch import nn, LongTensor, FloatTensor, BoolTensor
|
| 4 |
from .dalle_bart_encoder import GLU, AttentionBase
|
| 5 |
+
import gc
|
| 6 |
|
| 7 |
IMAGE_TOKEN_COUNT = 256
|
| 8 |
|
|
|
|
| 101 |
decoder_state = self.glu.forward(decoder_state)
|
| 102 |
decoder_state = residual + decoder_state
|
| 103 |
|
| 104 |
+
|
| 105 |
+
|
| 106 |
return decoder_state, attention_state
|
| 107 |
|
| 108 |
|
|
|
|
| 173 |
logits[:image_count] * (1 - supercondition_factor) +
|
| 174 |
logits[image_count:] * supercondition_factor
|
| 175 |
)
|
| 176 |
+
del supercondition_factor
|
| 177 |
logits_sorted, _ = logits.sort(descending=True)
|
| 178 |
is_kept = logits >= logits_sorted[:, top_k - 1]
|
| 179 |
del top_k
|
|
|
|
| 183 |
del temperature
|
| 184 |
logits.exp_()
|
| 185 |
logits *= is_kept.to(torch.float32)
|
| 186 |
+
del is_kept
|
| 187 |
image_tokens = torch.multinomial(logits, 1)[:, 0]
|
| 188 |
del logits
|
| 189 |
+
gc.collect()
|
| 190 |
|
| 191 |
return image_tokens, attention_state
|