import os import io from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoProcessor, VitsForConditionalGeneration import torch from fastapi.responses import StreamingResponse # Use /tmp for cache to avoid permission errors os.environ["HF_HOME"] = "/tmp" app = FastAPI() # Load processor and model once on startup model_name = "Somali-tts/somali_tts_model" processor = AutoProcessor.from_pretrained(model_name) model = VitsForConditionalGeneration.from_pretrained(model_name) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) class TextInput(BaseModel): inputs: str @app.post("/synthesize") async def synthesize_tts(data: TextInput): inputs = processor(data.inputs, return_tensors="pt").to(device) with torch.no_grad(): audio = model.generate(**inputs) audio = audio.squeeze().cpu().numpy() # Convert to WAV bytes in-memory import soundfile as sf buf = io.BytesIO() sf.write(buf, audio, samplerate=22050, format="WAV") buf.seek(0) return StreamingResponse(buf, media_type="audio/wav")