abreza commited on
Commit
da2ee9a
·
1 Parent(s): 81b0116

feat: improved number handling and audio processing

Browse files

- Add sentence splitting functionality
- Ensure audio segments are flattened for consistent output
- Improve number handling

Files changed (4) hide show
  1. interface.py +37 -11
  2. sentence_splitter.py +123 -0
  3. synthesis.py +115 -31
  4. text_utils.py +84 -0
interface.py CHANGED
@@ -32,11 +32,18 @@ def ge2pe_infer(model_name: str, text: str, use_rules: bool, use_dict: bool):
32
 
33
  def create_interface():
34
  with gr.Blocks(title="Persian Speech Suite", css=custom_css) as demo:
35
- gr.Markdown("# Persian Speech Suite: GE2PE & TTS\n" "A unified playground for Persian grapheme‑to‑phoneme conversion (GE2PE) **and** text‑to‑speech synthesis (Mana TTS).")
 
 
 
 
36
 
37
  with gr.Tabs():
38
  with gr.TabItem("Grapheme → Phoneme (GE2PE)"):
39
- gr.Markdown("Convert Persian text to its phonemic transcription. Choose between **Homo‑GE2PE** and **Homo‑T5**, optionally applying short‑vowel rules and/or a custom dictionary.")
 
 
 
40
 
41
  with gr.Row():
42
  model_selector = gr.Radio(
@@ -75,20 +82,34 @@ def create_interface():
75
  )
76
 
77
  with gr.TabItem("Text‑to‑Speech"):
78
- gr.Markdown("Generate natural‑sounding Persian speech from your text using Tacotron2 + HiFiGAN.")
79
-
80
- tts_input = gr.Textbox(
81
- label="Persian Text",
82
- placeholder="مدل تولید گفتار با دادگان نسل مانا",
83
- lines=5,
84
  )
85
 
86
- tts_button = gr.Button("Generate Speech", variant="primary")
87
- tts_output = gr.Audio(label="Generated Speech")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  tts_button.click(
90
  fn=generate_speech,
91
- inputs=[tts_input],
92
  outputs=[tts_output],
93
  )
94
 
@@ -98,6 +119,11 @@ def create_interface():
98
  ["ایران سرزمین زیبایی‌ها و افتخارات است."],
99
  ["فناوری هوش مصنوعی به سرعت در حال پیشرفت است."],
100
  ["مدل تولید گفتار با دادگان نسل مانا"],
 
 
 
 
 
101
  ],
102
  inputs=[tts_input],
103
  )
 
32
 
33
  def create_interface():
34
  with gr.Blocks(title="Persian Speech Suite", css=custom_css) as demo:
35
+ gr.Markdown(
36
+ "# Persian Speech Suite: GE2PE & TTS\n"
37
+ "A unified playground for Persian grapheme‑to‑phoneme conversion (GE2PE) **and** text‑to‑speech synthesis (Mana TTS).\n\n"
38
+ "✨ **Now supports long texts!** The TTS system automatically splits long texts into natural segments."
39
+ )
40
 
41
  with gr.Tabs():
42
  with gr.TabItem("Grapheme → Phoneme (GE2PE)"):
43
+ gr.Markdown(
44
+ "Convert Persian text to its phonemic transcription. "
45
+ "Choose between **Homo‑GE2PE** and **Homo‑T5**, optionally applying short‑vowel rules and/or a custom dictionary."
46
+ )
47
 
48
  with gr.Row():
49
  model_selector = gr.Radio(
 
82
  )
83
 
84
  with gr.TabItem("Text‑to‑Speech"):
85
+ gr.Markdown(
86
+ "Generate natural‑sounding Persian speech from your text using Tacotron2 + HiFiGAN.\n\n"
87
+ "✨ **New:** Supports long texts! The system automatically splits text into natural segments "
88
+ "and adds pauses between them for better readability."
 
 
89
  )
90
 
