Somalitts commited on
Commit
65c7d6b
Β·
verified Β·
1 Parent(s): a415f53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -105
app.py CHANGED
@@ -1,153 +1,210 @@
1
  import os
2
  import re
3
- import io
4
  import torch
5
- import numpy as np
6
  import torchaudio
7
- from fastapi import FastAPI, UploadFile, File, HTTPException, Query
8
- from fastapi.responses import StreamingResponse
 
 
 
9
  from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
10
- from speechbrain.inference import EncoderClassifier
11
 
12
- # ─── Setup ─────────────────────────────────────────────────────────────────────
13
 
14
- app = FastAPI(title="Somali Multi-Voice TTS API")
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
- print(f"Using device: {device}")
 
17
 
18
- VOICE_SAMPLE_FILES = ["1.wav"]
 
 
 
19
 
20
- # Use Hugging Face writable directory
21
- EMBEDDING_DIR = "/tmp/speaker_embeddings"
22
- os.makedirs(EMBEDDING_DIR, exist_ok=True)
23
 
24
- # Load models once
25
- processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
26
- model = SpeechT5ForTextToSpeech.from_pretrained("Somalitts/8aad").to(device)
27
- vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
28
  speaker_model = EncoderClassifier.from_hparams(
29
  source="speechbrain/spkrec-xvect-voxceleb",
30
  run_opts={"device": device},
31
- savedir="/tmp/spkrec-xvect-voxceleb"
32
  )
33
 
34
- speaker_embeddings_cache = {}
35
-
36
- # ─── Embedding Function ────────────────────────────────────────────────────────
37
-
38
- def get_speaker_embedding(wav_file_path):
39
- if wav_file_path in speaker_embeddings_cache:
40
- return speaker_embeddings_cache[wav_file_path]
41
-
42
- embedding_path = os.path.join(EMBEDDING_DIR, os.path.basename(wav_file_path) + ".pt")
43
- if os.path.exists(embedding_path):
44
- embedding = torch.load(embedding_path, map_location=device)
45
- speaker_embeddings_cache[wav_file_path] = embedding
46
- return embedding
47
-
48
- if not os.path.exists(wav_file_path):
49
- raise HTTPException(status_code=404, detail=f"Voice file not found: {wav_file_path}")
50
-
51
- audio, sr = torchaudio.load(wav_file_path)
52
- if sr != 16000:
53
- audio = torchaudio.functional.resample(audio, sr, 16000)
54
- if audio.shape[0] > 1:
55
- audio = torch.mean(audio, dim=0, keepdim=True)
56
 
 
 
 
 
 
 
57
  with torch.no_grad():
58
- embedding = speaker_model.encode_batch(audio.to(device))
59
- embedding = torch.nn.functional.normalize(embedding, dim=2).squeeze()
 
 
 
60
 
61
- torch.save(embedding.cpu(), embedding_path)
62
- speaker_embeddings_cache[wav_file_path] = embedding.to(device)
63
- return embedding.to(device)
64
 
65
- # ─── Text Normalization Functions ──────────────────────────────────────────────
66
 
 
67
  number_words = {
68
- 0: "eber", 1: "kow", 2: "labo", 3: "saddex", 4: "afar", 5: "shan",
69
- 6: "lix", 7: "toddobo", 8: "siddeed", 9: "sagaal", 10: "toban",
70
- 11: "kow iyo toban", 12: "labo iyo toban", 13: "saddex iyo toban",
71
- 14: "afar iyo toban", 15: "shan iyo toban", 16: "lix iyo toban",
72
- 17: "toddobo iyo toban", 18: "siddeed iyo toban", 19: "sagaal iyo toban",
73
- 20: "labaatan", 30: "soddon", 40: "afartan", 50: "konton",
74
- 60: "lixdan", 70: "toddobaatan", 80: "siddeetan", 90: "sagaashan",
75
- 100: "boqol", 1000: "kun"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  }
77
 
 
78
  def number_to_words(n):
