Yoshitaka16 commited on
Commit
9cd255d
·
verified ·
1 Parent(s): 9974212

Update djcm_module.py

Browse files
Files changed (1) hide show
  1. djcm_module.py +12 -5
djcm_module.py CHANGED
@@ -1,16 +1,23 @@
1
  import torch
 
 
2
 
3
  class DJCM:
4
- def __init__(self, model_path="models/djcm.pt", device="cuda"):
5
  self.device = device
 
 
 
6
  self.model = torch.load(model_path, map_location=device)
7
  self.model.eval()
8
 
9
- def extract_f0(self, audio_tensor):
10
  """
11
- audio_tensor: torch.Tensor, shape [1, T]
12
- return: f0 tensor
13
  """
 
14
  with torch.no_grad():
15
  f0 = self.model(audio_tensor)
16
- return f0
 
 
1
  import torch
2
+ import os
3
+ import numpy as np
4
 
5
  class DJCM:
6
+ def __init__(self, model_path="/content/HRVC/models/rvc/predictors/djcm.pt", device="cuda"):
7
  self.device = device
8
+ self.model_path = model_path
9
+ if not os.path.exists(model_path):
10
+ raise FileNotFoundError(f"DJCM model not found at {model_path}")
11
  self.model = torch.load(model_path, map_location=device)
12
  self.model.eval()
13
 
14
+ def infer_from_audio(self, audio_np):
15
  """
16
+ audio_np: np.ndarray, shape [T]
17
+ return: f0 np.ndarray
18
  """
19
+ audio_tensor = torch.from_numpy(audio_np).float().unsqueeze(0).to(self.device)
20
  with torch.no_grad():
21
  f0 = self.model(audio_tensor)
22
+ return f0.squeeze().cpu().numpy()
23
+