Somalitts commited on
Commit
9d4c38e
·
verified ·
1 Parent(s): acbd995

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -66
app.py CHANGED
@@ -4,92 +4,180 @@ import uuid
4
  import torch
5
  import torchaudio
6
  import soundfile as sf
7
- from fastapi import FastAPI
8
  from fastapi.responses import FileResponse
9
  from pydantic import BaseModel
 
 
 
10
  from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
11
- from speechbrain.inference.speaker import EncoderClassifier
 
 
 
 
12
 
13
- app = FastAPI()
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
- CACHE_DIR = "/tmp/hf-cache"
16
-
17
- # Load models
18
- processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts", cache_dir=CACHE_DIR)
19
- vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", cache_dir=CACHE_DIR).to(device)
20
- model_male = SpeechT5ForTextToSpeech.from_pretrained("Somalitts/5aad", cache_dir=CACHE_DIR).to(device)
21
- model_female = SpeechT5ForTextToSpeech.from_pretrained("Somalitts/8aad", cache_dir=CACHE_DIR).to(device)
22
-
23
- # Speaker encoder
24
- speaker_model = EncoderClassifier.from_hparams(
25
- source="speechbrain/spkrec-xvect-voxceleb",
26
- run_opts={"device": device},
27
- savedir="/tmp/spk_model"
28
- )
29
-
30
- # Load speaker embeddings
31
- def get_embedding(wav_path, pt_path):
32
- if os.path.exists(pt_path):
33
- return torch.load(pt_path).to(device)
34
- audio, sr = torchaudio.load(wav_path)
35
- audio = torchaudio.functional.resample(audio, sr, 16000).mean(dim=0).unsqueeze(0).to(device)
36
- with torch.no_grad():
37
- emb = speaker_model.encode_batch(audio)
38
- emb = torch.nn.functional.normalize(emb, dim=2).squeeze()
39
- torch.save(emb.cpu(), pt_path)
40
- return emb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- embedding_male = get_embedding("Hussein.wav", "/tmp/male_embedding.pt")
43
- embedding_female = get_embedding("caasho.wav", "/tmp/female_embedding.pt")
 
 
 
 
44
 
45
- # Text normalization
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  number_words = {
47
- 0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan",
48
- 6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban",
49
- 20: "labaatan", 30: "sodon", 40: "afartan", 50: "konton",
50
- 60: "lixdan", 70: "todobaatan", 80: "sideetan", 90: "sagaashan",
51
- 100: "boqol", 1000: "kun"
 
 
 
52
  }
53
 
54
- def number_to_words(n):
55
- if n < 20:
56
- return number_words.get(n, str(n))
57
- elif n < 100:
58
- tens, unit = divmod(n, 10)
59
- return number_words[tens * 10] + (" " + number_words[unit] if unit else "")
60
- elif n < 1000:
61
- hundreds, rem = divmod(n, 100)
62
- return (number_words[hundreds] + " boqol" if hundreds > 1 else "boqol") + (" " + number_to_words(rem) if rem else "")
63
- elif n < 1_000_000:
64
- th, rem = divmod(n, 1000)
65
- return (number_to_words(th) + " kun") + (" " + number_to_words(rem) if rem else "")
66
- else:
67
- return str(n)
68
 
69
  def replace_numbers_with_words(text):
70
- return re.sub(r'\b\d+\b', lambda m: number_to_words(int(m.group())), text)
71
 
72
  def normalize_text(text):
73
  text = text.lower()
74
  text = replace_numbers_with_words(text)
75
- text = re.sub(r'[^\w\s]', '', text)
 
76
  return text
77
 
78
- # API request schema
79
  class TTSRequest(BaseModel):
80
  text: str
81
- voice: str # "Male" or "Female"
82
 
83
- @app.post("/speak")
84
- def speak(payload: TTSRequest):
85
- clean_text = normalize_text(payload.text)
86
- inputs = processor(text=clean_text, return_tensors="pt").to(device)
87
- model = model_male if payload.voice.lower() == "male" else model_female
88
- embedding = embedding_male if payload.voice.lower() == "male" else embedding_female
 
89
 
90
- with torch.no_grad():
91
- waveform = model.generate_speech(inputs["input_ids"], embedding.unsqueeze(0), vocoder=vocoder)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- out_path = f"/tmp/{uuid.uuid4().hex}.wav"
94
- sf.write(out_path, waveform.cpu().numpy(), 16000)
95
- return FileResponse(out_path, media_type="audio/wav", filename="voice.wav")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import torch
5
  import torchaudio
6
  import soundfile as sf
7
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
8
  from fastapi.responses import FileResponse
9
  from pydantic import BaseModel
10
+ import logging
11
+ import tempfile
12
+
13
  from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