79
- if n in number_words:
80
- return number_words[n]
81
- if n < 100:
82
- return number_words[n // 10 * 10] + (" iyo " + number_words[n % 10] if n % 10 else "")
83
- if n < 1000:
84
- return (number_words[n // 100] + " boqol" if n // 100 > 1 else "boqol") + (
85
- " iyo " + number_to_words(n % 100) if n % 100 else "")
86
- if n < 1_000_000:
87
- return (number_to_words(n // 1000) + " kun" if n // 1000 > 1 else "kun") + (
88
- " iyo " + number_to_words(n % 1000) if n % 1000 else "")
89
- return str(n)
 
 
 
 
 
 
 
 
90
 
91
  def replace_numbers_with_words(text):
92
  return re.sub(r'\b\d+\b', lambda m: number_to_words(int(m.group())), text)
93
 
 
94
  def normalize_text(text):
95
  text = text.lower()
96
  text = replace_numbers_with_words(text)
97
  text = re.sub(r'[^\w\s\']', '', text)
98
  return text
99
 
100
- def split_long_text_into_chunks(text, max_words=18):
101
- words = text.split()
102
- return [' '.join(words[i:i + max_words]) for i in range(0, len(words), max_words)]
103
 
104
- # ─── Routes ────────────────────────────────────────────────────────────────────
 
 
 
105
 
106
- @app.get("/")
107
- async def root():
108
- return {"message": "Welcome to Somali Multi-Voice TTS API"}
109
 
110
- @app.post("/tts")
111
- async def text_to_speech_api(text: str = Query(..., min_length=1), voice_file: str = Query(...)):
112
- if voice_file not in VOICE_SAMPLE_FILES:
113
- raise HTTPException(status_code=400, detail=f"Voice file '{voice_file}' not found.")
114
 
115
- try:
116
- speaker_embedding = get_speaker_embedding(voice_file)
117
- except Exception as e:
118
- raise HTTPException(status_code=500, detail=str(e))
119
 
120
- text_chunks = split_long_text_into_chunks(text)
 
 
 
 
 
 
 
 
 
 
121
  audio_chunks = []
122
 
123
- for idx, chunk in enumerate(text_chunks):
124
- chunk = chunk.strip()
125
- if not chunk:
 
126
  continue
127
- norm_text = normalize_text(chunk)
128
- inputs = processor(text=norm_text, return_tensors="pt").to(device)
129
-
130
- with torch.no_grad():
131
- speech = model.generate(
132
- input_ids=inputs["input_ids"],
133
- speaker_embeddings=speaker_embedding.unsqueeze(0),
134
- do_sample=True,
135
- top_k=50,
136
- temperature=0.75,
137
- repetition_penalty=1.2,
138
- max_new_tokens=512
139
- )
140
- audio = vocoder(speech).cpu().squeeze().numpy()
 
 
 
 
 
 
 
 
 
141
  audio_chunks.append(audio)
142
 
143
- # Add short pause between chunks
144
- if idx < len(text_chunks) - 1:
145
- pause = np.zeros(int(16000 * 0.8))
146
- audio_chunks.append(pause)
 
 
 
 
 
 
 
 
147
 
148
  final_audio = np.concatenate(audio_chunks)
149
- buffer = io.BytesIO()
150
- torchaudio.save(buffer, torch.tensor(final_audio).unsqueeze(0), 16000, format="wav")
151
- buffer.seek(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
- return StreamingResponse(buffer, media_type="audio/wav", headers={"Content-Disposition": "inline; filename=tts_output.wav"})
 
1
  import os
2
  import re
3
+ import uuid
4
  import torch
 
5
  import torchaudio
6
+ import soundfile as sf
7
+ import numpy as np
8
+ from fastapi import FastAPI
9
+ from fastapi.responses import FileResponse
10
+ from pydantic import BaseModel
11
  from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
12
+ from speechbrain.inference.speaker import EncoderClassifier
13
 
 
14
 
15
+ app = FastAPI()
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ CACHE_DIR = "/tmp/hf-cache"
18
+
19
 
20
+ # Load models (female only)
21
+ processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts", cache_dir=CACHE_DIR)
22
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", cache_dir=CACHE_DIR).to(device)
23
+ model_female = SpeechT5ForTextToSpeech.from_pretrained("Somalitts/8aad", cache_dir=CACHE_DIR).to(device)
24
 
 
 
 
25
 
26
+ # Speaker encoder
 
 
 
27
  speaker_model = EncoderClassifier.from_hparams(
28
  source="speechbrain/spkrec-xvect-voxceleb",
29
  run_opts={"device": device},
30
+ savedir="/tmp/spk_model"
31
  )
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # Load female embedding only
35
+ def get_embedding(wav_path, pt_path):
36
+ if os.path.exists(pt_path):
37
+ return torch.load(pt_path).to(device)
38
+ audio, sr = torchaudio.load(wav_path)
39
+ audio = torchaudio.functional.resample(audio, sr, 16000).mean(dim=0).unsqueeze(0).to(device)
40
  with torch.no_grad():
41
+ emb = speaker_model.encode_batch(audio)
42
+ emb = torch.nn.functional.normalize(emb, dim=2).squeeze()
43
+ torch.save(emb.cpu(), pt_path)
44
+ return emb
45
+
46
 
47
+ embedding_female = get_embedding("caasho.wav", "/tmp/female_embedding.pt")
 
 
48
 
 
49
 
50
+ # Number words dictionary (Somali)
51
  number_words = {
52
+ 0: "eber",
53
+ 1: "kow",
54
+ 2: "laba",
55
+ 3: "saddex",
56
+ 4: "afar",
57
+ 5: "shan",
58
+ 6: "lix",
59
+ 7: "toddoba",
60
+ 8: "siddeed",
61
+ 9: "sagaal",
62
+ 10: "toban",
63
+ 11: "kow iyo toban",
64
+ 12: "laba iyo toban",
65
+ 13: "saddex iyo toban",
66
+ 14: "afar iyo toban",
67
+ 15: "shan iyo toban",
68
+ 16: "lix iyo toban",
69
+ 17: "toddoba iyo toban",
70
+ 18: "siddeed iyo toban",
71
+ 19: "sagaal iyo toban",
72
+ 20: "labaatan",
73
+ 30: "soddon",
74
+ 40: "afaratan",
75
+ 50: "konton",
76
+ 60: "lixdan",
77
+ 70: "toddobaatan",
78
+ 80: "siddeetan",
79
+ 90: "sagaashan",
80
+ 100: "boqol",
81
+ 1000: "kun"
82
  }
83
 
84
+
85
  def number_to_words(n):
86
+ try:
87
+ if n in number_words:
88
+ return number_words[n]
89
+ if n < 100:
90
+ return number_words[n // 10 * 10] + (" iyo " + number_words[n % 10] if n % 10 else "")
91
+ if n < 1000:
92
+ return (number_words[n // 100] + " boqol" if n // 100 > 1 else "boqol") + (
93
+ " iyo " + number_to_words(n % 100) if n % 100 else "")
94
+ if n < 1_000_000:
95
+ return (number_to_words(n // 1000) + " kun" if n // 1000 > 1 else "kun") + (
96
+ " iyo " + number_to_words(n % 1000) if n % 1000 else "")
97
+ if n < 1_000_000_000:
98
+ return (number_to_words(n // 1_000_000) + " milyan" if n // 1_000_000 > 1 else "milyan") + (
99
+ " iyo " + number_to_words(n % 1_000_000) if n % 1_000_000 else "")
100
+ return str(n)
101
+ except Exception as e:
102
+ print(f"Error converting number {n}: {e}")
103
+ return str(n)
104
+
105
 
106
  def replace_numbers_with_words(text):
107
  return re.sub(r'\b\d+\b', lambda m: number_to_words(int(m.group())), text)
108
 
109
+
110
  def normalize_text(text):
111
  text = text.lower()
112
  text = replace_numbers_with_words(text)
113
  text = re.sub(r'[^\w\s\']', '', text)
114
  return text
115
 
 
 
 
116
 
117
+ def split_into_sentences(text):
118
+ sentence_endings = re.compile(r'(?<=[.!?])\s+')
119
+ sentences = sentence_endings.split(text)
120
+ return [s.strip() for s in sentences if s.strip()]
121
 
 
 
 
122
 
123
+ def get_speaker_embedding(voice_choice):
124
+ # For now we only have female embedding loaded
125
+ # If you have male embedding, load it and select here based on voice_choice
126
+ return embedding_female
127
 
 
 
 
 
128
 
129
+ def text_to_speech(text, voice_choice):
130
+ if not text or not voice_choice:
131
+ # gr.Warning() is undefined in this context - replace or remove as needed
132
+ print("Fadlan geli qoraal oo dooro cod.")
133
+ return None
134
+
135
+
136
+ speaker_embedding = get_speaker_embedding(voice_choice)
137
+
138
+
139
+ paragraphs = text.strip().split("\n")
140
  audio_chunks = []
141
 
142
+
143
+ for para_idx, para in enumerate(paragraphs):
144
+ para = para.strip()
145
+ if not para:
146
  continue
147
+
148
+
149
+ sentences = split_into_sentences(para)
150
+
151
+
152
+ for sent_idx, sentence in enumerate(sentences):
153
+ norm_sentence = normalize_text(sentence)
154
+ inputs = processor(text=norm_sentence, return_tensors="pt").to(device)
155
+
156
+
157
+ with torch.no_grad():
158
+ speech = model_female.generate(
159
+ input_ids=inputs["input_ids"],
160
+ speaker_embeddings=speaker_embedding.unsqueeze(0),
161
+ do_sample=True,
162
+ top_k=50,
163
+ temperature=0.75,
164
+ repetition_penalty=1.2,
165
+ max_new_tokens=512
166
+ )
167
+ audio = vocoder(speech).cpu().squeeze().numpy()
168
+
169
+
170
  audio_chunks.append(audio)
171
 
172
+
173
+ # Pause 0.5s after each sentence except last
174
+ if sent_idx < len(sentences) - 1:
175
+ pause = np.zeros(int(16000 * 0.5))
176
+ audio_chunks.append(pause)
177
+
178
+
179
+ # Pause 0.8s after each paragraph except last
180
+ if para_idx < len(paragraphs) - 1:
181
+ para_pause = np.zeros(int(16000 * 0.8))
182
+ audio_chunks.append(para_pause)
183
+
184
 
185
  final_audio = np.concatenate(audio_chunks)
186
+ return (16000, final_audio)
187
+
188
+
189
+ class TTSRequest(BaseModel):
190
+ text: str
191
+
192
+
193
+ @app.post("/speak")
194
+ def speak(payload: TTSRequest):
195
+ clean_text = normalize_text(payload.text)
196
+ inputs = processor(text=clean_text, return_tensors="pt").to(device)
197
+
198
+
199
+ with torch.no_grad():
200
+ waveform = model_female.generate_speech(
201
+ input_ids=inputs["input_ids"],
202
+ speaker_embeddings=embedding_female.unsqueeze(0),
203
+ vocoder=vocoder
204
+ )
205
+
206
+
207
+ out_path = f"/tmp/{uuid.uuid4().hex}.wav"
208
+ sf.write(out_path, waveform.cpu().numpy(), 16000)
209
+ return FileResponse(out_path, media_type="audio/wav", filename="voice.wav")
210