nurfarah57 commited on
Commit
8b39b35
·
verified ·
1 Parent(s): 1730b68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -25
app.py CHANGED
@@ -1,11 +1,4 @@
1
  import os
2
-
3
- # Set cache directories to /tmp to avoid permission issues in the container
4
- os.environ["HF_HOME"] = "/tmp"
5
- os.environ["TRANSFORMERS_CACHE"] = "/tmp"
6
- os.environ["TORCH_HOME"] = "/tmp"
7
- os.environ["XDG_CACHE_HOME"] = "/tmp"
8
-
9
  import io
10
  import re
11
  import numpy as np
@@ -16,15 +9,23 @@ from fastapi.responses import StreamingResponse
16
  from pydantic import BaseModel
17
  from transformers import VitsModel, AutoTokenizer
18
 
 
 
 
 
 
 
19
  app = FastAPI()
20
 
21
- # Load model and tokenizer once at startup
22
  model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
23
  tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
 
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  model.to(device)
26
  model.eval()
27
 
 
28
  number_words = {
29
  0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan",
30
  6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban",
@@ -36,8 +37,7 @@ number_words = {
36
  100: "boqol", 1000: "kun"
37
  }
38
 
39
- def number_to_words(number):
40
- number = int(number)
41
  if number < 20:
42
  return number_words[number]
43
  elif number < 100:
@@ -56,14 +56,7 @@ def number_to_words(number):
56
  words.append("kun")
57
  else:
58
  words.append(number_to_words(thousands) + " kun")
59
- if remainder >= 100:
60
- hundreds, rem2 = divmod(remainder, 100)
61
- if hundreds:
62
- boqol_text = (number_words[hundreds] + " boqol") if hundreds > 1 else "boqol"
63
- words.append(boqol_text)
64
- if rem2:
65
- words.append("iyo " + number_to_words(rem2))
66
- elif remainder:
67
  words.append("iyo " + number_to_words(remainder))
68
  return " ".join(words)
69
  elif number < 1000000000:
@@ -79,10 +72,12 @@ def number_to_words(number):
79
  else:
80
  return str(number)
81
 
82
- def normalize_text(text):
 
83
  numbers = re.findall(r'\d+', text)
84
  for num in numbers:
85
- text = text.replace(num, number_to_words(num))
 
86
  text = text.replace("KH", "qa").replace("Z", "S")
87
  text = text.replace("SH", "SHa'a").replace("DH", "Dha'a")
88
  text = text.replace("ZamZam", "SamSam")
@@ -93,23 +88,38 @@ class TextIn(BaseModel):
93
 
94
  @app.post("/synthesize")
95
  async def synthesize(data: TextIn):
 
96
  text = normalize_text(data.inputs)
 
 
97
  inputs = tokenizer(text, return_tensors="pt").to(device)
 
 
98
  with torch.no_grad():
99
- waveform = model(**inputs).waveform.squeeze().cpu().numpy()
 
100
 
101
- # If stereo or multi-channel, convert to mono by averaging channels
102
  if waveform.ndim > 1:
103
  waveform = waveform.mean(axis=0)
104
 
105
- print(f"Audio length: {waveform.shape[0]}, min: {waveform.min()}, max: {waveform.max()}")
 
106
 
 
107
  waveform = np.clip(waveform, -1.0, 1.0)
108
 
109
- buf = io.BytesIO()
 
110
 
 
 
111
  sample_rate = getattr(model.config, "sampling_rate", 22050)
112
- scipy.io.wavfile.write(buf, rate=sample_rate, data=(waveform * 32767).astype(np.int16))
113
  buf.seek(0)
114
 
 
 
 
 
115
  return StreamingResponse(buf, media_type="audio/wav")
 
1
  import os
 
 
 
 
 
 
 
2
  import io
3
  import re
4
  import numpy as np
 
9
  from pydantic import BaseModel
10
  from transformers import VitsModel, AutoTokenizer
11
 
12
+ # Set environment variables to avoid permission issues in container
13
+ os.environ["HF_HOME"] = "/tmp"
14
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp"
15
+ os.environ["TORCH_HOME"] = "/tmp"
16
+ os.environ["XDG_CACHE_HOME"] = "/tmp"
17
+
18
  app = FastAPI()
19
 
20
+ # Load model and tokenizer at startup
21
  model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
22
  tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
23
+
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  model.to(device)
26
  model.eval()
27
 
28
+ # Number-to-Somali words dictionary
29
  number_words = {
30
  0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan",
31
  6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban",
 
37
  100: "boqol", 1000: "kun"
38
  }
39
 
40
+ def number_to_words(number: int) -> str:
 
41
  if number < 20:
42
  return number_words[number]
43
  elif number < 100:
 
56
  words.append("kun")
57
  else:
58
  words.append(number_to_words(thousands) + " kun")
59
+ if remainder:
 
 
 
 
 
 
 
60
  words.append("iyo " + number_to_words(remainder))
61
  return " ".join(words)
62
  elif number < 1000000000:
 
72
  else:
73
  return str(number)
74
 
75
+ def normalize_text(text: str) -> str:
76
+ # Replace numbers with words
77
  numbers = re.findall(r'\d+', text)
78
  for num in numbers:
79
+ text = text.replace(num, number_to_words(int(num)))
80
+ # Additional Somali text normalization rules
81
  text = text.replace("KH", "qa").replace("Z", "S")
82
  text = text.replace("SH", "SHa'a").replace("DH", "Dha'a")
83
  text = text.replace("ZamZam", "SamSam")
 
88
 
89
  @app.post("/synthesize")
90
  async def synthesize(data: TextIn):
91
+ # Normalize and convert text input
92
  text = normalize_text(data.inputs)
93
+
94
+ # Tokenize and move to device
95
  inputs = tokenizer(text, return_tensors="pt").to(device)
96
+
97
+ # Generate waveform with no_grad
98
  with torch.no_grad():
99
+ output = model(**inputs)
100
+ waveform = output.waveform.squeeze().cpu().numpy()
101
 
102
+ # If multi-channel audio, average channels to mono
103
  if waveform.ndim > 1:
104
  waveform = waveform.mean(axis=0)
105
 
106
+ # Convert to float32 if not already
107
+ waveform = waveform.astype(np.float32)
108
 
109
+ # Clip waveform to [-1.0, 1.0]
110
  waveform = np.clip(waveform, -1.0, 1.0)
111
 
112
+ # Convert to 16-bit PCM
113
+ pcm_waveform = (waveform * 32767).astype(np.int16)
114
 
115
+ # Prepare WAV file in memory buffer
116
+ buf = io.BytesIO()
117
  sample_rate = getattr(model.config, "sampling_rate", 22050)
118
+ scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform)
119
  buf.seek(0)
120
 
121
+ # Debug info
122
+ print(f"Generated audio length: {pcm_waveform.shape[0]} samples, Sample rate: {sample_rate}")
123
+
124
+ # Stream response as WAV audio
125
  return StreamingResponse(buf, media_type="audio/wav")