NewsShots / modal_tts.py
SwikarG's picture
Update modal_tts.py
c361854 verified
import io
import modal
from pydantic import BaseModel
# Request model for JSON body
class TTSRequest(BaseModel):
prompt: str
use_music: bool = True
# Shared image for all Modal functions
image = (
modal.Image.debian_slim(python_version="3.10")
.pip_install("chatterbox-tts==0.1.1", "fastapi[standard]", "pydub", "ffmpeg")
.apt_install("ffmpeg") # Required by pydub
)
# Attach Volume
volume = modal.Volume.from_name("background-music")
# Modal App
app = modal.App("NewsShots_TTS_", image=image)
# TTS Class
@app.cls(gpu="a10g", scaledown_window=60 * 10, volumes={"/music": volume})
class ChatterboxWithMusic:
@modal.enter()
def load(self):
from chatterbox.tts import ChatterboxTTS
from pydub import AudioSegment
self.tts_model = ChatterboxTTS.from_pretrained(device="cuda")
self.AudioSegment = AudioSegment
@modal.fastapi_endpoint(method="POST")
def generate(self, request: TTSRequest):
import torchaudio
from fastapi.responses import StreamingResponse
# Extract data from request body
prompt = request.prompt
use_music = request.use_music
# Generate speech from prompt
wav_tensor = self.tts_model.generate(prompt)
buffer = io.BytesIO()
torchaudio.save(buffer, wav_tensor, self.tts_model.sr, format="wav")
buffer.seek(0)
# Convert to AudioSegment
tts_audio = self.AudioSegment.from_file(buffer, format="wav")
# Try to load background music
if use_music:
try:
with open("/music/music/download.mp3", "rb") as f:
music_bytes = f.read()
background = self.AudioSegment.from_file(io.BytesIO(music_bytes))
background = background - 15
if len(background) < len(tts_audio):
background *= (len(tts_audio) // len(background) + 1)
background = background[:len(tts_audio)]
final_audio = tts_audio.overlay(background)
except FileNotFoundError:
final_audio = tts_audio
else:
final_audio = tts_audio
# Export mixed audio to buffer
final_buffer = io.BytesIO()
final_audio.export(final_buffer, format="mp3")
final_buffer.seek(0)
# Stream as response
return StreamingResponse(final_buffer, media_type="audio/mpeg")