Yoshitaka16 commited on
Commit
f7d8262
·
verified ·
1 Parent(s): 148757c

Update djcm_module.py

Browse files
Files changed (1) hide show
  1. djcm_module.py +9 -13
djcm_module.py CHANGED
@@ -1,23 +1,19 @@
1
  import torch
2
- import os
3
  import numpy as np
4
 
5
- class DJCM:
6
- def __init__(self, model_path="/content/HRVC/models/rvc/predictors/djcm.pt", device="cuda"):
7
  self.device = device
8
- self.model_path = model_path
9
- if not os.path.exists(model_path):
10
- raise FileNotFoundError(f"DJCM model not found at {model_path}")
11
- self.model = torch.load(model_path, map_location=device)
12
  self.model.eval()
13
 
14
- def infer_from_audio(self, audio_np):
15
  """
16
- audio_np: np.ndarray, shape [T]
17
- return: f0 np.ndarray
 
18
  """
19
- audio_tensor = torch.from_numpy(audio_np).float().unsqueeze(0).to(self.device)
20
  with torch.no_grad():
21
- f0 = self.model(audio_tensor)
22
  return f0.squeeze().cpu().numpy()
23
-
 
1
  import torch
 
2
  import numpy as np
3
 
4
+ class DJCMExtractor:
5
+ def __init__(self, model_path, device="cuda"):
6
  self.device = device
7
+ self.model = torch.jit.load(model_path, map_location=device)
 
 
 
8
  self.model.eval()
9
 
10
+ def __call__(self, audio, sr=16000):
11
  """
12
+ audio: numpy array (1D, float32)
13
+ sr: sample rate (default 16k atau sesuaikan dengan DJCM)
14
+ return: f0 contour (numpy array 1D)
15
  """
16
+ x = torch.tensor(audio, dtype=torch.float32, device=self.device).unsqueeze(0)
17
  with torch.no_grad():
18
+ f0 = self.model(x, sr) # Sesuaikan kalau model DJCM butuh input lain
19
  return f0.squeeze().cpu().numpy()