14
+ from speechbrain.pretrained import EncoderClassifier
15
+
16
+ # --- Dejinta iyo Isku-habeynta (Configuration) ---
17
+ logging.basicConfig(level=logging.INFO)
18
+ app = FastAPI(title="Multi-Voice Somali Text-to-Speech API")
19
 
20
+ # Hubinta aaladda (GPU ama CPU)
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ logging.info(f"Using device: {device}")
23
+
24
+ # Faylasha codadka tixraaca (ku dar halkan faylashaada .wav)
25
+ # Hubi in faylashan ay yaalliin isla galka uu ku jiro koodhkan
26
+ VOICE_SAMPLE_FILES = ["1.wav"]
27
+ EMBEDDING_DIR = "speaker_embeddings"
28
+ os.makedirs(EMBEDDING_DIR, exist_ok=True)
29
+
30
+ # --- Soo Dejinta Model-yada (Global variables) ---
31
+ processor = None
32
+ model = None
33
+ vocoder = None
34
+ speaker_model = None
35
+ speaker_embeddings_cache = {}
36
+
37
+ @app.on_event("startup")
38
+ async def startup_event():
39
+ """
40
+ Shaqadan waxay shaqaynaysaa hal mar marka uu barnaamijku bilaabmo.
41
+ Waxay soo dejinaysaa model-yada waxayna diyaarisaa codadka.
42
+ """
43
+ global processor, model, vocoder, speaker_model
44
+ logging.info("Loading models...")
45
+ try:
46
+ processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
47
+ model = SpeechT5ForTextToSpeech.from_pretrained("Somalitts/8aad").to(device)
48
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
49
+ speaker_model = EncoderClassifier.from_hparams(
50
+ source="speechbrain/spkrec-xvect-voxceleb",
51
+ run_opts={"device": device},
52
+ savedir=os.path.join("pretrained_models", "spkrec-xvect-voxceleb")
53
+ )
54
+ logging.info("Models loaded successfully.")
55
+ except Exception as e:
56
+ logging.error(f"Error loading models: {e}")
57
+ raise RuntimeError(f"Could not load models: {e}")
58
+
59
+ logging.info("Pre-caching speaker embeddings...")
60
+ for voice_file in VOICE_SAMPLE_FILES:
61
+ if not os.path.exists(voice_file):
62
+ raise FileNotFoundError(f"Reference audio file not found: {voice_file}. Make sure it's in the same directory.")
63
+ get_speaker_embedding(voice_file)
64
+ logging.info("Embeddings cached. Application is ready to serve requests.")
65
+
66
 
67
+ def get_speaker_embedding(wav_file_path):
68
+ """
69
+ Waxay abuurtaa oo kaydisaa 'speaker embedding' ama way soo akhridaa haddii uu horay u kaydsanaa.
70
+ """
71
+ if wav_file_path in speaker_embeddings_cache:
72
+ return speaker_embeddings_cache[wav_file_path]
73
 
74
+ embedding_path = os.path.join(EMBEDDING_DIR, f"{os.path.basename(wav_file_path)}.pt")
75
+ if os.path.exists(embedding_path):
76
+ embedding = torch.load(embedding_path, map_location=device)
77
+ speaker_embeddings_cache[wav_file_path] = embedding
78
+ logging.info(f"Loaded cached embedding for {wav_file_path}")
79
+ return embedding
80
+
81
+ try:
82
+ audio, sr = torchaudio.load(wav_file_path)
83
+ if sr != 16000:
84
+ audio = torchaudio.functional.resample(audio, sr, 16000)
85
+ if audio.shape[0] > 1:
86
+ audio = torch.mean(audio, dim=0, keepdim=True)
87
+
88
+ with torch.no_grad():
89
+ embedding = speaker_model.encode_batch(audio.to(device))
90
+ embedding = torch.nn.functional.normalize(embedding, dim=2).squeeze()
91
+
92
+ torch.save(embedding.cpu(), embedding_path)
93
+ speaker_embeddings_cache[wav_file_path] = embedding.to(device)
94
+ logging.info(f"Generated and cached new embedding for {wav_file_path}")
95
+ return embedding.to(device)
96
+ except Exception as e:
97
+ logging.error(f"Could not process audio file {wav_file_path}. Error: {e}")
98
+ raise HTTPException(status_code=500, detail=f"Failed to process reference audio: {wav_file_path}")
99
+
100
+ # --- Shaqooyinka Hagaajinta Qoraalka (Text Processing) ---
101
+ # (Kuwani sidoodii hore ayay u fiican yihiin)
102
  number_words = {
103
+ 0: "eber", 1: "kow", 2: "labo", 3: "saddex", 4: "afar", 5: "shan",
104
+ 6: "lix", 7: "toddobo", 8: "siddeed", 9: "sagaal", 10: "toban",
105
+ 11: "kow iyo toban", 12: "labo iyo toban", 13: "saddex iyo toban",
106
+ 14: "afar iyo toban", 15: "shan iyo toban", 16: "lix iyo toban",
107
+ 17: "toddobo iyo toban", 18: "siddeed iyo toban", 19: "sagaal iyo toban",
108
+ 20: "labaatan", 30: "soddon", 40: "afartan", 50: "konton",
109
+ 60: "lixdan", 70: "toddobaatan", 80: "siddeetan", 90: "sagaashan",
110
+ 100: "boqol", 1000: "kun",
111
  }
