Spaces:
Paused
Paused
| from fastapi import FastAPI | |
| import schemas | |
| import uvicorn | |
| from starlette.middleware.cors import CORSMiddleware | |
| from functions import * | |
| import base64 | |
| import os | |
| import traceback | |
| from bark import SAMPLE_RATE, generate_audio, preload_models | |
| import soundfile as sf | |
| import wave | |
| import numpy as np | |
| import nltk | |
| # fastapi port | |
| server_port = 7860 | |
| # Preload model | |
| preload_models() | |
| app = FastAPI(docs_url=None, redoc_url=None) | |
| # Set allowed access domain names | |
| origins = ["*"] # set to "*" means all. | |
| def concatenate_wavs(wav_files, output_file, silence_duration=0.3): | |
| wavs = [wave.open(f, 'rb') for f in wav_files] | |
| sampwidth = wavs[0].getsampwidth() | |
| framerate = wavs[0].getframerate() | |
| nchannels = wavs[0].getnchannels() | |
| samples = [wav.readframes(wav.getnframes()) for wav in wavs] | |
| total_frames = sum(len(s) for s in samples) + int(silence_duration * framerate * nchannels * sampwidth) | |
| output = wave.open(output_file, 'wb') | |
| output.setparams((nchannels, sampwidth, framerate, total_frames, 'NONE', 'Uncompressed')) | |
| for s in samples: | |
| output.writeframes(s) | |
| silence_frame = np.zeros((int(silence_duration * framerate), 2)).astype(np.int16).tobytes() | |
| for i in range(int(nchannels / 2)): | |
| output.writeframes(silence_frame) | |
| output.close() | |
| # Set cross domain parameter transfer | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, # Set allowed origins sources | |
| allow_credentials=True, | |
| allow_methods=["*"], # Set up HTTP methods that allow cross domain access, such as get, post, put, etc. | |
| allow_headers=["*"]) # Allowing cross domain headers can be used to identify sources and other functions. | |
| async def tts_bark(item: schemas.generate_web): | |
| time_start = time.time() | |
| text = item.text | |
| print(f"{text=}") | |
| try: | |
| sentences = nltk.sent_tokenize(text) | |
| idx = 1 | |
| wavs = [] | |
| for s in sentences: | |
| audio_array = generate_audio(s, history_prompt="en_speaker_8", text_temp=0.6, waveform_temp=0.6) | |
| fname = f"tmp-{idx}.wav" | |
| sf.write(fname, audio_array, SAMPLE_RATE) | |
| idx += 1 | |
| wavs.append(fname) | |
| file_name_pre = f"out-{time.time()}" | |
| file_name_wav = file_name_pre + ".wav" | |
| file_name_ogg = file_name_pre + ".ogg" | |
| concatenate_wavs(wavs, file_name_wav) | |
| # convert to OGG | |
| os.system("ffmpeg -i " + file_name_wav + " -c:a libopus -b:a 64k -y " + file_name_ogg) | |
| with open(file_name_ogg, "rb") as f: | |
| audio_content = f.read() | |
| base64_audio = base64.b64encode(audio_content).decode("utf-8") | |
| res = {"file_base64": base64_audio, | |
| "audio_text": text, | |
| "file_name": file_name_ogg, | |
| } | |
| print_log(item, res, time_start) | |
| os.remove(file_name_wav) | |
| os.remove(file_name_ogg) | |
| return res | |
| except Exception as err: | |
| res = {"code": 9, "msg": "api error", "err": str(err), "traceback": traceback.format_exc()} | |
| print_log(item, res, time_start) | |
| return res | |
| if __name__ == '__main__': | |
| print_env(server_port) | |
| uvicorn.run(app="main:app", host="0.0.0.0", port=server_port, reload=False) |