Updated Code to run on Colab Notebook without Errors
#13
by
theshresthshukla
- opened
README.md
CHANGED
|
@@ -153,6 +153,9 @@ def decode_snac_tokens(snac_tokens, snac_model):
|
|
| 153 |
if not snac_tokens or len(snac_tokens) % 7 != 0:
|
| 154 |
return None
|
| 155 |
|
|
|
|
|
|
|
|
|
|
| 156 |
# De-interleave tokens into 3 hierarchical levels
|
| 157 |
codes_lvl = [[] for _ in range(3)]
|
| 158 |
llm_codebook_offsets = [AUDIO_CODE_BASE_OFFSET + i * 4096 for i in range(7)]
|
|
@@ -172,7 +175,7 @@ def decode_snac_tokens(snac_tokens, snac_model):
|
|
| 172 |
# Convert to tensors for SNAC decoder
|
| 173 |
hierarchical_codes = []
|
| 174 |
for lvl_codes in codes_lvl:
|
| 175 |
-
tensor = torch.tensor(lvl_codes, dtype=torch.int32, device=
|
| 176 |
if torch.any((tensor < 0) | (tensor > 4095)):
|
| 177 |
raise ValueError("Invalid SNAC token values")
|
| 178 |
hierarchical_codes.append(tensor)
|
|
|
|
| 153 |
if not snac_tokens or len(snac_tokens) % 7 != 0:
|
| 154 |
return None
|
| 155 |
|
| 156 |
+
# Get the device of the SNAC model. Fixed by Shresth to run on colab notebook :)
|
| 157 |
+
snac_device = next(snac_model.parameters()).device
|
| 158 |
+
|
| 159 |
# De-interleave tokens into 3 hierarchical levels
|
| 160 |
codes_lvl = [[] for _ in range(3)]
|
| 161 |
llm_codebook_offsets = [AUDIO_CODE_BASE_OFFSET + i * 4096 for i in range(7)]
|
|
|
|
| 175 |
# Convert to tensors for SNAC decoder
|
| 176 |
hierarchical_codes = []
|
| 177 |
for lvl_codes in codes_lvl:
|
| 178 |
+
tensor = torch.tensor(lvl_codes, dtype=torch.int32, device=snac_device).unsqueeze(0)
|
| 179 |
if torch.any((tensor < 0) | (tensor > 4095)):
|
| 180 |
raise ValueError("Invalid SNAC token values")
|
| 181 |
hierarchical_codes.append(tensor)
|