HusseinBashir commited on
Commit
35e3ab8
·
verified ·
1 Parent(s): 1b065b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -61
app.py CHANGED
@@ -4,9 +4,8 @@ import numpy as np
4
  import scipy.io.wavfile
5
  from transformers import VitsModel, AutoTokenizer
6
  import re
7
- import time
8
 
9
- # Load model and tokenizer
10
  model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
11
  tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -68,7 +67,6 @@ def number_to_words(number):
68
  return str(number)
69
 
70
  def normalize_text(text):
71
- text = text.lower()
72
  # Remove commas from numbers like 1,000,000
73
  text = re.sub(r'(\d{1,3})(,\d{3})+', lambda m: m.group(0).replace(",", ""), text)
74
 
@@ -86,81 +84,40 @@ def normalize_text(text):
86
  '$': 'doolar',
87
  '=': 'egwal',
88
  '+': 'balaas',
89
- '%': 'boqolkiiba',
90
- '&': 'iyo',
91
- '@': 'at',
92
- '#': 'hash',
93
  }
94
  for sym, word in symbol_map.items():
95
  text = text.replace(sym, ' ' + word + ' ')
96
 
97
- # Special rule for 'z' or 'Z' prefix or suffix to sound as 's'
98
- # Replace 'z' or 'Z' at start or end of word with 's'
99
- def replace_z(match):
100
- word = match.group()
101
- # Replace z or Z at start or end with s
102
- if word.startswith('z'):
103
- word = 's' + word[1:]
104
- if word.endswith('z'):
105
- word = word[:-1] + 's'
106
- return word
107
-
108
- # Apply regex word by word for words containing z or Z
109
- text = re.sub(r'\b[z][a-z]*\b', replace_z, text) # words starting with z
110
- text = re.sub(r'\b[a-z]*[z]\b', replace_z, text) # words ending with z
111
-
112
- # Optional character normalization (kuma jirto 'z' sababtoo ah hadda la maamulo)
113
- text = text.replace("kh", "qa").replace("sh", "sha'a").replace("dh", "dha'a")
114
 
115
  return text
116
 
117
  def tts(text):
118
- paragraphs = [p for p in text.strip().split("\n") if p.strip()]
119
  audio_list = []
120
 
121
- # Calculate max total duration allowed based on paragraph count
122
- n = len(paragraphs)
123
- if n <= 5:
124
- max_duration = 30 # seconds
125
- elif n <= 20:
126
- max_duration = 60
127
- else:
128
- max_duration = 120
129
-
130
- # Generate waveform per paragraph and keep track of lengths
131
- waveforms = []
132
- for para in paragraphs:
133
  norm_para = normalize_text(para)
134
  inputs = tokenizer(norm_para, return_tensors="pt").to(device)
135
  with torch.no_grad():
136
  waveform = model(**inputs).waveform.squeeze().cpu().numpy()
137
- waveforms.append(waveform)
138
-
139
- # Calculate total length of raw waveform (in samples)
140
- total_samples = sum(wf.shape[0] for wf in waveforms)
141
- sampling_rate = model.config.sampling_rate
142
 
143
- # Compute speed factor to fit into max_duration seconds
144
- total_duration = total_samples / sampling_rate
145
- speed_factor = total_duration / max_duration if total_duration > max_duration else 1.0
146
-
147
- # Adjust waveforms speed by resampling (speed up if needed)
148
- from scipy.signal import resample
149
-
150
- for i, wf in enumerate(waveforms):
151
- new_length = int(len(wf) / speed_factor)
152
- waveforms[i] = resample(wf, new_length)
153
-
154
- # Add 0.3 sec pause between paragraphs except last one
155
- pause = np.zeros(int(sampling_rate * 0.3))
156
- for i, wf in enumerate(waveforms):
157
- audio_list.append(wf)
158
- if i < len(waveforms) -1:
159
- audio_list.append(pause)
160
 
161
  final_audio = np.concatenate(audio_list)
162
- filename = f"output_{int(time.time())}.wav"
163
- scipy.io.wavfile.write(filename, rate=sampling_rate, data=(final_audio * 32767).astype(np.int16))
164
  return filename
165
 
166
  gr.Interface(
 
4
  import scipy.io.wavfile
5
  from transformers import VitsModel, AutoTokenizer
6
  import re
 
7
 
8
+ # Load fine-tuned model from Hugging Face Hub or local path
9
  model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
10
  tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
67
  return str(number)
68
 
69
  def normalize_text(text):
 
70
  # Remove commas from numbers like 1,000,000
71
  text = re.sub(r'(\d{1,3})(,\d{3})+', lambda m: m.group(0).replace(",", ""), text)
72
 
 
84
  '$': 'doolar',
85
  '=': 'egwal',
86
  '+': 'balaas',
87
+ '-': 'miinas'
 
 
 
88
  }
89
  for sym, word in symbol_map.items():
90
  text = text.replace(sym, ' ' + word + ' ')
91
 
92
+ # Optional character normalization
93
+ text = text.replace("KH", "qa").replace("Z", "S")
94
+ text = text.replace("SH", "SHa'a").replace("DH", "Dha'a")
95
+ text = text.replace("ZamZam", "SamSam")
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  return text
98
 
99
  def tts(text):
100
+ paragraphs = text.strip().split("\n")
101
  audio_list = []
102
 
103
+ for i, para in enumerate(paragraphs):
104
+ if not para.strip():
105
+ continue
 
 
 
 
 
 
 
 
 
106
  norm_para = normalize_text(para)
107
  inputs = tokenizer(norm_para, return_tensors="pt").to(device)
108
  with torch.no_grad():
109
  waveform = model(**inputs).waveform.squeeze().cpu().numpy()
 
 
 
 
 
110
 
111
+ # Add pause between paragraphs (only if it's not the last one)
112
+ if i < len(paragraphs) - 1:
113
+ pause = np.zeros(int(model.config.sampling_rate * 0.8)) # 0.8 seconds pause
114
+ audio_list.append(np.concatenate((waveform, pause)))
115
+ else:
116
+ audio_list.append(waveform)
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  final_audio = np.concatenate(audio_list)
119
+ filename = "output.wav"
120
+ scipy.io.wavfile.write(filename, rate=model.config.sampling_rate, data=(final_audio * 32767).astype(np.int16))
121
  return filename
122
 
123
  gr.Interface(