codewithjarair commited on
Commit
8973227
·
verified ·
1 Parent(s): 722bcf4

Update kokoro_engine.py

Browse files
Files changed (1) hide show
  1. kokoro_engine.py +29 -20
kokoro_engine.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from kokoro import KModel
3
  import numpy as np
4
  import os
5
 
@@ -7,7 +7,11 @@ class KokoroEngine:
7
  def __init__(self, model_path="hexgrad/Kokoro-82M", device=None):
8
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
9
  print(f"Initializing KokoroEngine on {self.device}...")
10
- self.model = KModel(model_path).to(self.device).eval()
 
 
 
 
11
 
12
  # Available voices categorized
13
  self.voices = {
@@ -22,6 +26,14 @@ class KokoroEngine:
22
  "Portuguese": ["pf_dora", "pm_alex"]
23
  }
24
 
 
 
 
 
 
 
 
 
25
  def get_voice_list(self):
26
  all_voices = []
27
  for category in self.voices.values():
@@ -31,23 +43,20 @@ class KokoroEngine:
31
  def generate(self, text, voice="af_heart", speed=1.0, lang='a'):
32
  """
33
  Generates audio from text using a specified voice.
 
 
 
 
 
34
 
35
- Args:
36
- text (str): The text to synthesize.
37
- voice (str or torch.Tensor): The voice ID to use or a voice tensor.
38
- speed (float): The speed factor (default 1.0).
39
- lang (str): Language code (default 'a').
 
 
 
40
 
41
- Returns:
42
- tuple: (audio_numpy, sample_rate)
43
- """
44
- # If voice is a path to a custom .pt file, load it
45
- if isinstance(voice, str) and (voice.endswith(".pt") or voice.endswith(".bin")):
46
- if os.path.exists(voice):
47
- voice = torch.load(voice, map_location=self.device)
48
- else:
49
- print(f"Warning: Voice file {voice} not found. Falling back to af_heart.")
50
- voice = "af_heart"
51
-
52
- audio, out_ps = self.model(text, voice=voice, speed=speed, lang=lang)
53
- return audio, 24000 # Kokoro standard sample rate is 24k
 
1
  import torch
2
+ from kokoro import KModel, KPipeline
3
  import numpy as np
4
  import os
5
 
 
7
  def __init__(self, model_path="hexgrad/Kokoro-82M", device=None):
8
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
9
  print(f"Initializing KokoroEngine on {self.device}...")
10
+
11
+ # Load the base model
12
+ self.model = KModel().to(self.device).eval()
13
+ # Initialize a dictionary to cache pipelines for different languages
14
+ self.pipelines = {}
15
 
16
  # Available voices categorized
17
  self.voices = {
 
26
  "Portuguese": ["pf_dora", "pm_alex"]
27
  }
28
 
29
+ def get_pipeline(self, lang_code):
30
+ """Returns or creates a pipeline for the given language code."""
31
+ if lang_code not in self.pipelines:
32
+ print(f"Creating pipeline for language: {lang_code}")
33
+ # We pass model=self.model to share the underlying weights
34
+ self.pipelines[lang_code] = KPipeline(lang_code=lang_code, model=self.model, device=self.device)
35
+ return self.pipelines[lang_code]
36
+
37
  def get_voice_list(self):
38
  all_voices = []
39
  for category in self.voices.values():
 
43
  def generate(self, text, voice="af_heart", speed=1.0, lang='a'):
44
  """
45
  Generates audio from text using a specified voice.
46
+ """
47
+ pipeline = self.get_pipeline(lang)
48
+
49
+ # Generator returns (gs, ps, audio)
50
+ generator = pipeline(text, voice=voice, speed=speed)
51
 
52
+ # Collect all audio segments
53
+ all_audio = []
54
+ for gs, ps, audio in generator:
55
+ if audio is not None:
56
+ all_audio.append(audio)
57
+
58
+ if not all_audio:
59
+ return None, 24000
60
 
61
+ final_audio = np.concatenate(all_audio)
62
+ return final_audio, 24000