91
+ with gr.Row():
92
+ with gr.Column(scale=2):
93
+ tts_input = gr.Textbox(
94
+ label="Persian Text",
95
+ placeholder="متن فارسی خود را اینجا بنویسید...",
96
+ lines=8,
97
+ )
98
+
99
+ with gr.Row():
100
+ tts_add_pauses = gr.Checkbox(
101
+ value=True,
102
+ label="Add pauses between segments",
103
+ info="Adds 300ms pause between text segments for natural flow"
104
+ )
105
+
106
+ tts_button = gr.Button("Generate Speech", variant="primary", size="lg")
107
+
108
+ tts_output = gr.Audio(label="Generated Speech", type="filepath")
109
 
110
  tts_button.click(
111
  fn=generate_speech,
112
+ inputs=[tts_input, gr.State(None), tts_add_pauses],
113
  outputs=[tts_output],
114
  )
115
 
 
119
  ["ایران سرزمین زیبایی‌ها و افتخارات است."],
120
  ["فناوری هوش مصنوعی به سرعت در حال پیشرفت است."],
121
  ["مدل تولید گفتار با دادگان نسل مانا"],
122
+ [
123
+ "هوش مصنوعی یکی از شگفت‌انگیزترین دستاوردهای بشر در قرن بیست و یکم است. "
124
+ "این فناوری توانایی یادگیری، استدلال و حل مسئله را به ماشین‌ها می‌دهد. "
125
+ "از پردازش زبان طبیعی گرفته تا بینایی کامپیوتری، هوش مصنوعی در حال تغییر دنیای ماست."
126
+ ],
127
  ],
128
  inputs=[tts_input],
129
  )
sentence_splitter.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List
3
+
4
+ class PersianSentenceSplitter:
5
+
6
+ def __init__(self, max_chars: int = 200, min_chars: int = 50):
7
+ self.max_chars = max_chars
8
+ self.min_chars = min_chars
9
+
10
+ self.sentence_endings = r'[.!?؟۔]'
11
+
12
+ self.weak_boundaries = r'[،,;؛]'
13
+
14
+ def clean_text(self, text: str) -> str:
15
+ text = re.sub(r'\s+', ' ', text)
16
+
17
+ text = text.replace('_', '\u200c')
18
+
19
+ text = text.replace('ك', 'ک').replace('ي', 'ی')
20
+
21
+ persian_digits = '۰۱۲۳۴۵۶۷۸۹'
22
+ english_digits = '0123456789'
23
+ digit_map = str.maketrans(persian_digits, english_digits)
24
+ text = text.translate(digit_map)
25
+
26
+ arabic_digits = '٠١٢٣٤٥٦٧٨٩'
27
+ arabic_map = str.maketrans(arabic_digits, english_digits)
28
+ text = text.translate(arabic_map)
29
+
30
+ return text.strip()
31
+
32
+ def split_by_punctuation(self, text: str) -> List[str]:
33
+ segments = re.split(f'({self.sentence_endings})', text)
34
+
35
+ sentences = []
36
+ for i in range(0, len(segments) - 1, 2):
37
+ if i + 1 < len(segments):
38
+ sentence = segments[i] + segments[i + 1]
39
+ else:
40
+ sentence = segments[i]
41
+
42
+ sentence = sentence.strip()
43
+ if sentence:
44
+ sentences.append(sentence)
45
+
46
+ if len(segments) % 2 == 1 and segments[-1].strip():
47
+ sentences.append(segments[-1].strip())
48
+
49
+ return sentences
50
+
51
+ def split_long_sentence(self, sentence: str) -> List[str]:
52
+ if len(sentence) <= self.max_chars:
53
+ return [sentence]
54
+
55
+ chunks = []
56
+ current_chunk = ""
57
+
58
+ parts = re.split(f'({self.weak_boundaries})', sentence)
59
+
60
+ for i in range(0, len(parts)):
61
+ part = parts[i]
62
+
63
+ if len(current_chunk + part) > self.max_chars and current_chunk:
64
+ chunks.append(current_chunk.strip())
65
+ current_chunk = part
66
+ else:
67
+ current_chunk += part
68
+
69
+ if current_chunk.strip():
70
+ chunks.append(current_chunk.strip())
71
+
72
+ final_chunks = []
73
+ for chunk in chunks:
74
+ if len(chunk) > self.max_chars:
75
+ final_chunks.extend(self.force_split_by_words(chunk))
76
+ else:
77
+ final_chunks.append(chunk)
78
+
79
+ return final_chunks
80
+
81
+ def force_split_by_words(self, text: str) -> List[str]:
82
+ words = text.split()
83
+ chunks = []
84
+ current_chunk = []
85
+ current_length = 0
86
+
87
+ for word in words:
88
+ word_length = len(word) + 1 # +1 for space
89
+
90
+ if current_length + word_length > self.max_chars and current_chunk:
91
+ chunks.append(' '.join(current_chunk))
92
+ current_chunk = [word]
93
+ current_length = word_length
94
+ else:
95
+ current_chunk.append(word)
96
+ current_length += word_length
97
+
98
+ if current_chunk:
99
+ chunks.append(' '.join(current_chunk))
100
+
101
+ return chunks
102
+
103
+ def split(self, text: str) -> List[str]:
104
+ text = self.clean_text(text)
105
+
106
+ if not text:
107
+ return []
108
+
109
+ if len(text) <= self.max_chars:
110
+ return [text]
111
+
112
+ sentences = self.split_by_punctuation(text)
113
+
114
+ final_segments = []
115
+ for sentence in sentences:
116
+ if len(sentence) > self.max_chars:
117
+ final_segments.extend(self.split_long_sentence(sentence))
118
+ else:
119
+ final_segments.append(sentence)
120
+
121
+ final_segments = [seg.strip() for seg in final_segments if seg.strip()]
122
+
123
+ return final_segments
synthesis.py CHANGED
@@ -1,83 +1,167 @@
1
  import os
