nurfarah57 commited on
Commit
1ad4fd7
·
verified ·
1 Parent(s): 830736b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -14
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
 
3
- # === IMPORTANT ===
4
- # Set cache directories BEFORE any imports that use Hugging Face or PyTorch caching
5
  os.environ["HF_HOME"] = "/tmp"
6
  os.environ["TRANSFORMERS_CACHE"] = "/tmp"
7
  os.environ["TORCH_HOME"] = "/tmp"
@@ -9,17 +8,18 @@ os.environ["XDG_CACHE_HOME"] = "/tmp"
9
 
10
  import io
11
  import re
 
12
  import numpy as np
13
  import scipy.io.wavfile
14
  import torch
15
- from fastapi import FastAPI
16
  from fastapi.responses import StreamingResponse
17
  from pydantic import BaseModel
18
  from transformers import VitsModel, AutoTokenizer
19
 
20
  app = FastAPI()
21
 
22
- # Load model and tokenizer once at startup
23
  model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
24
  tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
25
 
@@ -27,7 +27,6 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  model.to(device)
28
  model.eval()
29
 
30
- # Somali number words dictionary for normalization
31
  number_words = {
32
  0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan",
33
  6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban",
@@ -75,11 +74,9 @@ def number_to_words(number: int) -> str:
75
  return str(number)
76
 
77
  def normalize_text(text: str) -> str:
78
- # Replace digits with Somali words
79
  numbers = re.findall(r'\d+', text)
80
  for num in numbers:
81
  text = text.replace(num, number_to_words(int(num)))
82
- # Additional Somali text normalizations
83
  text = text.replace("KH", "qa").replace("Z", "S")
84
  text = text.replace("SH", "SHa'a").replace("DH", "Dha'a")
85
  text = text.replace("ZamZam", "SamSam")
@@ -89,28 +86,60 @@ class TextIn(BaseModel):
89
  inputs: str
90
 
91
  @app.post("/synthesize")
92
- async def synthesize(data: TextIn):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  text = normalize_text(data.inputs)
 
 
94
  inputs = tokenizer(text, return_tensors="pt").to(device)
95
 
 
96
  with torch.no_grad():
97
  output = model(**inputs)
98
- waveform = output.waveform.squeeze().cpu().numpy()
99
 
100
- # Mono conversion if multi-channel
101
- if waveform.ndim > 1:
102
- waveform = waveform.mean(axis=0)
 
 
 
 
 
 
 
 
 
103
 
104
  waveform = waveform.astype(np.float32)
105
  waveform = np.clip(waveform, -1.0, 1.0)
106
 
107
  pcm_waveform = (waveform * 32767).astype(np.int16)
108
 
 
 
 
109
  buf = io.BytesIO()
110
  sample_rate = getattr(model.config, "sampling_rate", 22050)
 
 
111
  scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform)
112
  buf.seek(0)
113
 
114
- print(f"Generated audio length: {pcm_waveform.shape[0]} samples, Sample rate: {sample_rate}")
115
-
116
  return StreamingResponse(buf, media_type="audio/wav")
 
1
  import os
2
 
3
+ # Set cache dirs before imports to fix permission errors
 
4
  os.environ["HF_HOME"] = "/tmp"
5
  os.environ["TRANSFORMERS_CACHE"] = "/tmp"
6
  os.environ["TORCH_HOME"] = "/tmp"
 
8
 
9
  import io
10
  import re
11
+ import math
12
  import numpy as np
13
  import scipy.io.wavfile
14
  import torch
15
+ from fastapi import FastAPI, Query
16
  from fastapi.responses import StreamingResponse
17
  from pydantic import BaseModel
18
  from transformers import VitsModel, AutoTokenizer
19
 
20
  app = FastAPI()
21
 
22
+ # Load model/tokenizer once at startup
23
  model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
24
  tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
25
 
 
27
  model.to(device)
28
  model.eval()
29
 
 
30
  number_words = {
31
  0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan",
32
  6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban",
 
74
  return str(number)
75
 
76
  def normalize_text(text: str) -> str:
 
77
  numbers = re.findall(r'\d+', text)
78
  for num in numbers:
79
  text = text.replace(num, number_to_words(int(num)))
 
80
  text = text.replace("KH", "qa").replace("Z", "S")
81
  text = text.replace("SH", "SHa'a").replace("DH", "Dha'a")
82
  text = text.replace("ZamZam", "SamSam")
 
86
  inputs: str
87
 
88
  @app.post("/synthesize")
89
+ async def synthesize(data: TextIn, test: bool = Query(False, description="Set true to generate test tone instead of TTS")):
90
+ if test:
91
+ # Generate 2-second 440Hz sine wave for testing playback
92
+ duration_s = 2.0
93
+ sample_rate = 22050
94
+ t = np.linspace(0, duration_s, int(sample_rate*duration_s), endpoint=False)
95
+ freq = 440
96
+ waveform = 0.5 * np.sin(2 * math.pi * freq * t).astype(np.float32)
97
+ pcm_waveform = (waveform * 32767).astype(np.int16)
98
+
99
+ buf = io.BytesIO()
100
+ scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform)
101
+ buf.seek(0)
102
+
103
+ print(f"[TEST MODE] Generated test tone: {pcm_waveform.shape[0]} samples, Sample rate: {sample_rate}")
104
+
105
+ return StreamingResponse(buf, media_type="audio/wav")
106
+
107
+ # Normalize input text
108
  text = normalize_text(data.inputs)
109
+
110
+ # Tokenize and move to device
111
  inputs = tokenizer(text, return_tensors="pt").to(device)
112
 
113
+ # Generate waveform
114
  with torch.no_grad():
115
  output = model(**inputs)
 
116
 
117
+ print("Raw waveform shape:", output.waveform.shape)
118
+
119
+ waveform = output.waveform.cpu().numpy()
120
+
121
+ # Process waveform dimensions
122
+ if waveform.ndim == 3:
123
+ waveform = waveform[0] # batch dimension
124
+ if waveform.ndim == 2:
125
+ waveform = waveform.mean(axis=0) # average channels to mono
126
+
127
+ print("Processed waveform shape:", waveform.shape)
128
+ print("Waveform min/max before clip:", waveform.min(), waveform.max())
129
 
130
  waveform = waveform.astype(np.float32)
131
  waveform = np.clip(waveform, -1.0, 1.0)
132
 
133
  pcm_waveform = (waveform * 32767).astype(np.int16)
134
 
135
+ print("PCM waveform shape:", pcm_waveform.shape)
136
+ print("PCM waveform min/max:", pcm_waveform.min(), pcm_waveform.max())
137
+
138
  buf = io.BytesIO()
139
  sample_rate = getattr(model.config, "sampling_rate", 22050)
140
+ print("Sample rate:", sample_rate)
141
+
142
  scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform)
143
  buf.seek(0)
144
 
 
 
145
  return StreamingResponse(buf, media_type="audio/wav")