Spaces:
Runtime error
Runtime error
Hugo Flores Garcia
commited on
Commit
·
5dafbac
1
Parent(s):
df7025d
fixes #3
Browse files- vampnet/beats.py +2 -1
vampnet/beats.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Tuple
|
|
| 9 |
from typing import Union
|
| 10 |
|
| 11 |
import librosa
|
|
|
|
| 12 |
import numpy as np
|
| 13 |
from audiotools import AudioSignal
|
| 14 |
|
|
@@ -203,7 +204,7 @@ class WaveBeat(BeatTracker):
|
|
| 203 |
def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
|
| 204 |
from wavebeat.dstcn import dsTCNModel
|
| 205 |
|
| 206 |
-
model = dsTCNModel.load_from_checkpoint(ckpt_path)
|
| 207 |
model.eval()
|
| 208 |
|
| 209 |
self.device = device
|
|
|
|
| 9 |
from typing import Union
|
| 10 |
|
| 11 |
import librosa
|
| 12 |
+
import torch
|
| 13 |
import numpy as np
|
| 14 |
from audiotools import AudioSignal
|
| 15 |
|
|
|
|
| 204 |
def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
|
| 205 |
from wavebeat.dstcn import dsTCNModel
|
| 206 |
|
| 207 |
+
model = dsTCNModel.load_from_checkpoint(ckpt_path, map_location=torch.device(device))
|
| 208 |
model.eval()
|
| 209 |
|
| 210 |
self.device = device
|