Nick021402 commited on
Commit
daae1c2
·
verified ·
1 Parent(s): 3491ff8

Update tts_engine.py

Browse files
Files changed (1) hide show
  1. tts_engine.py +10 -29
tts_engine.py CHANGED
@@ -45,13 +45,13 @@ class NariDIAEngine:
45
  logger.info("Nari DIA model initialized successfully.")
46
 
47
  except Exception as e:
48
- logger.error(f"Failed to initialize Nari DIA model: {e}")
49
  self.model = None
50
 
51
  def synthesize_segment(
52
  self,
53
  text: str,
54
- speaker: str, # This will be 'speaker1', 'speaker2' from segmenter
55
  output_path: str
56
  ) -> Optional[str]:
57
  """
@@ -59,7 +59,7 @@ class NariDIAEngine:
59
 
60
  Args:
61
  text: Text to synthesize
62
- speaker: Speaker identifier (e.g., 'speaker1', 'speaker2')
63
  output_path: Path to save the audio file
64
 
65
  Returns:
@@ -71,24 +71,14 @@ class NariDIAEngine:
71
 
72
  try:
73
  # Nari DIA expects [S1] or [S2] tags.
74
- # Map your generic 'speaker1'/'speaker2' to Nari DIA's tags.
75
- # Note: Nari DIA primarily supports S1 and S2. If your segmenter
76
- # generates more speakers, they will be mapped to S1/S2.
77
-
78
- # Map based on the speaker index derived from segmenter
79
- if speaker == "speaker1":
80
- dia_speaker_tag = "[S1]"
81
- elif speaker == "speaker2":
82
- dia_speaker_tag = "[S2]"
83
  else:
84
- # For speaker3, speaker4, etc., we'll alternate or default to S1/S2
85
- # A more sophisticated mapping might be needed if you want more than 2 distinct voices
86
- # For simplicity, we'll just alternate beyond speaker2
87
- if int(speaker.replace('speaker', '')) % 2 == 1: # Odd speakers to S1
88
- dia_speaker_tag = "[S1]"
89
- else: # Even speakers to S2
90
- dia_speaker_tag = "[S2]"
91
- logger.warning(f"Nari DIA primarily supports [S1] and [S2]. Mapping '{speaker}' to '{dia_speaker_tag}'.")
92
 
93
  # Nari DIA expects the speaker tag at the beginning of the segment
94
  full_text_input = f"{dia_speaker_tag} {text}"
@@ -102,7 +92,6 @@ class NariDIAEngine:
102
  inputs = {k: v.to("cuda") for k, v in inputs.items()}
103
 
104
  with torch.no_grad():
105
- # The .generate method should return audio waveform
106
  audio_waveform = self.model.generate(**inputs).cpu().numpy().squeeze()
107
 
108
  # Nari DIA's sampling rate (check documentation if different)
@@ -119,11 +108,3 @@ class NariDIAEngine:
119
  except Exception as e:
120
  logger.error(f"Failed to synthesize segment with Nari DIA: {e}", exc_info=True) # exc_info to print full traceback
121
  return None
122
-
123
- # Remove the mock audio generation function as it's no longer needed
124
- # def _generate_mock_audio(self, text: str, speaker: str) -> np.ndarray:
125
- # """
126
- # Generate mock audio data for demonstration.
127
- # In a real implementation, this would be replaced with actual TTS.
128
- # """
129
- # # ... (existing mock audio generation code) ...
 
45
  logger.info("Nari DIA model initialized successfully.")
46
 
47
  except Exception as e:
48
+ logger.error(f"Failed to initialize Nari DIA model: {e}", exc_info=True)
49
  self.model = None
50
 
51
  def synthesize_segment(
52
  self,
53
  text: str,
54
+ speaker: str, # This will be 'S1' or 'S2' from segmenter
55
  output_path: str
56
  ) -> Optional[str]:
57
  """
 
59
 
60
  Args:
61
  text: Text to synthesize
62
+ speaker: Speaker identifier ('S1' or 'S2' expected from segmenter)
63
  output_path: Path to save the audio file
64
 
65
  Returns:
 
71
 
72
  try:
73
  # Nari DIA expects [S1] or [S2] tags.
74
+ # The segmenter is now directly outputting "S1" or "S2".
75
+ # We just need to wrap it in brackets.
76
+ if speaker in ["S1", "S2"]:
77
+ dia_speaker_tag = f"[{speaker}]"
 
 
 
 
 
78
  else:
79
+ # Fallback in case segmenter outputs something unexpected
80
+ logger.warning(f"Unexpected speaker tag '{speaker}' from segmenter. Defaulting to [S1].")
81
+ dia_speaker_tag = "[S1]"
 
 
 
 
 
82
 
83
  # Nari DIA expects the speaker tag at the beginning of the segment
84
  full_text_input = f"{dia_speaker_tag} {text}"
 
92
  inputs = {k: v.to("cuda") for k, v in inputs.items()}
93
 
94
  with torch.no_grad():
 
95
  audio_waveform = self.model.generate(**inputs).cpu().numpy().squeeze()
96
 
97
  # Nari DIA's sampling rate (check documentation if different)
 
108
  except Exception as e:
109
  logger.error(f"Failed to synthesize segment with Nari DIA: {e}", exc_info=True) # exc_info to print full traceback
110
  return None