Spaces:
Sleeping
Sleeping
Commit ·
d6a5fdd
1
Parent(s): d81ea2d
Update block_baseline.py
Browse files- block_baseline.py +2 -2
block_baseline.py
CHANGED
|
@@ -23,7 +23,7 @@ def get_bins(vocab_size, block_size):
|
|
| 23 |
|
| 24 |
return bin2words, words2bin
|
| 25 |
|
| 26 |
-
def encode_block(model, enc, message, context, block_size, bin2words, words2bin, finish_sent=False, device='
|
| 27 |
length = len(message)
|
| 28 |
|
| 29 |
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
|
|
@@ -96,7 +96,7 @@ def encode_block(model, enc, message, context, block_size, bin2words, words2bin,
|
|
| 96 |
|
| 97 |
return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
|
| 98 |
|
| 99 |
-
def decode_block(model, enc, text, context, block_size, bin2words, words2bin, device='
|
| 100 |
# inp is a list of token indices
|
| 101 |
# context is a list of token indices
|
| 102 |
inp = enc.encode(text)
|
|
|
|
| 23 |
|
| 24 |
return bin2words, words2bin
|
| 25 |
|
| 26 |
+
def encode_block(model, enc, message, context, block_size, bin2words, words2bin, finish_sent=False, device='cpu'):
|
| 27 |
length = len(message)
|
| 28 |
|
| 29 |
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
|
|
|
|
| 96 |
|
| 97 |
return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
|
| 98 |
|
| 99 |
+
def decode_block(model, enc, text, context, block_size, bin2words, words2bin, device='cpu'):
|
| 100 |
# inp is a list of token indices
|
| 101 |
# context is a list of token indices
|
| 102 |
inp = enc.encode(text)
|