2
  import sys
 
3
  import numpy as np
4
  import torch
5
  import soundfile as sf
6
  import spaces
7
  from config import models_path, results_path, sample_path, BASE_DIR
 
 
8
 
9
  encoder = None
10
  synthesizer = None
11
  vocoder = None
 
12
 
13
  def load_models():
14
- global encoder, synthesizer, vocoder
15
-
16
  try:
17
  sys.path.append(os.path.join(BASE_DIR, 'pmt2'))
18
-
19
  from encoder import inference as encoder_module
20
  from synthesizer.inference import Synthesizer
21
  from parallel_wavegan.utils import load_model as vocoder_hifigan
22
-
23
  global encoder
24
  encoder = encoder_module
25
-
26
  print("Loading encoder model...")
27
  encoder.load_model(os.path.join(models_path, 'encoder.pt'))
28
-
29
  print("Loading synthesizer model...")
30
  synthesizer = Synthesizer(os.path.join(models_path, 'synthesizer.pt'))
31
-
32
  print("Loading HiFiGAN vocoder...")
33
  vocoder = vocoder_hifigan(os.path.join(models_path, 'vocoder_HiFiGAN.pkl'))
34
  vocoder.remove_weight_norm()
35
  vocoder = vocoder.eval().to('cuda' if torch.cuda.is_available() else 'cpu')
36
-
 
 
 
37
  return True
38
  except Exception as e:
39
  import traceback
40
  print(f"Error loading models: {traceback.format_exc()}")
41
  return False
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  @spaces.GPU(duration=120)
44
- def generate_speech(text, reference_audio=None):
45
  if not text or text.strip() == "":
46
  return None
47
-
48
  try:
49
  if reference_audio is None:
50
  ref_wav_path = sample_path
51
  else:
52
  ref_wav_path = os.path.join(results_path, "reference_audio.wav")
