import torch import numpy as np from flask import Flask, Response, request, stream_with_context from pocket_tts import TTSModel app = Flask(__name__) # Load model globally print("Loading TTS Model...") model = TTSModel.load_model() # Pre-load voice state voice_state = model.get_state_for_audio_prompt( "hf://kyutai/tts-voices/alba-mackenna/casual.wav" ) def generate_wav_header(sample_rate): """Generates a standard PCM WAV header.""" # Using a smaller size for the header; many players handle 0 length fine # or you can re-calculate if the total length is known beforehand. data_size = 0 # 0 indicates unknown length for streaming o = b"RIFF" + (data_size + 36).to_bytes(4, 'little') o += b"WAVE" + b"fmt " + (16).to_bytes(4, 'little') o += (1).to_bytes(2, 'little') # PCM o += (1).to_bytes(2, 'little') # Mono o += sample_rate.to_bytes(4, 'little') o += (sample_rate * 2).to_bytes(4, 'little') o += (2).to_bytes(2, 'little') o += (16).to_bytes(2, 'little') o += b"data" + data_size.to_bytes(4, 'little') return o @app.route('/stream') def stream_audio(): text = request.args.get('text', 'Streaming real-time audio with Pocket TTS.') def generate(): # Yield header yield generate_wav_header(model.sample_rate) # Stream chunks try: for chunk in model.generate_audio_stream(voice_state, text): # Ensure device-agnostic conversion to CPU audio_data = chunk.cpu().clamp(-1, 1).numpy() pcm_data = (audio_data * 32767).astype(np.int16).tobytes() yield pcm_data except Exception as e: print(f"Error during streaming: {e}") return Response( stream_with_context(generate()), mimetype="audio/wav", headers={"Content-Disposition": "inline; filename=output.wav"} ) if __name__ == '__main__': # Use Gunicorn/Uvicorn for production; threaded is fine for dev app.run(host='0.0.0.0', port=7860, threaded=True)