nurfarah57 commited on
Commit
23b3803
·
verified ·
1 Parent(s): 26173ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -1,20 +1,24 @@
1
  import os
 
 
 
 
 
 
 
2
  import io
3
  import re
4
  import numpy as np
5
  import scipy.io.wavfile
 
6
  from fastapi import FastAPI
7
- from pydantic import BaseModel
8
  from fastapi.responses import StreamingResponse
9
- import torch
10
  from transformers import VitsModel, AutoTokenizer
11
 
12
- # Use /tmp for cache to avoid permission errors
13
- os.environ["HF_HOME"] = "/tmp"
14
-
15
  app = FastAPI()
16
 
17
- # Load model and tokenizer once
18
  model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
19
  tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -94,7 +98,6 @@ async def synthesize(data: TextIn):
94
  with torch.no_grad():
95
  waveform = model(**inputs).waveform.squeeze().cpu().numpy()
96
 
97
- # Convert waveform to WAV bytes
98
  buf = io.BytesIO()
99
  scipy.io.wavfile.write(buf, rate=model.config.sampling_rate, data=(waveform * 32767).astype(np.int16))
100
  buf.seek(0)
 
1
  import os
2
+
3
+ # Set cache directories to /tmp to avoid permission issues in the container
4
+ os.environ["HF_HOME"] = "/tmp"
5
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp"
6
+ os.environ["TORCH_HOME"] = "/tmp"
7
+ os.environ["XDG_CACHE_HOME"] = "/tmp"
8
+
9
  import io
10
  import re
11
  import numpy as np
12
  import scipy.io.wavfile
13
+ import torch
14
  from fastapi import FastAPI
 
15
  from fastapi.responses import StreamingResponse
16
+ from pydantic import BaseModel
17
  from transformers import VitsModel, AutoTokenizer
18
 
 
 
 
19
  app = FastAPI()
20
 
21
+ # Load model and tokenizer once at startup
22
  model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
23
  tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
98
  with torch.no_grad():
99
  waveform = model(**inputs).waveform.squeeze().cpu().numpy()
100
 
 
101
  buf = io.BytesIO()
102
  scipy.io.wavfile.write(buf, rate=model.config.sampling_rate, data=(waveform * 32767).astype(np.int16))
103
  buf.seek(0)