53
  sf.write(ref_wav_path, reference_audio[1], reference_audio[0])
54
-
55
  print(f"Using reference audio: {ref_wav_path}")
56
-
57
  wav = synthesizer.load_preprocess_wav(ref_wav_path)
58
-
59
  encoder_wav = encoder.preprocess_wav(wav)
60
  embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
61
-
62
- texts = [text]
63
- embeds = [embed] * len(texts)
64
- specs = synthesizer.synthesize_spectrograms(texts, embeds)
65
- spec = np.concatenate(specs, axis=1)
66
-
67
- x = torch.from_numpy(spec.T).to('cuda' if torch.cuda.is_available() else 'cpu')
68
-
69
- with torch.no_grad():
70
- wav = vocoder.inference(x)
71
-
72
- wav = wav.cpu().numpy()
73
- wav = wav / np.abs(wav).max() * 0.97
74
-
75
- output_filename = f"generated_{hash(text) % 10000}.wav"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  output_path = os.path.join(results_path, output_filename)
77
- sf.write(output_path, wav, synthesizer.sample_rate)
78
-
 
 
 
79
  return output_path
80
-
81
  except Exception as e:
82
  import traceback
83
  error_details = traceback.format_exc()
 
1
  import os
2
  import sys
3
+ import re
4
  import numpy as np
5
  import torch
6
  import soundfile as sf
7
  import spaces
8
  from config import models_path, results_path, sample_path, BASE_DIR
9
+ from sentence_splitter import PersianSentenceSplitter
10
+ from text_utils import convert_number_to_text
11
 
12
  encoder = None
13
  synthesizer = None
14
  vocoder = None
15
+ sentence_splitter = None
16
 
17
  def load_models():
18
+ global encoder, synthesizer, vocoder, sentence_splitter
19
+
20
  try:
21
  sys.path.append(os.path.join(BASE_DIR, 'pmt2'))
22
+
23
  from encoder import inference as encoder_module
24
  from synthesizer.inference import Synthesizer
25
  from parallel_wavegan.utils import load_model as vocoder_hifigan
26
+
27
  global encoder
28
  encoder = encoder_module
29
+
30
  print("Loading encoder model...")
31
  encoder.load_model(os.path.join(models_path, 'encoder.pt'))
32
+
33
  print("Loading synthesizer model...")
34
  synthesizer = Synthesizer(os.path.join(models_path, 'synthesizer.pt'))
35
+
36
  print("Loading HiFiGAN vocoder...")
37
  vocoder = vocoder_hifigan(os.path.join(models_path, 'vocoder_HiFiGAN.pkl'))
38
  vocoder.remove_weight_norm()
39
  vocoder = vocoder.eval().to('cuda' if torch.cuda.is_available() else 'cpu')
40
+
41
+ sentence_splitter = PersianSentenceSplitter(max_chars=150, min_chars=30)
42
+
43
+ print("Models loaded successfully!")
44
  return True
45
  except Exception as e:
46
  import traceback
47
  print(f"Error loading models: {traceback.format_exc()}")
48
  return False
49
 
