|
|
import os |
|
|
import io |
|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
from transformers import AutoProcessor, VitsForConditionalGeneration |
|
|
import torch |
|
|
from fastapi.responses import StreamingResponse |
|
|
|
|
|
|
|
|
os.environ["HF_HOME"] = "/tmp" |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
model_name = "Somali-tts/somali_tts_model" |
|
|
processor = AutoProcessor.from_pretrained(model_name) |
|
|
model = VitsForConditionalGeneration.from_pretrained(model_name) |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
|
|
|
class TextInput(BaseModel): |
|
|
inputs: str |
|
|
|
|
|
@app.post("/synthesize") |
|
|
async def synthesize_tts(data: TextInput): |
|
|
inputs = processor(data.inputs, return_tensors="pt").to(device) |
|
|
with torch.no_grad(): |
|
|
audio = model.generate(**inputs) |
|
|
audio = audio.squeeze().cpu().numpy() |
|
|
|
|
|
|
|
|
import soundfile as sf |
|
|
buf = io.BytesIO() |
|
|
sf.write(buf, audio, samplerate=22050, format="WAV") |
|
|
buf.seek(0) |
|
|
|
|
|
return StreamingResponse(buf, media_type="audio/wav") |
|
|
|