MohamedRashad commited on
Commit
d96db07
·
1 Parent(s): d3a0030

Add warmup function to initialize CUDA graphs and improve performance in get_models

Browse files
Files changed (1) hide show
  1. app.py +29 -26
app.py CHANGED
@@ -89,6 +89,11 @@ def get_models():
89
  other_mimi.streaming_forever(1)
90
  lm_gen.streaming_forever(1)
91
 
 
 
 
 
 
92
  _model_cache.update({
93
  "mimi": mimi,
94
  "other_mimi": other_mimi,
@@ -101,6 +106,23 @@ def get_models():
101
  return _model_cache
102
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  def wrap_with_system_tags(text: str) -> str:
106
  """Add system tags as PersonaPlex expects."""
@@ -149,6 +171,13 @@ def generate_response(audio_input, persona: str, voice: str):
149
  import sphn
150
  audio = sphn.resample(audio, sr, mimi.sample_rate)
151
 
 
 
 
 
 
 
 
152
  # Add channel dimension: (T,) -> (1, T)
153
  if audio.ndim == 1:
154
  audio = audio[None, :]
@@ -200,32 +229,6 @@ def generate_response(audio_input, persona: str, voice: str):
200
  if text_token not in (0, 3): # Skip special tokens
201
  text_piece = text_tokenizer.id_to_piece(text_token).replace("▁", " ")
202
  generated_text.append(text_piece)
203
-
204
- # Continue generating with silence to let the model finish speaking
205
- # Add extra frames (approximately 10 seconds of continuation)
206
- extra_frames = int(10 * mimi.frame_rate)
207
-
208
- # Use the correct SINE_TOKENS from lm.py for user audio (simulates silence/background)
209
- # These represent a 440Hz sine wave encoded by Mimi - official PersonaPlex constants
210
- SINE_TOKENS = [430, 1268, 381, 1611, 1095, 1495, 56, 472]
211
- sine_input = torch.tensor(SINE_TOKENS, dtype=torch.long, device=DEVICE).view(1, 8, 1)
212
-
213
- for _ in range(extra_frames):
214
- # Pass sine tokens as user input to simulate silence on user side
215
- tokens = lm_gen.step(sine_input)
216
-
217
- if tokens is None:
218
- continue
219
-
220
- # Decode agent audio
221
- pcm = decode_tokens_to_pcm(mimi, other_mimi, tokens)
222
- generated_frames.append(pcm)
223
-
224
- # Decode text token
225
- text_token = tokens[0, 0, 0].item()
226
- if text_token not in (0, 3): # Skip special tokens
227
- text_piece = text_tokenizer.id_to_piece(text_token).replace("▁", " ")
228
- generated_text.append(text_piece)
229
 
230
  if not generated_frames:
231
  return None, "No audio generated. Try speaking more clearly."
 
89
  other_mimi.streaming_forever(1)
90
  lm_gen.streaming_forever(1)
91
 
92
+ # Run warmup to initialize CUDA graphs (improves performance)
93
+ print("Running warmup...")
94
+ _warmup_models(mimi, other_mimi, lm_gen, frame_size)
95
+ print("Warmup complete.")
96
+
97
  _model_cache.update({
98
  "mimi": mimi,
99
  "other_mimi": other_mimi,
 
106
  return _model_cache
107
 
108
 
109
+ def _warmup_models(mimi, other_mimi, lm_gen, frame_size):
110
+ """Run warmup passes to initialize CUDA graphs."""
111
+ for _ in range(4):
112
+ chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=DEVICE)
113
+ codes = mimi.encode(chunk)
114
+ _ = other_mimi.encode(chunk)
115
+ for c in range(codes.shape[-1]):
116
+ tokens = lm_gen.step(codes[:, :, c:c+1])
117
+ if tokens is not None:
118
+ _ = mimi.decode(tokens[:, 1:9])
119
+ _ = other_mimi.decode(tokens[:, 1:9])
120
+ torch.cuda.synchronize()
121
+ # Reset after warmup
122
+ mimi.reset_streaming()
123
+ other_mimi.reset_streaming()
124
+ lm_gen.reset_streaming()
125
+
126
 
127
  def wrap_with_system_tags(text: str) -> str:
128
  """Add system tags as PersonaPlex expects."""
 
171
  import sphn
172
  audio = sphn.resample(audio, sr, mimi.sample_rate)
173
 
174
+ # PAD INPUT WITH SILENCE to give the model time to respond
175
+ # This is critical because PersonaPlex output duration = input duration
176
+ # Adding ~8 seconds of silence allows the model to complete its response
177
+ silence_duration = 8 # seconds
178
+ silence = np.zeros(int(silence_duration * mimi.sample_rate), dtype=np.float32)
179
+ audio = np.concatenate([audio, silence])
180
+
181
  # Add channel dimension: (T,) -> (1, T)
182
  if audio.ndim == 1:
183
  audio = audio[None, :]
 
229
  if text_token not in (0, 3): # Skip special tokens
230
  text_piece = text_tokenizer.id_to_piece(text_token).replace("▁", " ")
231
  generated_text.append(text_piece)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
  if not generated_frames:
234
  return None, "No audio generated. Try speaking more clearly."