50
+
51
+ def normalize_text_for_synthesis(text: str) -> str:
52
+ text = text.replace('ك', 'ک').replace('ي', 'ی')
53
+
54
+ text = text.replace('_', '\u200c')
55
+
56
+ text = re.sub(r'\s+', ' ', text)
57
+ text = text.strip()
58
+
59
+ number_pattern = r'[۰-۹0-9٠-٩]+(?:[,،٬][۰-۹0-9٠-٩]+)*'
60
+
61
+ def replace_number(match):
62
+ num_str = match.group(0)
63
+ try:
64
+ return convert_number_to_text(num_str)
65
+ except:
66
+ return num_str
67
+
68
+ text = re.sub(number_pattern, replace_number, text)
69
+
70
+ return text
71
+
72
+
73
+ def synthesize_segment(text_segment: str, embed: np.ndarray) -> np.ndarray:
74
+ try:
75
+ text_segment = normalize_text_for_synthesis(text_segment)
76
+
77
+ specs = synthesizer.synthesize_spectrograms([text_segment], [embed])
78
+ spec = specs[0]
79
+
80
+ x = torch.from_numpy(spec.T).to('cuda' if torch.cuda.is_available() else 'cpu')
81
+
82
+ with torch.no_grad():
83
+ wav = vocoder.inference(x)
84
+
85
+ wav = wav.cpu().numpy()
86
+
87
+ if wav.ndim > 1:
88
+ wav = wav.squeeze()
89
+
90
+ return wav
91
+
92
+ except Exception as e:
93
+ import traceback
94
+ print(f"Error synthesizing segment '{text_segment[:50]}...': {traceback.format_exc()}")
95
+ return None
96
+
97
+
98
+ def add_silence(duration_ms: int = 300) -> np.ndarray:
99
+ sample_rate = synthesizer.sample_rate
100
+ num_samples = int(sample_rate * duration_ms / 1000)
101
+ return np.zeros(num_samples, dtype=np.float32)
102
+
103
+
104
  @spaces.GPU(duration=120)
105
+ def generate_speech(text, reference_audio=None, add_pauses: bool = True):
106
  if not text or text.strip() == "":
107
  return None
108
+
109
  try:
110
  if reference_audio is None:
111
  ref_wav_path = sample_path
112
  else:
113
  ref_wav_path = os.path.join(results_path, "reference_audio.wav")
114
  sf.write(ref_wav_path, reference_audio[1], reference_audio[0])
115
+
116
  print(f"Using reference audio: {ref_wav_path}")
117
+
118
  wav = synthesizer.load_preprocess_wav(ref_wav_path)
119
+
120
  encoder_wav = encoder.preprocess_wav(wav)
121
  embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
122
+
123
+ text_segments = sentence_splitter.split(text)
124
+
125
+ print(f"Split text into {len(text_segments)} segments:")
126
+ for i, segment in enumerate(text_segments, 1):
127
+ print(f" Segment {i}: {segment[:60]}{'...' if len(segment) > 60 else ''}")
128
+
129
+ audio_segments = []
130
+ silence = add_silence(300) if add_pauses else None # 300ms pause
131
+
132
+ for i, segment in enumerate(text_segments):
133
+ print(f"Processing segment {i+1}/{len(text_segments)}...")
134
+
135
+ segment_wav = synthesize_segment(segment, embed)
136
+
137
+ if segment_wav is not None:
138
+ segment_wav = segment_wav.flatten() if segment_wav.ndim > 1 else segment_wav
139
+ audio_segments.append(segment_wav)
140
+
141
+ if add_pauses and i < len(text_segments) - 1:
142
+ audio_segments.append(silence)
143
+ else:
144
+ print(f"Warning: Failed to synthesize segment {i+1}")
145
+
146
+ if not audio_segments:
147
+ print("Error: No audio segments were generated successfully")
148
+ return None
149
+
150
+ audio_segments = [seg.flatten() if seg.ndim > 1 else seg for seg in audio_segments]
151
+
152
+ final_wav = np.concatenate(audio_segments)
153
+
154
+ final_wav = final_wav / np.abs(final_wav).max() * 0.97
155
+
156
+ output_filename = f"generated_{abs(hash(text)) % 100000}.wav"
157
  output_path = os.path.join(results_path, output_filename)
158
+ sf.write(output_path, final_wav, synthesizer.sample_rate)
159
+
160
+ print(f"✓ Successfully generated speech: {output_path}")
161
+ print(f" Total duration: {len(final_wav) / synthesizer.sample_rate:.2f} seconds")
162
+
163
  return output_path
164
+
165
  except Exception as e:
166
  import traceback
167
  error_details = traceback.format_exc()
