from fastapi import FastAPI, Form, UploadFile, File from fastapi.responses import FileResponse, StreamingResponse from pydantic import BaseModel from enum import Enum from typing import Literal from datetime import datetime import tempfile import wave import io import os import re from piper import PiperVoice, SynthesisConfig from faster_whisper import WhisperModel # -------------------- CONFIG -------------------- VOICE_DIR = "actors" WHISPER_MODEL_PATH = "./models/faster-whisper-tiny" syn_config = SynthesisConfig( volume=1.0, length_scale=1.15, noise_scale=0.55, noise_w_scale=0.7, normalize_audio=True, ) # -------------------- ENUMS -------------------- class VoiceActor(str, Enum): alba = "en_GB-alba-medium.onnx" hfc_female = "en_US-hfc_female-medium.onnx" danny = "en_US-danny-low.onnx" lessac = "en_US-lessac-high.onnx" libritts = "en_US-libritts-high.onnx" cori = "en_GB-cori-high.onnx" class Input(BaseModel): text: str actor: Literal[ "en_GB-alba-medium.onnx", "en_US-hfc_female-medium.onnx", "en_US-danny-low.onnx", "en_US-lessac-high.onnx", ] | None = VoiceActor.alba.value # -------------------- APP -------------------- app = FastAPI(title="Fast TTS + STT API") # -------------------- MODEL CACHE -------------------- print("🔹 Loading Whisper model...") stt_model = WhisperModel( WHISPER_MODEL_PATH, device="cpu", compute_type="int8", cpu_threads=os.cpu_count(), num_workers=1, ) print("🔹 Whisper loaded") voice_cache: dict[str, PiperVoice] = {} def get_voice(actor: str) -> PiperVoice: if actor not in voice_cache: voice_cache[actor] = PiperVoice.load(f"{VOICE_DIR}/{actor}") return voice_cache[actor] def chunk_text(text: str, max_tokens: int = 150): sentences = re.split(r'(?<=[.!?])\s+', text.strip()) chunks = [] current = [] for sentence in sentences: words = sentence.split() if len(current) + len(words) <= max_tokens: current.extend(words) else: chunks.append(" ".join(current)) current = words if current: chunks.append(" ".join(current)) return chunks def synthesize_chunked_tts(text: str, voice, syn_config): chunks = chunk_text(text, max_tokens=150) output = io.BytesIO() sample_rate = voice.config.sample_rate with wave.open(output, "wb") as out_wav: out_wav.setnchannels(1) out_wav.setsampwidth(2) out_wav.setframerate(sample_rate) for chunk in chunks: buffer = io.BytesIO() with wave.open(buffer, "wb") as temp_wav: temp_wav.setnchannels(1) temp_wav.setsampwidth(2) temp_wav.setframerate(sample_rate) voice.synthesize_wav( chunk, temp_wav, syn_config=syn_config ) buffer.seek(0) with wave.open(buffer, "rb") as temp_wav: out_wav.writeframes(temp_wav.readframes(temp_wav.getnframes())) output.seek(0) return output # -------------------- ROUTES -------------------- @app.get("/") def root(): return {"status": "ok"} # -------- TTS (JSON, returns file) -------- @app.post("/tts-demo") def tts_demo(input: Input): voice = get_voice(input.actor) temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") temp_path = temp_file.name temp_file.close() with wave.open(temp_path, "wb") as wav: wav.setnchannels(1) wav.setsampwidth(2) wav.setframerate(voice.config.sample_rate) voice.synthesize_wav(input.text, wav, syn_config=syn_config) return FileResponse( temp_path, filename=f"tts-{int(datetime.now().timestamp())}.wav", media_type="audio/wav", ) # -------- TTS (FORM, STREAMING – FASTEST) -------- @app.post("/tts") def tts( text: str = Form(...), actor: VoiceActor = Form(VoiceActor.alba), ): voice = get_voice(actor.value) buffer = io.BytesIO() with wave.open(buffer, "wb") as wav: wav.setnchannels(1) wav.setsampwidth(2) wav.setframerate(voice.config.sample_rate) voice.synthesize_wav(text, wav, syn_config=syn_config) buffer.seek(0) return StreamingResponse(buffer, media_type="audio/wav") # -------- STT ONLY -------- @app.post("/stt") async def speech_to_text(file: UploadFile = File(...)): audio_bytes = await file.read() with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: f.write(audio_bytes) temp_path = f.name segments, info = stt_model.transcribe( temp_path, beam_size=1, language="en", vad_filter=True, vad_parameters={"min_silence_duration_ms": 500}, ) os.unlink(temp_path) return { "text": " ".join(seg.text for seg in segments), "language": info.language, "duration": info.duration, } @app.post("/speech") def tts( text: str = Form(...), actor: VoiceActor = Form(VoiceActor.alba), ): voice = get_voice(actor.value) audio_buffer = synthesize_chunked_tts( text=text, voice=voice, syn_config=syn_config, ) return StreamingResponse(audio_buffer, media_type="audio/wav") # -------- STT → TTS -------- @app.post("/convert") async def convert(file: UploadFile = File(...)): audio_bytes = await file.read() with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: f.write(audio_bytes) temp_path = f.name segments, _ = stt_model.transcribe( temp_path, beam_size=1, language="en", vad_filter=True, vad_parameters={"min_silence_duration_ms": 500}, ) os.unlink(temp_path) text = " ".join(seg.text for seg in segments) voice = get_voice(VoiceActor.alba.value) buffer = io.BytesIO() with wave.open(buffer, "wb") as wav: wav.setnchannels(1) wav.setsampwidth(2) wav.setframerate(voice.config.sample_rate) voice.synthesize_wav(text, wav, syn_config=syn_config) buffer.seek(0) return StreamingResponse(buffer, media_type="audio/wav")