nurfarah57 commited on
Commit
830736b
·
verified ·
1 Parent(s): 8b39b35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -22
app.py CHANGED
@@ -1,4 +1,12 @@
1
  import os
 
 
 
 
 
 
 
 
2
  import io
3
  import re
4
  import numpy as np
@@ -9,15 +17,9 @@ from fastapi.responses import StreamingResponse
9
  from pydantic import BaseModel
10
  from transformers import VitsModel, AutoTokenizer
11
 
12
- # Set environment variables to avoid permission issues in container
13
- os.environ["HF_HOME"] = "/tmp"
14
- os.environ["TRANSFORMERS_CACHE"] = "/tmp"
15
- os.environ["TORCH_HOME"] = "/tmp"
16
- os.environ["XDG_CACHE_HOME"] = "/tmp"
17
-
18
  app = FastAPI()
19
 
20
- # Load model and tokenizer at startup
21
  model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
22
  tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
23
 
@@ -25,7 +27,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  model.to(device)
26
  model.eval()
27
 
28
- # Number-to-Somali words dictionary
29
  number_words = {
30
  0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan",
31
  6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban",
@@ -73,11 +75,11 @@ def number_to_words(number: int) -> str:
73
  return str(number)
74
 
75
  def normalize_text(text: str) -> str:
76
- # Replace numbers with words
77
  numbers = re.findall(r'\d+', text)
78
  for num in numbers:
79
  text = text.replace(num, number_to_words(int(num)))
80
- # Additional Somali text normalization rules
81
  text = text.replace("KH", "qa").replace("Z", "S")
82
  text = text.replace("SH", "SHa'a").replace("DH", "Dha'a")
83
  text = text.replace("ZamZam", "SamSam")
@@ -88,38 +90,27 @@ class TextIn(BaseModel):
88
 
89
  @app.post("/synthesize")
90
  async def synthesize(data: TextIn):
91
- # Normalize and convert text input
92
  text = normalize_text(data.inputs)
93
-
94
- # Tokenize and move to device
95
  inputs = tokenizer(text, return_tensors="pt").to(device)
96
 
97
- # Generate waveform with no_grad
98
  with torch.no_grad():
99
  output = model(**inputs)
100
  waveform = output.waveform.squeeze().cpu().numpy()
101
 
102
- # If multi-channel audio, average channels to mono
103
  if waveform.ndim > 1:
104
  waveform = waveform.mean(axis=0)
105
 
106
- # Convert to float32 if not already
107
  waveform = waveform.astype(np.float32)
108
-
109
- # Clip waveform to [-1.0, 1.0]
110
  waveform = np.clip(waveform, -1.0, 1.0)
111
 
112
- # Convert to 16-bit PCM
113
  pcm_waveform = (waveform * 32767).astype(np.int16)
114
 
115
- # Prepare WAV file in memory buffer
116
  buf = io.BytesIO()
117
  sample_rate = getattr(model.config, "sampling_rate", 22050)
118
  scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform)
119
  buf.seek(0)
120
 
121
- # Debug info
122
  print(f"Generated audio length: {pcm_waveform.shape[0]} samples, Sample rate: {sample_rate}")
123
 
124
- # Stream response as WAV audio
125
  return StreamingResponse(buf, media_type="audio/wav")
 
1
  import os
2
+
3
+ # === IMPORTANT ===
4
+ # Set cache directories BEFORE any imports that use Hugging Face or PyTorch caching
5
+ os.environ["HF_HOME"] = "/tmp"
6
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp"
7
+ os.environ["TORCH_HOME"] = "/tmp"
8
+ os.environ["XDG_CACHE_HOME"] = "/tmp"
9
+
10
  import io
11
  import re
12
  import numpy as np
 
17
  from pydantic import BaseModel
18
  from transformers import VitsModel, AutoTokenizer
19
 
 
 
 
 
 
 
20
  app = FastAPI()
21
 
22
+ # Load model and tokenizer once at startup
23
  model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
24
  tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
25
 
 
27
  model.to(device)
28
  model.eval()
29
 
30
+ # Somali number words dictionary for normalization
31
  number_words = {
32
  0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan",
33
  6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban",
 
75
  return str(number)
76
 
77
  def normalize_text(text: str) -> str:
78
+ # Replace digits with Somali words
79
  numbers = re.findall(r'\d+', text)
80
  for num in numbers:
81
  text = text.replace(num, number_to_words(int(num)))
82
+ # Additional Somali text normalizations
83
  text = text.replace("KH", "qa").replace("Z", "S")
84
  text = text.replace("SH", "SHa'a").replace("DH", "Dha'a")
85
  text = text.replace("ZamZam", "SamSam")
 
90
 
91
  @app.post("/synthesize")
92
  async def synthesize(data: TextIn):
 
93
  text = normalize_text(data.inputs)
 
 
94
  inputs = tokenizer(text, return_tensors="pt").to(device)
95
 
 
96
  with torch.no_grad():
97
  output = model(**inputs)
98
  waveform = output.waveform.squeeze().cpu().numpy()
99
 
100
+ # Mono conversion if multi-channel
101
  if waveform.ndim > 1:
102
  waveform = waveform.mean(axis=0)
103
 
 
104
  waveform = waveform.astype(np.float32)
 
 
105
  waveform = np.clip(waveform, -1.0, 1.0)
106
 
 
107
  pcm_waveform = (waveform * 32767).astype(np.int16)
108
 
 
109
  buf = io.BytesIO()
110
  sample_rate = getattr(model.config, "sampling_rate", 22050)
111
  scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform)
112
  buf.seek(0)
113
 
 
114
  print(f"Generated audio length: {pcm_waveform.shape[0]} samples, Sample rate: {sample_rate}")
115
 
 
116
  return StreamingResponse(buf, media_type="audio/wav")