text_utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PERSIAN_DIGITS = {
2
+ '۰': 'صفر', '۱': 'یک', '۲': 'دو', '۳': 'سه', '۴': 'چهار',
3
+ '۵': 'پنج', '۶': 'شش', '۷': 'هفت', '۸': 'هشت', '۹': 'نه',
4
+ '0': 'صفر', '1': 'یک', '2': 'دو', '3': 'سه', '4': 'چهار',
5
+ '5': 'پنج', '6': 'شش', '7': 'هفت', '8': 'هشت', '9': 'نه'
6
+ }
7
+
8
+ PERSIAN_NUMBERS = {
9
+ 10: 'ده', 11: 'یازده', 12: 'دوازده', 13: 'سیزده', 14: 'چهارده',
10
+ 15: 'پانزده', 16: 'شانزده', 17: 'هفده', 18: 'هجده', 19: 'نوزده',
11
+ 20: 'بیست', 30: 'سی', 40: 'چهل', 50: 'پنجاه',
12
+ 60: 'شصت', 70: 'هفتاد', 80: 'هشتاد', 90: 'نود',
13
+ 100: 'صد', 200: 'دویست', 300: 'سیصد', 400: 'چهارصد', 500: 'پانصد',
14
+ 600: 'ششصد', 700: 'هفتصد', 800: 'هشتصد', 900: 'نهصد'
15
+ }
16
+
17
+
18
+ def convert_three_digit(num: int) -> str:
19
+ if num == 0:
20
+ return ''
21
+
22
+ if num < 10:
23
+ return PERSIAN_DIGITS[str(num)]
24
+ elif num < 20:
25
+ return PERSIAN_NUMBERS[num]
26
+ elif num < 100:
27
+ tens = (num // 10) * 10
28
+ ones = num % 10
29
+ if ones == 0:
30
+ return PERSIAN_NUMBERS[tens]
31
+ return PERSIAN_NUMBERS[tens] + ' و ' + PERSIAN_DIGITS[str(ones)]
32
+ else:
33
+ hundreds = (num // 100) * 100
34
+ remainder = num % 100
35
+ if remainder == 0:
36
+ return PERSIAN_NUMBERS[hundreds]
37
+ return PERSIAN_NUMBERS[hundreds] + ' و ' + convert_three_digit(remainder)
38
+
39
+
40
+ def convert_number_to_text(num_str: str, phone_mode: bool = False) -> str:
41
+ try:
42
+ num_str = num_str.replace(',', '').replace('٬', '').replace(' ', '')
43
+
44
+ persian_to_english = str.maketrans('۰۱۲۳۴۵۶۷۸۹', '0123456789')
45
+ num_str = num_str.translate(persian_to_english)
46
+
47
+ if phone_mode:
48
+ return ' '.join(PERSIAN_DIGITS[d] for d in num_str if d.isdigit())
49
+
50
+ num = int(num_str)
51
+
52
+ if num == 0:
53
+ return 'صفر'
54
+
55
+ if num < 0:
56
+ return 'منفی ' + convert_number_to_text(str(abs(num)))
57
+
58
+ if num < 1000:
59
+ return convert_three_digit(num)
60
+
61
+ parts = []
62
+
63
+ if num >= 1_000_000_000:
64
+ billions = num // 1_000_000_000
65
+ parts.append(convert_three_digit(billions) + ' میلیارد')
66
+ num %= 1_000_000_000
67
+
68
+ if num >= 1_000_000:
69
+ millions = num // 1_000_000
70
+ parts.append(convert_three_digit(millions) + ' میلیون')
71
+ num %= 1_000_000
72
+
73
+ if num >= 1000:
74
+ thousands = num // 1000
75
+ parts.append(convert_three_digit(thousands) + ' هزار')
76
+ num %= 1000
77
+
78
+ if num > 0:
79
+ parts.append(convert_three_digit(num))
80
+
81
+ return ' و '.join(parts)
82
+
83
+ except:
84
+ return num_str