Spaces:
Runtime error
Runtime error
| # StyleTTS 2 HTTP Streaming API by @fakerybakery - Copyright (c) 2023 mrfakename. All rights reserved. | |
| # Docs: API_DOCS.md | |
| # To-Do: | |
| # * Support voice cloning | |
| # * Implement authentication, user "credits" system w/ SQLite3 | |
| import io | |
| import markdown | |
| from tortoise.utils.text import split_and_recombine_text | |
| from flask import Flask, Response, request, jsonify | |
| import numpy as np | |
| import ljinference | |
| import torch | |
| import hashlib | |
| from scipy.io.wavfile import read, write | |
| from flask_cors import CORS | |
| import os | |
| import torchaudio | |
| def genHeader(sampleRate, bitsPerSample, channels): | |
| datasize = 2000 * 10**6 | |
| o = bytes("RIFF", "ascii") | |
| o += (datasize + 36).to_bytes(4, "little") | |
| o += bytes("WAVE", "ascii") | |
| o += bytes("fmt ", "ascii") | |
| o += (16).to_bytes(4, "little") | |
| o += (1).to_bytes(2, "little") | |
| o += (channels).to_bytes(2, "little") | |
| o += (sampleRate).to_bytes(4, "little") | |
| o += (sampleRate * channels * bitsPerSample // 8).to_bytes(4, "little") | |
| o += (channels * bitsPerSample // 8).to_bytes(2, "little") | |
| o += (bitsPerSample).to_bytes(2, "little") | |
| o += bytes("data", "ascii") | |
| o += (datasize).to_bytes(4, "little") | |
| return o | |
| import phonemizer | |
| global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True) | |
| print("Starting Flask app") | |
| app = Flask(__name__) | |
| cors = CORS(app) | |
| def index(): | |
| with open('API_DOCS.md', 'r') as f: | |
| return markdown.markdown(f.read()) | |
| cache_dir = 'cache' | |
| if not os.path.exists(cache_dir): | |
| os.makedirs(cache_dir, exist_ok=True) | |
| def serve_wav(): | |
| if request.method == 'GET': | |
| request.form = request.args | |
| if 'text' not in request.form: | |
| if 'text' not in request.json: | |
| error_response = {'error': 'Missing required fields. Please include "text" in your request.'} | |
| return jsonify(error_response), 400 | |
| else: | |
| text = request.json['text'] | |
| else: | |
| text = request.form['text'].strip() | |
| if not text.strip(): | |
| error_response = {'error': 'Empty text. Please ensure "text" in not empty.'} | |
| return jsonify(error_response), 400 | |
| texts = split_and_recombine_text(text) | |
| audios = [] | |
| noise = torch.randn(1,1,256).to('cuda' if torch.cuda.is_available() else 'cpu') | |
| for t in texts: | |
| # check for cache | |
| hash = hashlib.sha256(t.lower().encode()).hexdigest() | |
| if os.path.exists(os.path.join(cache_dir, hash + '.wav')): | |
| audios.append(read(os.path.join(cache_dir, hash + '.wav'))[1]) | |
| else: | |
| aud = ljinference.inference(t, noise, diffusion_steps=7, embedding_scale=1) | |
| write(os.path.join(cache_dir, hash + '.wav'), 24000, aud) | |
| audios.append(aud) | |
| output_buffer = io.BytesIO() | |
| write(output_buffer, 24000, np.concatenate(audios)) | |
| response = Response(output_buffer.getvalue()) | |
| response.headers["Content-Type"] = "audio/wav" | |
| return response | |
| if __name__ == "__main__": | |
| app.run("0.0.0.0", port=7860) |