Update app.py
Browse files
app.py
CHANGED
|
@@ -1,20 +1,24 @@
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import io
|
| 3 |
import re
|
| 4 |
import numpy as np
|
| 5 |
import scipy.io.wavfile
|
|
|
|
| 6 |
from fastapi import FastAPI
|
| 7 |
-
from pydantic import BaseModel
|
| 8 |
from fastapi.responses import StreamingResponse
|
| 9 |
-
import
|
| 10 |
from transformers import VitsModel, AutoTokenizer
|
| 11 |
|
| 12 |
-
# Use /tmp for cache to avoid permission errors
|
| 13 |
-
os.environ["HF_HOME"] = "/tmp"
|
| 14 |
-
|
| 15 |
app = FastAPI()
|
| 16 |
|
| 17 |
-
# Load model and tokenizer once
|
| 18 |
model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
|
| 19 |
tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
|
| 20 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
@@ -94,7 +98,6 @@ async def synthesize(data: TextIn):
|
|
| 94 |
with torch.no_grad():
|
| 95 |
waveform = model(**inputs).waveform.squeeze().cpu().numpy()
|
| 96 |
|
| 97 |
-
# Convert waveform to WAV bytes
|
| 98 |
buf = io.BytesIO()
|
| 99 |
scipy.io.wavfile.write(buf, rate=model.config.sampling_rate, data=(waveform * 32767).astype(np.int16))
|
| 100 |
buf.seek(0)
|
|
|
|
| 1 |
import os
|
| 2 |
+
|
| 3 |
+
# Set cache directories to /tmp to avoid permission issues in the container
|
| 4 |
+
os.environ["HF_HOME"] = "/tmp"
|
| 5 |
+
os.environ["TRANSFORMERS_CACHE"] = "/tmp"
|
| 6 |
+
os.environ["TORCH_HOME"] = "/tmp"
|
| 7 |
+
os.environ["XDG_CACHE_HOME"] = "/tmp"
|
| 8 |
+
|
| 9 |
import io
|
| 10 |
import re
|
| 11 |
import numpy as np
|
| 12 |
import scipy.io.wavfile
|
| 13 |
+
import torch
|
| 14 |
from fastapi import FastAPI
|
|
|
|
| 15 |
from fastapi.responses import StreamingResponse
|
| 16 |
+
from pydantic import BaseModel
|
| 17 |
from transformers import VitsModel, AutoTokenizer
|
| 18 |
|
|
|
|
|
|
|
|
|
|
| 19 |
app = FastAPI()
|
| 20 |
|
| 21 |
+
# Load model and tokenizer once at startup
|
| 22 |
model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
|
| 23 |
tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
|
| 24 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 98 |
with torch.no_grad():
|
| 99 |
waveform = model(**inputs).waveform.squeeze().cpu().numpy()
|
| 100 |
|
|
|
|
| 101 |
buf = io.BytesIO()
|
| 102 |
scipy.io.wavfile.write(buf, rate=model.config.sampling_rate, data=(waveform * 32767).astype(np.int16))
|
| 103 |
buf.seek(0)
|