hubert_base / djcm_module.py
Yoshitaka16's picture
Update djcm_module.py
f7d8262 verified
raw
history blame contribute delete
685 Bytes
import torch
import numpy as np
class DJCMExtractor:
def __init__(self, model_path, device="cuda"):
self.device = device
self.model = torch.jit.load(model_path, map_location=device)
self.model.eval()
def __call__(self, audio, sr=16000):
"""
audio: numpy array (1D, float32)
sr: sample rate (default 16k atau sesuaikan dengan DJCM)
return: f0 contour (numpy array 1D)
"""
x = torch.tensor(audio, dtype=torch.float32, device=self.device).unsqueeze(0)
with torch.no_grad():
f0 = self.model(x, sr) # Sesuaikan kalau model DJCM butuh input lain
return f0.squeeze().cpu().numpy()