112
 
113
+ def number_to_words_recursive(n):
114
+ if n in number_words: return number_words[n]
115
+ if n < 100: return number_words[n//10 * 10] + (" iyo " + number_words[n%10] if n%10 else "")
116
+ if n < 1000: return (number_to_words_recursive(n//100) + " boqol" if n//100 > 1 else "boqol") + (" iyo " + number_to_words_recursive(n%100) if n%100 else "")
117
+ if n < 1000000: return (number_to_words_recursive(n//1000) + " kun") + (" iyo " + number_to_words_recursive(n%1000) if n%1000 else "")
118
+ return str(n)
 
 
 
 
 
 
 
 
119
 
120
  def replace_numbers_with_words(text):
121
+ return re.sub(r'\b\d+\b', lambda m: number_to_words_recursive(int(m.group())), text)
122
 
123
  def normalize_text(text):
124
  text = text.lower()
125
  text = replace_numbers_with_words(text)
126
+ text = re.sub(r'[^\w\s\']', '', text)
127
+ text = re.sub(r'\s+', ' ', text).strip()
128
  return text
129
 
130
+ # --- Qaabka Codsiga API-ga (Pydantic Model) ---
131
  class TTSRequest(BaseModel):
132
  text: str
133
+ voice_choice: str = "1.wav" # Qiimaha asalka ah haddii aan la soo dirin
134
 
135
+ # --- Endpoints-ka API-ga ---
136
+ @app.get("/voices", summary="Soo Hel Codadka La Heli Karo")
137
+ async def get_available_voices():
138
+ """
139
+ Wuxuu soo celinayaa liiska faylasha codadka ee diyaar ka ah.
140
+ """
141
+ return {"available_voices": VOICE_SAMPLE_FILES}
142
 
143
+ @app.post("/speak", summary="Abuur Cod Qoraal ka timid")
144
+ async def text_to_speech_endpoint(payload: TTSRequest, background_tasks: BackgroundTasks):
145
+ """
146
+ Wuxuu qoraal u beddelaa cod .wav ah.
147
+ - **text**: Qoraalka aad rabto inaad cod u beddesho.
148
+ - **voice_choice**: Faylka codka aad rabto inaad tixraacdo (tusaale, "1.wav").
149
+ """
150
+ if not payload.text or not payload.text.strip():
151
+ raise HTTPException(status_code=400, detail="Qoraalku ma bannaanaan karo (Text cannot be empty).")
152
+ if payload.voice_choice not in VOICE_SAMPLE_FILES:
153
+ raise HTTPException(status_code=400, detail=f"Codka la doortay '{payload.voice_choice}' lama helin.")
154
+
155
+ try:
156
+ speaker_embedding = get_speaker_embedding(payload.voice_choice)
157
+ except FileNotFoundError:
158
+ raise HTTPException(status_code=404, detail=f"Faylka codka ee '{payload.voice_choice}' lama helin.")
159
 
160
+ normalized_text = normalize_text(payload.text)
161
+ logging.info(f"Generating speech for: '{normalized_text}' with voice '{payload.voice_choice}'")
162
+ inputs = processor(text=normalized_text, return_tensors="pt").to(device)
163
+
164
+ with torch.no_grad():
165
+ speech = model.generate_speech(
166
+ inputs["input_ids"],
167
+ speaker_embedding.unsqueeze(0),
168
+ vocoder=vocoder
169
+ )
170
+
171
+ # Ku kaydi fayl ku meel gaar ah
172
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
173
+ sf.write(tmp_file.name, speech.cpu().numpy(), 16000)
174
+
175
+ # Ku dar shaqo tirtiraysa faylka ka dib marka la soo celiyo
176
+ background_tasks.add_task(os.remove, tmp_file.name)
177
+
178
+ # Soo celi faylka codka
179
+ return FileResponse(
180
+ path=tmp_file.name,
181
+ media_type="audio/wav",
182
+ filename=f"{uuid.uuid4()}.wav"
183
+ )