Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Query | |
| from fastapi.responses import StreamingResponse | |
| import os | |
| from os import environ as env | |
| import torch | |
| import time | |
| import nltk | |
| import io | |
| import base64 | |
| import torchaudio | |
| from fastapi.responses import JSONResponse | |
| from app.inference import inference, LFinference, compute_style | |
| import numpy as np | |
| nltk.download('punkt') | |
| nltk.download('punkt_tab') | |
| app = FastAPI() | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| async def read_root(): | |
| #return {"details": f"Hello! This is {env['SECRET_API_KEY']} environment"} | |
| #return {"details": f"Hello Stream!"} | |
| #return {"details": f"Hello Stream! This is {env['API_KEY_SECRET']} environment running OK!"} | |
| return {"details": "Environment is running OK!"} | |
| async def synthesize( | |
| text: str, | |
| return_base64: bool = True, | |
| ################################################### | |
| diffusion_steps: int = Query(5, ge=5, le=200), | |
| embedding_scale: float = Query(1.0, ge=1.0, le=5.0) | |
| ################################################### | |
| ): | |
| try: | |
| start = time.time() | |
| noise = torch.randn(1, 1, 256).to(device) | |
| wav = inference(text, noise, diffusion_steps=diffusion_steps, embedding_scale=embedding_scale) | |
| rtf = (time.time() - start) / (len(wav) / 24000) | |
| if return_base64: | |
| audio_buffer = io.BytesIO() | |
| torchaudio.save(audio_buffer, torch.tensor(wav).unsqueeze(0), 24000, format="wav") | |
| audio_buffer.seek(0) | |
| audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') | |
| return JSONResponse(content={"RTF": rtf, "audio_base64": audio_base64}) | |
| else: | |
| return JSONResponse(content={"RTF": rtf, "audio": wav.tolist()}) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def synthesize_longform( | |
| passage: str, | |
| return_base64: bool = False, | |
| ################################################### | |
| alpha: float = Query(0.7, ge=0.0, le=1.0), | |
| diffusion_steps: int = Query(10, ge=5, le=200), | |
| embedding_scale: float = Query(1.5, ge=1.0, le=5.0) | |
| ################################################### | |
| ): | |
| try: | |
| sentences = passage.split('.') # simple split | |
| wavs = [] | |
| s_prev = None | |
| start = time.time() | |
| for text in sentences: | |
| if text.strip() == "": | |
| continue | |
| text += '.' # add it back | |
| noise = torch.randn(1, 1, 256).to(device) # Generate noise | |
| wav, s_prev = LFinference(text, s_prev, noise, alpha=0.7, | |
| diffusion_steps=diffusion_steps, | |
| embedding_scale=embedding_scale) | |
| wavs.append(wav) | |
| final_wav = np.concatenate(wavs) # Concatenate all wavs | |
| rtf = (time.time() - start) / (len(final_wav) / 24000) | |
| audio_buffer = io.BytesIO() | |
| torchaudio.save(audio_buffer, torch.tensor(final_wav).unsqueeze(0), 24000, format="wav") | |
| audio_buffer.seek(0) | |
| if return_base64: | |
| audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') | |
| return JSONResponse(content={"RTF": rtf, "audio_base64": audio_base64}) | |
| else: | |
| #return JSONResponse(content={"RTF": rtf, "audio": final_wav.tolist()}) | |
| return StreamingResponse(audio_buffer, media_type="audio/wav") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def synthesize_with_emotion( | |
| texts: dict, | |
| return_base64: bool = True, | |
| ################################################### | |
| diffusion_steps: int = Query(100, ge=5, le=200), | |
| embedding_scale: float = Query(5.0, ge=1.0, le=5.0) | |
| ################################################### | |
| ): | |
| try: | |
| results = [] | |
| for emotion, text in texts.items(): | |
| noise = torch.randn(1, 1, 256).to(device) | |
| wav = inference(text, noise, diffusion_steps=diffusion_steps, | |
| embedding_scale=embedding_scale) | |
| if return_base64: | |
| audio_buffer = io.BytesIO() | |
| torchaudio.save(audio_buffer, torch.tensor(wav).unsqueeze(0), 24000, format="wav") | |
| audio_buffer.seek(0) | |
| audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') | |
| results.append({ | |
| "emotion": emotion, | |
| "audio_base64": audio_base64 | |
| }) | |
| else: | |
| results.append({ | |
| "emotion": emotion, | |
| "audio": wav.tolist() | |
| }) | |
| return JSONResponse(content={"results": results}) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def synthesize_streaming_audio( | |
| text: str, | |
| return_base64: bool = False, | |
| ################################################### | |
| diffusion_steps: int = Query(5, ge=5, le=200), | |
| embedding_scale: float = Query(1.0, ge=1.0, le=5.0) | |
| ################################################### | |
| ): | |
| try: | |
| start = time.time() | |
| noise = torch.randn(1, 1, 256).to(device) | |
| wav = inference(text, noise, diffusion_steps=diffusion_steps, embedding_scale=embedding_scale) | |
| rtf = (time.time() - start) / (len(wav) / 24000) | |
| audio_buffer = io.BytesIO() | |
| torchaudio.save(audio_buffer, torch.tensor(wav).unsqueeze(0), 24000, format="wav") | |
| audio_buffer.seek(0) | |
| if return_base64: | |
| audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') | |
| return JSONResponse(content={"RTF": rtf, "audio_base64": audio_base64}) | |
| else: | |
| return StreamingResponse(audio_buffer, media_type="audio/wav") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |