SwikarG commited on
Commit
c361854
·
verified ·
1 Parent(s): 1e0a550

Update modal_tts.py

Browse files
Files changed (1) hide show
  1. modal_tts.py +61 -31
modal_tts.py CHANGED
@@ -1,43 +1,73 @@
1
  import io
2
-
3
  import modal
 
 
 
 
 
 
4
 
5
- image = modal.Image.debian_slim(python_version="3.12").pip_install(
6
- "chatterbox-tts==0.1.1", "fastapi[standard]"
 
 
 
7
  )
8
- app = modal.App("chatterbox-api-example", image=image)
9
 
10
- with image.imports():
11
- import torchaudio as ta
12
- from chatterbox.tts import ChatterboxTTS
13
- from fastapi.responses import StreamingResponse
14
 
 
 
15
 
16
- @app.cls(gpu="a10g", scaledown_window=60 * 5, enable_memory_snapshot=True)
17
- @modal.concurrent(max_inputs=10)
18
- class Chatterbox:
19
  @modal.enter()
20
  def load(self):
21
- self.model = ChatterboxTTS.from_pretrained(device="cuda")
22
-
23
- @modal.fastapi_endpoint(docs=True, method="POST")
24
- def generate(self, prompt: str):
25
- # Generate audio waveform from the input text
26
- wav = self.model.generate(prompt)
27
 
28
- # Create an in-memory buffer to store the WAV file
 
 
 
 
 
 
 
 
 
 
29
  buffer = io.BytesIO()
30
-
31
- # Save the generated audio to the buffer in WAV format
32
- # Uses the model's sample rate and WAV format
33
- ta.save(buffer, wav, self.model.sr, format="wav")
34
-
35
- # Reset buffer position to the beginning for reading
36
  buffer.seek(0)
37
-
38
- # Return the audio as a streaming response with appropriate MIME type.
39
- # This allows for browsers to playback audio directly.
40
- return StreamingResponse(
41
- io.BytesIO(buffer.read()),
42
- media_type="audio/wav",
43
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import io
 
2
  import modal
3
+ from pydantic import BaseModel
4
+
5
+ # Request model for JSON body
6
+ class TTSRequest(BaseModel):
7
+ prompt: str
8
+ use_music: bool = True
9
 
10
+ # Shared image for all Modal functions
11
+ image = (
12
+ modal.Image.debian_slim(python_version="3.10")
13
+ .pip_install("chatterbox-tts==0.1.1", "fastapi[standard]", "pydub", "ffmpeg")
14
+ .apt_install("ffmpeg") # Required by pydub
15
  )
 
16
 
17
+ # Attach Volume
18
+ volume = modal.Volume.from_name("background-music")
 
 
19
 
20
+ # Modal App
21
+ app = modal.App("NewsShots_TTS_", image=image)
22
 
23
+ # TTS Class
24
+ @app.cls(gpu="a10g", scaledown_window=60 * 10, volumes={"/music": volume})
25
+ class ChatterboxWithMusic:
26
  @modal.enter()
27
  def load(self):
28
+ from chatterbox.tts import ChatterboxTTS
29
+ from pydub import AudioSegment
30
+ self.tts_model = ChatterboxTTS.from_pretrained(device="cuda")
31
+ self.AudioSegment = AudioSegment
 
 
32
 
33
+ @modal.fastapi_endpoint(method="POST")
34
+ def generate(self, request: TTSRequest):
35
+ import torchaudio
36
+ from fastapi.responses import StreamingResponse
37
+
38
+ # Extract data from request body
39
+ prompt = request.prompt
40
+ use_music = request.use_music
41
+
42
+ # Generate speech from prompt
43
+ wav_tensor = self.tts_model.generate(prompt)
44
  buffer = io.BytesIO()
45
+ torchaudio.save(buffer, wav_tensor, self.tts_model.sr, format="wav")
 
 
 
 
 
46
  buffer.seek(0)
47
+
48
+ # Convert to AudioSegment
49
+ tts_audio = self.AudioSegment.from_file(buffer, format="wav")
50
+
51
+ # Try to load background music
52
+ if use_music:
53
+ try:
54
+ with open("/music/music/download.mp3", "rb") as f:
55
+ music_bytes = f.read()
56
+ background = self.AudioSegment.from_file(io.BytesIO(music_bytes))
57
+ background = background - 15
58
+ if len(background) < len(tts_audio):
59
+ background *= (len(tts_audio) // len(background) + 1)
60
+ background = background[:len(tts_audio)]
61
+ final_audio = tts_audio.overlay(background)
62
+ except FileNotFoundError:
63
+ final_audio = tts_audio
64
+ else:
65
+ final_audio = tts_audio
66
+
67
+ # Export mixed audio to buffer
68
+ final_buffer = io.BytesIO()
69
+ final_audio.export(final_buffer, format="mp3")
70
+ final_buffer.seek(0)
71
+
72
+ # Stream as response
73
+ return StreamingResponse(final_buffer, media_type="audio/mpeg")