Updated Code to run on Colab Notebook without Errors

#13
Files changed (1) hide show
  1. README.md +4 -1
README.md CHANGED
@@ -153,6 +153,9 @@ def decode_snac_tokens(snac_tokens, snac_model):
153
  if not snac_tokens or len(snac_tokens) % 7 != 0:
154
  return None
155
 
 
 
 
156
  # De-interleave tokens into 3 hierarchical levels
157
  codes_lvl = [[] for _ in range(3)]
158
  llm_codebook_offsets = [AUDIO_CODE_BASE_OFFSET + i * 4096 for i in range(7)]
@@ -172,7 +175,7 @@ def decode_snac_tokens(snac_tokens, snac_model):
172
  # Convert to tensors for SNAC decoder
173
  hierarchical_codes = []
174
  for lvl_codes in codes_lvl:
175
- tensor = torch.tensor(lvl_codes, dtype=torch.int32, device=snac_model.device).unsqueeze(0)
176
  if torch.any((tensor < 0) | (tensor > 4095)):
177
  raise ValueError("Invalid SNAC token values")
178
  hierarchical_codes.append(tensor)
 
153
  if not snac_tokens or len(snac_tokens) % 7 != 0:
154
  return None
155
 
156
+ # Get the device of the SNAC model. Fixed by Shresth to run on colab notebook :)
157
+ snac_device = next(snac_model.parameters()).device
158
+
159
  # De-interleave tokens into 3 hierarchical levels
160
  codes_lvl = [[] for _ in range(3)]
161
  llm_codebook_offsets = [AUDIO_CODE_BASE_OFFSET + i * 4096 for i in range(7)]
 
175
  # Convert to tensors for SNAC decoder
176
  hierarchical_codes = []
177
  for lvl_codes in codes_lvl:
178
+ tensor = torch.tensor(lvl_codes, dtype=torch.int32, device=snac_device).unsqueeze(0)
179
  if torch.any((tensor < 0) | (tensor > 4095)):
180
  raise ValueError("Invalid SNAC token values")
181
  hierarchical_codes.append(tensor)