thecollabagepatch commited on
Commit
c8f33f0
·
1 Parent(s): 3fbec97

i think we did it? gradio with fastrtc v1

Browse files
Files changed (3) hide show
  1. Dockerfile +2 -2
  2. app.py +27 -21
  3. fastrtc_magenta.py +228 -254
Dockerfile CHANGED
@@ -134,7 +134,7 @@ RUN uv pip install --system -c /tmp/constraints.txt \
134
  # Ensure compatible protobuf version
135
  RUN uv pip install --system --force-reinstall "protobuf>=5.27.0"
136
 
137
- # RUN uv pip install --system fastrtc
138
 
139
  # Set working directory and create cache
140
  WORKDIR /app
@@ -152,7 +152,7 @@ COPY lil_demo_540p.mp4 /app/
152
  COPY magentaRT_rt_tester.html /app/
153
  COPY magenta_prompts.js /app/
154
  COPY docs/ /app/docs/
155
- # COPY fastrtc_magenta.py /app/
156
 
157
  EXPOSE 7860
158
  CMD ["python", "-m", "uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
134
  # Ensure compatible protobuf version
135
  RUN uv pip install --system --force-reinstall "protobuf>=5.27.0"
136
 
137
+ RUN uv pip install --system fastrtc
138
 
139
  # Set working directory and create cache
140
  WORKDIR /app
 
152
  COPY magentaRT_rt_tester.html /app/
153
  COPY magenta_prompts.js /app/
154
  COPY docs/ /app/docs/
155
+ COPY fastrtc_magenta.py /app/
156
 
157
  EXPOSE 7860
158
  CMD ["python", "-m", "uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py CHANGED
@@ -1,5 +1,15 @@
1
  import os
2
 
 
 
 
 
 
 
 
 
 
 
3
  # ---- Space mode gating (place above any JAX import!) ----
4
  SPACE_MODE = os.getenv("SPACE_MODE")
5
  if SPACE_MODE is None:
@@ -1741,27 +1751,23 @@ async def ws_jam(websocket: WebSocket):
1741
  except Exception:
1742
  pass
1743
 
1744
- # --- FastRTC Gradio Integration (optional) ---
1745
- # if FASTRTC_AVAILABLE:
1746
- # try:
1747
- # magenta_stream = create_magenta_stream(
1748
- # get_mrt_fn=get_mrt,
1749
- # build_style_fn=build_style_vector,
1750
- # asset_manager=asset_manager,
1751
- # concurrency_limit=1, # Single GPU
1752
- # time_limit=3600, # 1 hour max per session
1753
- # )
1754
-
1755
- # # Mount FastRTC routes at /rtc
1756
- # magenta_stream.mount(app, path="/rtc")
1757
-
1758
- # # Mount Gradio UI at /gradio
1759
- # import gradio as gr
1760
- # app = gr.mount_gradio_app(app, magenta_stream.ui, path="/gradio")
1761
-
1762
- # print("✓ FastRTC Gradio interface available at /gradio")
1763
- # except Exception as e:
1764
- # print(f"FastRTC integration skipped: {e}")
1765
 
1766
 
1767
  @app.get("/ping")
 
1
  import os
2
 
3
+ # very top of app.py, before importing tensorflow/jax/magenta stuff
4
+ try:
5
+ import pylibsrtp
6
+ # This is the key: call srtp_init via cffi binding
7
+ from pylibsrtp import _binding
8
+ _binding.lib.srtp_init()
9
+ print("SRTP init OK (pre-TF)", flush=True)
10
+ except Exception as e:
11
+ print("SRTP init failed early:", e, flush=True)
12
+
13
  # ---- Space mode gating (place above any JAX import!) ----
14
  SPACE_MODE = os.getenv("SPACE_MODE")
15
  if SPACE_MODE is None:
 
1751
  except Exception:
1752
  pass
1753
 
1754
+ # --- FastRTC Gradio Integration (only in serve mode) ---
1755
+ if SPACE_MODE == "serve":
1756
+ try:
1757
+ from fastrtc_magenta import create_magenta_stream, FASTRTC_AVAILABLE
1758
+ if FASTRTC_AVAILABLE:
1759
+ magenta_stream = create_magenta_stream(
1760
+ get_mrt_fn=get_mrt,
1761
+ build_style_fn=build_style_vector,
1762
+ asset_manager=asset_manager,
1763
+ concurrency_limit=1,
1764
+ time_limit=3600,
1765
+ )
1766
+ magenta_stream.mount(app, path="/rtc")
1767
+ app = gr.mount_gradio_app(app, magenta_stream.ui, path="/gradio")
1768
+ print("✓ FastRTC Gradio interface available at /gradio")
1769
+ except Exception as e:
1770
+ print(f"⚠ FastRTC integration skipped: {e}")
 
 
 
 
1771
 
1772
 
1773
  @app.get("/ping")
fastrtc_magenta.py CHANGED
@@ -2,28 +2,24 @@
2
  FastRTC integration for MagentaRT real-time streaming.
3
 
4
  This module provides a Gradio-native interface for MagentaRT using FastRTC,
5
- enabling real-time audio streaming with live parameter updates through
6
- a proper Gradio UI.
7
-
8
- Usage:
9
- from fastrtc_magenta import create_magenta_stream
10
-
11
- # In your existing FastAPI app:
12
- magenta_stream = create_magenta_stream(get_mrt_fn=get_mrt)
13
- magenta_stream.mount(app, path="/rtc")
14
-
15
- # Or standalone:
16
- magenta_stream.ui.launch()
17
  """
18
 
 
 
19
  import numpy as np
20
  import gradio as gr
21
  from typing import Callable, Optional
22
- from dataclasses import dataclass, field
23
 
24
  # FastRTC imports
25
  try:
26
- from fastrtc import Stream, StreamHandler, AdditionalOutputs
27
  FASTRTC_AVAILABLE = True
28
  except ImportError:
29
  FASTRTC_AVAILABLE = False
@@ -32,7 +28,6 @@ except ImportError:
32
 
33
  @dataclass
34
  class MagentaRTParams:
35
- """Live-updatable parameters for MagentaRT generation."""
36
  temperature: float = 1.1
37
  guidance_weight: float = 1.1
38
  topk: int = 40
@@ -45,97 +40,175 @@ class MagentaRTParams:
45
 
46
  class MagentaRTStreamHandler(StreamHandler):
47
  """
48
- FastRTC StreamHandler for continuous MagentaRT audio generation.
49
-
50
- This handler generates ~2s audio chunks continuously, with support
51
- for live parameter updates via FastRTC's set_input mechanism.
52
-
53
- The MagentaRT system handles crossfading internally, so chunks
54
- can be played back-to-back without client-side processing.
55
  """
56
-
57
  def __init__(
58
  self,
59
  get_mrt_fn: Callable,
60
  build_style_fn: Callable,
61
  asset_manager=None,
62
  ):
63
- # MagentaRT outputs stereo 48kHz audio
64
  super().__init__(
65
  expected_layout="stereo",
66
  output_sample_rate=48000,
67
- input_sample_rate=48000, # Not used in receive-only mode
68
  )
69
-
70
  self.get_mrt_fn = get_mrt_fn
71
  self.build_style_fn = build_style_fn
72
  self.asset_manager = asset_manager
73
-
74
- # Will be initialized in start_up()
75
  self.mrt = None
76
  self.state = None
 
 
77
  self.style_cur = None
78
  self.style_tgt = None
79
- self.params = MagentaRTParams()
80
-
81
- # Track chunk timing for style ramping
82
- self.chunk_duration = 2.0 # Will be updated from mrt config
83
-
 
 
84
  def copy(self) -> "MagentaRTStreamHandler":
85
- """Create a fresh handler for each new connection."""
86
  return MagentaRTStreamHandler(
87
  get_mrt_fn=self.get_mrt_fn,
88
  build_style_fn=self.build_style_fn,
89
  asset_manager=self.asset_manager,
90
  )
91
-
 
 
 
 
92
  def start_up(self) -> None:
93
- """Initialize MagentaRT state when stream starts."""
94
  self.mrt = self.get_mrt_fn()
95
  self.state = self.mrt.init_state()
96
-
97
- # Calculate chunk duration from config
98
  codec_fps = float(self.mrt.codec.frame_rate)
99
  self.chunk_duration = (
100
- self.mrt.config.chunk_length_frames *
101
- self.mrt.config.frame_length_samples
102
  ) / float(self.mrt.sample_rate)
103
-
104
- # Build silent context (10s) tokens
 
 
105
  ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
106
  sr = int(self.mrt.sample_rate)
107
- samples = int(max(1, round(ctx_seconds * sr)))
108
-
109
- # Import here to avoid circular deps
110
- from magenta_rt import audio as au
111
-
112
- silent = au.Waveform(np.zeros((samples, 2), np.float32), sr)
113
  tokens = self.mrt.codec.encode(silent).astype(np.int32)
114
  tokens = tokens[:, :self.mrt.config.decoder_codec_rvq_depth]
115
  self.state.context_tokens = tokens
116
-
117
- # Ensure assets loaded for style building
118
  if self.asset_manager:
119
  self.asset_manager.ensure_assets_loaded(self.mrt)
120
-
121
- # Build initial style
122
  self._rebuild_style()
123
  self.style_cur = self.style_tgt.copy()
124
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def _rebuild_style(self) -> None:
126
- """Rebuild target style vector from current params."""
127
  text_list = [s.strip() for s in self.params.styles.split(",") if s.strip()]
128
-
129
  try:
130
  text_w = [float(x) for x in self.params.style_weights.split(",") if x.strip()]
131
  except ValueError:
132
  text_w = []
133
-
134
  try:
135
  cw = [float(x) for x in self.params.centroid_weights.split(",") if x.strip()]
136
  except ValueError:
137
  cw = []
138
-
139
  self.style_tgt = self.build_style_fn(
140
  self.mrt,
141
  text_styles=text_list,
@@ -145,15 +218,18 @@ class MagentaRTStreamHandler(StreamHandler):
145
  mean_weight=self.params.mean_weight,
146
  centroid_weights=cw,
147
  )
148
-
149
  def _apply_param_updates(self) -> None:
150
- """Check latest_args for parameter updates from Gradio UI."""
151
- # latest_args format: [webrtc_value, temp, guidance, topk, styles, style_weights, mean, centroids, ramp]
152
  args = self.latest_args
153
  if not args or len(args) < 2:
 
154
  return
155
-
156
- # Skip index 0 which is the dummy webrtc value
 
 
 
 
157
  try:
158
  if len(args) > 1 and args[1] is not None:
159
  self.params.temperature = float(args[1])
@@ -172,66 +248,89 @@ class MagentaRTStreamHandler(StreamHandler):
172
  if len(args) > 8 and args[8] is not None:
173
  self.params.style_ramp_seconds = float(args[8])
174
  except (ValueError, TypeError):
175
- pass # Ignore malformed updates
176
-
177
- # Apply to MRT
178
  self.mrt.temperature = self.params.temperature
179
  self.mrt.guidance_weight = self.params.guidance_weight
180
  self.mrt.topk = self.params.topk
181
-
182
- # Rebuild target style
183
- self._rebuild_style()
184
-
185
- def receive(self, frame: tuple[int, np.ndarray]) -> None:
186
- """
187
- Receive incoming audio frame.
188
-
189
- For MagentaRT rt-mode, we ignore input audio - this is output-only.
190
- In the future, we could use input for style conditioning or feedback.
191
- """
192
- pass
193
-
194
- def emit(self) -> Optional[tuple[int, np.ndarray]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  """
196
- Generate and emit the next audio chunk.
197
-
198
- Returns:
199
- Tuple of (sample_rate, audio_array) where audio_array has shape
200
- (num_channels, num_samples) for stereo output.
 
 
 
201
  """
202
- if self.mrt is None or self.state is None:
203
- return None
204
-
205
- # Check for parameter updates from UI
206
- self._apply_param_updates()
207
-
208
- # Ramp style toward target
209
- ramp = self.params.style_ramp_seconds
210
- if ramp <= 0.0:
211
- self.style_cur = self.style_tgt
212
- else:
213
- step = min(1.0, self.chunk_duration / ramp)
214
- self.style_cur = self.style_cur + step * (self.style_tgt - self.style_cur)
215
-
216
- # Generate chunk (crossfading handled internally by MagentaRT)
217
- wav, new_state = self.mrt.generate_chunk(
218
- state=self.state,
219
- style=self.style_cur
220
- )
221
- self.state = new_state
222
-
223
- # Convert to FastRTC format: (sample_rate, array)
224
- # FastRTC expects shape (num_channels, num_samples) for stereo
225
- # MagentaRT outputs shape (num_samples, num_channels)
226
- audio = wav.samples.astype(np.float32)
227
- audio = audio.T # Transpose to (channels, samples)
228
-
229
- return (int(self.mrt.sample_rate), audio)
230
-
231
- def shutdown(self) -> None:
232
- """Clean up when stream closes."""
233
- self.mrt = None
234
- self.state = None
235
 
236
 
237
  def create_magenta_stream(
@@ -241,156 +340,31 @@ def create_magenta_stream(
241
  concurrency_limit: int = 1,
242
  time_limit: Optional[float] = None,
243
  ) -> "Stream":
244
- """
245
- Create a FastRTC Stream for MagentaRT real-time generation.
246
-
247
- Args:
248
- get_mrt_fn: Function that returns the MagentaRT instance
249
- build_style_fn: Function to build style vectors (build_style_vector from app.py)
250
- asset_manager: Optional AssetManager for finetune steering
251
- concurrency_limit: Max concurrent streams (default 1 for single GPU)
252
- time_limit: Optional max stream duration in seconds
253
-
254
- Returns:
255
- FastRTC Stream object that can be mounted or launched
256
-
257
- Example:
258
- stream = create_magenta_stream(get_mrt, build_style_vector, asset_manager)
259
- stream.ui.launch() # Standalone Gradio UI
260
- # or
261
- stream.mount(app, path="/rtc") # Mount on existing FastAPI
262
- """
263
  if not FASTRTC_AVAILABLE:
264
  raise ImportError("FastRTC not installed. Run: pip install fastrtc")
265
-
266
  handler = MagentaRTStreamHandler(
267
  get_mrt_fn=get_mrt_fn,
268
  build_style_fn=build_style_fn,
269
  asset_manager=asset_manager,
270
  )
271
-
272
  stream = Stream(
273
  handler=handler,
274
  modality="audio",
275
- mode="receive", # Server-to-client only (we generate, client listens)
276
  concurrency_limit=concurrency_limit,
277
  time_limit=time_limit,
278
  additional_inputs=[
279
- gr.Slider(
280
- minimum=0.1, maximum=2.0, step=0.01, value=1.1,
281
- label="Temperature",
282
- info="Higher = more random, lower = more deterministic"
283
- ),
284
- gr.Slider(
285
- minimum=0.0, maximum=8.0, step=0.1, value=1.1,
286
- label="Guidance Weight",
287
- info="How strongly to follow the style"
288
- ),
289
- gr.Slider(
290
- minimum=1, maximum=256, step=1, value=40,
291
- label="Top-K",
292
- info="Number of token candidates to sample from"
293
- ),
294
- gr.Textbox(
295
- value="warmup",
296
- label="Styles",
297
- info="Comma-separated style prompts (e.g., 'acid house, dreamy pads')"
298
- ),
299
- gr.Textbox(
300
- value="1.0",
301
- label="Style Weights",
302
- info="Comma-separated weights for each style"
303
- ),
304
- gr.Slider(
305
- minimum=0.0, maximum=2.0, step=0.01, value=0.0,
306
- label="Mean Weight",
307
- info="Weight for finetune mean embedding (if available)"
308
- ),
309
- gr.Textbox(
310
- value="",
311
- label="Centroid Weights",
312
- info="Comma-separated weights for finetune centroids"
313
- ),
314
- gr.Slider(
315
- minimum=0.0, maximum=10.0, step=0.1, value=2.0,
316
- label="Style Ramp (seconds)",
317
- info="How long to transition between style changes"
318
- ),
319
- ],
320
- )
321
-
322
- return stream
323
-
324
-
325
- # -----------------------------------------------------------------------------
326
- # Alternative: Simpler generator-based approach (if StreamHandler is overkill)
327
- # -----------------------------------------------------------------------------
328
-
329
- def create_simple_magenta_stream(
330
- get_mrt_fn: Callable,
331
- build_style_fn: Callable,
332
- asset_manager=None,
333
- ) -> "Stream":
334
- """
335
- Simpler generator-based MagentaRT stream.
336
-
337
- This approach is less flexible but easier to understand.
338
- Parameter updates won't work as smoothly - they'll only apply
339
- when a new stream starts.
340
- """
341
- if not FASTRTC_AVAILABLE:
342
- raise ImportError("FastRTC not installed. Run: pip install fastrtc")
343
-
344
- def generate_audio(
345
- temperature: float = 1.1,
346
- guidance: float = 1.1,
347
- topk: int = 40,
348
- styles: str = "warmup",
349
- ):
350
- """Generator that yields MagentaRT audio chunks."""
351
- from magenta_rt import audio as au
352
-
353
- mrt = get_mrt_fn()
354
- state = mrt.init_state()
355
-
356
- # Set params
357
- mrt.temperature = temperature
358
- mrt.guidance_weight = guidance
359
- mrt.topk = topk
360
-
361
- # Build silent context
362
- codec_fps = float(mrt.codec.frame_rate)
363
- ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
364
- sr = int(mrt.sample_rate)
365
- samples = int(max(1, round(ctx_seconds * sr)))
366
- silent = au.Waveform(np.zeros((samples, 2), np.float32), sr)
367
- tokens = mrt.codec.encode(silent).astype(np.int32)
368
- tokens = tokens[:, :mrt.config.decoder_codec_rvq_depth]
369
- state.context_tokens = tokens
370
-
371
- # Build style
372
- if asset_manager:
373
- asset_manager.ensure_assets_loaded(mrt)
374
- text_list = [s.strip() for s in styles.split(",") if s.strip()]
375
- style = build_style_fn(mrt, text_styles=text_list)
376
-
377
- # Generate forever
378
- while True:
379
- wav, state = mrt.generate_chunk(state=state, style=style)
380
- audio = wav.samples.astype(np.float32).T
381
- yield (int(mrt.sample_rate), audio)
382
-
383
- stream = Stream(
384
- handler=generate_audio,
385
- modality="audio",
386
- mode="receive",
387
- concurrency_limit=1,
388
- additional_inputs=[
389
- gr.Slider(0.1, 2.0, value=1.1, label="Temperature"),
390
- gr.Slider(0.0, 8.0, value=1.1, label="Guidance"),
391
- gr.Slider(1, 256, value=40, step=1, label="Top-K"),
392
  gr.Textbox(value="warmup", label="Styles"),
 
 
 
 
393
  ],
394
  )
395
-
396
- return stream
 
2
  FastRTC integration for MagentaRT real-time streaming.
3
 
4
  This module provides a Gradio-native interface for MagentaRT using FastRTC,
5
+ enabling real-time audio streaming with live parameter updates.
6
+
7
+ Key notes:
8
+ - MagentaRT system handles crossfading internally.
9
+ - Many FastRTC builds assume mono in the outgoing PyAV path; we downmix to mono
10
+ int16 for now (easy to switch once FastRTC stereo output is patched).
 
 
 
 
 
 
11
  """
12
 
13
+ from __future__ import annotations
14
+
15
  import numpy as np
16
  import gradio as gr
17
  from typing import Callable, Optional
18
+ from dataclasses import dataclass
19
 
20
  # FastRTC imports
21
  try:
22
+ from fastrtc import Stream, StreamHandler
23
  FASTRTC_AVAILABLE = True
24
  except ImportError:
25
  FASTRTC_AVAILABLE = False
 
28
 
29
  @dataclass
30
  class MagentaRTParams:
 
31
  temperature: float = 1.1
32
  guidance_weight: float = 1.1
33
  topk: int = 40
 
40
 
41
  class MagentaRTStreamHandler(StreamHandler):
42
  """
43
+ StreamHandler for continuous MagentaRT audio generation (server -> client).
44
+
45
+ FastRTC versions differ in how they consume handlers; some require emit()
46
+ (abstract), so we implement emit() as the canonical “produce next frame” API.
47
+
48
+ We also keep __call__ as a generator adapter because some versions call that.
 
49
  """
50
+
51
  def __init__(
52
  self,
53
  get_mrt_fn: Callable,
54
  build_style_fn: Callable,
55
  asset_manager=None,
56
  ):
 
57
  super().__init__(
58
  expected_layout="stereo",
59
  output_sample_rate=48000,
60
+ input_sample_rate=48000,
61
  )
62
+
63
  self.get_mrt_fn = get_mrt_fn
64
  self.build_style_fn = build_style_fn
65
  self.asset_manager = asset_manager
66
+
 
67
  self.mrt = None
68
  self.state = None
69
+
70
+ self.params = MagentaRTParams()
71
  self.style_cur = None
72
  self.style_tgt = None
73
+
74
+ self.chunk_duration = 2.0
75
+ self.latest_args = None
76
+
77
+ # Internal generator used by emit()
78
+ self._gen = None
79
+
80
  def copy(self) -> "MagentaRTStreamHandler":
 
81
  return MagentaRTStreamHandler(
82
  get_mrt_fn=self.get_mrt_fn,
83
  build_style_fn=self.build_style_fn,
84
  asset_manager=self.asset_manager,
85
  )
86
+
87
+ # -------------------------------------------------------------------------
88
+ # Lifecycle
89
+ # -------------------------------------------------------------------------
90
+
91
  def start_up(self) -> None:
92
+ """Initialize MagentaRT + state."""
93
  self.mrt = self.get_mrt_fn()
94
  self.state = self.mrt.init_state()
95
+
96
+ # Compute chunk duration from MRT config
97
  codec_fps = float(self.mrt.codec.frame_rate)
98
  self.chunk_duration = (
99
+ self.mrt.config.chunk_length_frames * self.mrt.config.frame_length_samples
 
100
  ) / float(self.mrt.sample_rate)
101
+
102
+ # Build silent context tokens
103
+ from magenta_rt import audio as au
104
+
105
  ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
106
  sr = int(self.mrt.sample_rate)
107
+ n = int(max(1, round(ctx_seconds * sr)))
108
+
109
+ silent = au.Waveform(np.zeros((n, 2), np.float32), sr)
 
 
 
110
  tokens = self.mrt.codec.encode(silent).astype(np.int32)
111
  tokens = tokens[:, :self.mrt.config.decoder_codec_rvq_depth]
112
  self.state.context_tokens = tokens
113
+
114
+ # Load assets if needed
115
  if self.asset_manager:
116
  self.asset_manager.ensure_assets_loaded(self.mrt)
117
+
118
+ # Initial style
119
  self._rebuild_style()
120
  self.style_cur = self.style_tgt.copy()
121
+
122
+ # Create internal generator for emit()
123
+ self._gen = self._generate_forever()
124
+
125
+ def shutdown(self) -> None:
126
+ self.mrt = None
127
+ self.state = None
128
+ self.style_cur = None
129
+ self.style_tgt = None
130
+ self.latest_args = None
131
+ self._gen = None
132
+
133
+ # -------------------------------------------------------------------------
134
+ # FastRTC entrypoints
135
+ # -------------------------------------------------------------------------
136
+
137
+ def __call__(self, *args):
138
+ """
139
+ Some FastRTC versions call handler(*ui_args) and expect a generator.
140
+ We provide that by yielding emit() forever.
141
+ """
142
+ self.latest_args = [None, *args]
143
+ self.start_up()
144
+ try:
145
+ while True:
146
+ out = self.emit()
147
+ if out is None:
148
+ continue
149
+ yield out
150
+ finally:
151
+ self.shutdown()
152
+
153
+ def emit(self):
154
+ """
155
+ REQUIRED by some FastRTC versions (abstract method).
156
+ Produce the next (sample_rate, audio) chunk.
157
+ """
158
+ if self._gen is None:
159
+ # If FastRTC calls emit() without calling __call__ first,
160
+ # we still need to be able to start up.
161
+ self.latest_args = self.latest_args or [None]
162
+ self.start_up()
163
+
164
+ try:
165
+ return next(self._gen)
166
+ except StopIteration:
167
+ return None
168
+
169
+ def receive(self, frame: tuple[int, np.ndarray]) -> None:
170
+ # output-only mode
171
+ return
172
+
173
+ # -------------------------------------------------------------------------
174
+ # Core generation loop
175
+ # -------------------------------------------------------------------------
176
+
177
+ def _generate_forever(self):
178
+ """Internal generator that yields audio chunks forever."""
179
+ while True:
180
+ self._apply_param_updates()
181
+ self._ramp_style()
182
+
183
+ wav, self.state = self.mrt.generate_chunk(state=self.state, style=self.style_cur)
184
+
185
+ samples = np.asarray(wav.samples)
186
+ if samples.dtype != np.float32:
187
+ samples = samples.astype(np.float32, copy=False)
188
+
189
+ # Ensure stereo planar float32: (2, N)
190
+ audio_stereo = self._ensure_stereo_planar(samples)
191
+
192
+ # Return (sr, ndarray, layout) so FastRTC sets layout properly
193
+ yield (48000, audio_stereo, "stereo")
194
+
195
+ # -------------------------------------------------------------------------
196
+ # Params + style
197
+ # -------------------------------------------------------------------------
198
+
199
  def _rebuild_style(self) -> None:
 
200
  text_list = [s.strip() for s in self.params.styles.split(",") if s.strip()]
201
+
202
  try:
203
  text_w = [float(x) for x in self.params.style_weights.split(",") if x.strip()]
204
  except ValueError:
205
  text_w = []
206
+
207
  try:
208
  cw = [float(x) for x in self.params.centroid_weights.split(",") if x.strip()]
209
  except ValueError:
210
  cw = []
211
+
212
  self.style_tgt = self.build_style_fn(
213
  self.mrt,
214
  text_styles=text_list,
 
218
  mean_weight=self.params.mean_weight,
219
  centroid_weights=cw,
220
  )
221
+
222
  def _apply_param_updates(self) -> None:
 
 
223
  args = self.latest_args
224
  if not args or len(args) < 2:
225
+ # no UI args yet
226
  return
227
+
228
+ prev_styles = self.params.styles
229
+ prev_style_weights = self.params.style_weights
230
+ prev_mean = self.params.mean_weight
231
+ prev_centroids = self.params.centroid_weights
232
+
233
  try:
234
  if len(args) > 1 and args[1] is not None:
235
  self.params.temperature = float(args[1])
 
248
  if len(args) > 8 and args[8] is not None:
249
  self.params.style_ramp_seconds = float(args[8])
250
  except (ValueError, TypeError):
251
+ return
252
+
253
+ # Apply sampler params
254
  self.mrt.temperature = self.params.temperature
255
  self.mrt.guidance_weight = self.params.guidance_weight
256
  self.mrt.topk = self.params.topk
257
+
258
+ style_changed = (
259
+ self.params.styles != prev_styles or
260
+ self.params.style_weights != prev_style_weights or
261
+ self.params.mean_weight != prev_mean or
262
+ self.params.centroid_weights != prev_centroids
263
+ )
264
+ if style_changed:
265
+ self._rebuild_style()
266
+
267
+ def _ramp_style(self) -> None:
268
+ if self.style_cur is None or self.style_tgt is None:
269
+ return
270
+
271
+ ramp = float(self.params.style_ramp_seconds or 0.0)
272
+ if ramp <= 0.0:
273
+ self.style_cur = self.style_tgt.copy()
274
+ return
275
+
276
+ alpha = min(1.0, max(0.0, self.chunk_duration / ramp))
277
+ self.style_cur = (1.0 - alpha) * self.style_cur + alpha * self.style_tgt
278
+
279
+ # -------------------------------------------------------------------------
280
+ # Audio helpers
281
+ # -------------------------------------------------------------------------
282
+
283
+ @staticmethod
284
+ def _downmix_to_mono(samples: np.ndarray) -> np.ndarray:
285
+ if samples.ndim == 1:
286
+ return samples
287
+ if samples.ndim == 2:
288
+ # assume (num_samples, channels)
289
+ if samples.shape[1] == 1:
290
+ return samples[:, 0]
291
+ return samples.mean(axis=1)
292
+ return samples.reshape(-1)
293
+
294
+ @staticmethod
295
+ def _ensure_stereo_planar(samples: np.ndarray) -> np.ndarray:
296
  """
297
+ Convert waveform samples into PyAV-friendly planar audio (C-contiguous).
298
+
299
+ PyAV expects planar audio for format="fltp":
300
+ - mono: (1, N)
301
+ - stereo: (2, N)
302
+
303
+ We also MUST return a C-contiguous ndarray, or PyAV will raise:
304
+ ValueError: ndarray is not C-contiguous
305
  """
306
+ x = np.asarray(samples, dtype=np.float32)
307
+
308
+ # Mono 1D -> (1, N)
309
+ if x.ndim == 1:
310
+ return np.ascontiguousarray(x.reshape(1, -1))
311
+
312
+ # (N, 1) -> (1, N)
313
+ if x.ndim == 2 and x.shape[1] == 1:
314
+ return np.ascontiguousarray(x[:, 0].reshape(1, -1))
315
+
316
+ # Interleaved stereo (N, 2) -> planar (2, N)
317
+ if x.ndim == 2 and x.shape[1] == 2:
318
+ # x.T is typically non-contiguous, so force contiguous
319
+ return np.ascontiguousarray(x.T)
320
+
321
+ # Already planar stereo (2, N) -> ensure contiguous anyway
322
+ if x.ndim == 2 and x.shape[0] == 2:
323
+ return np.ascontiguousarray(x)
324
+
325
+ # Fallback: flatten to mono
326
+ return np.ascontiguousarray(x.reshape(-1).reshape(1, -1))
327
+
328
+
329
+ @staticmethod
330
+ def _float_to_int16(x: np.ndarray) -> np.ndarray:
331
+ x = np.asarray(x, dtype=np.float32)
332
+ x = np.clip(x, -1.0, 1.0)
333
+ return (x * 32767.0).astype(np.int16)
 
 
 
 
 
334
 
335
 
336
  def create_magenta_stream(
 
340
  concurrency_limit: int = 1,
341
  time_limit: Optional[float] = None,
342
  ) -> "Stream":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  if not FASTRTC_AVAILABLE:
344
  raise ImportError("FastRTC not installed. Run: pip install fastrtc")
345
+
346
  handler = MagentaRTStreamHandler(
347
  get_mrt_fn=get_mrt_fn,
348
  build_style_fn=build_style_fn,
349
  asset_manager=asset_manager,
350
  )
351
+
352
  stream = Stream(
353
  handler=handler,
354
  modality="audio",
355
+ mode="receive",
356
  concurrency_limit=concurrency_limit,
357
  time_limit=time_limit,
358
  additional_inputs=[
359
+ gr.Slider(0.1, 2.0, step=0.01, value=1.1, label="Temperature"),
360
+ gr.Slider(0.0, 8.0, step=0.1, value=1.1, label="Guidance Weight"),
361
+ gr.Slider(1, 256, step=1, value=40, label="Top-K"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  gr.Textbox(value="warmup", label="Styles"),
363
+ gr.Textbox(value="1.0", label="Style Weights"),
364
+ gr.Slider(0.0, 2.0, step=0.01, value=0.0, label="Mean Weight"),
365
+ gr.Textbox(value="", label="Centroid Weights"),
366
+ gr.Slider(0.0, 10.0, step=0.1, value=2.0, label="Style Ramp (seconds)"),
367
  ],
368
  )
369
+
370
+ return stream