Veena commited on
Commit
e5b76b7
·
1 Parent(s): d1c3c57

Remove maya1 directory (using transformers)

Browse files
maya1/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- """
2
- Maya1 TTS Inference System
3
- Open-source inference for description-conditioned TTS with emotion control.
4
- """
5
-
6
- __version__ = "1.0.0"
7
- __author__ = "Maya Research AI"
 
 
 
 
 
 
 
 
maya1/api_v2.py DELETED
@@ -1,342 +0,0 @@
1
- import os
2
- import io
3
- import wave
4
- import time
5
- from typing import Optional
6
- from fastapi import FastAPI, HTTPException
7
- from fastapi.responses import StreamingResponse
8
- from fastapi.middleware.cors import CORSMiddleware
9
- from pydantic import BaseModel, Field
10
- from dotenv import load_dotenv
11
-
12
- from .model_loader import Maya1Model
13
- from .prompt_builder import Maya1PromptBuilder
14
- from .snac_decoder import SNACDecoder
15
- from .pipeline import Maya1Pipeline
16
- from .streaming_pipeline import Maya1SlidingWindowPipeline
17
- from .constants import (
18
- DEFAULT_TEMPERATURE,
19
- DEFAULT_TOP_P,
20
- DEFAULT_MAX_TOKENS,
21
- DEFAULT_REPETITION_PENALTY,
22
- AUDIO_SAMPLE_RATE,
23
- )
24
-
25
- # Timeout settings (seconds)
26
- GENERATE_TIMEOUT = 60
27
-
28
- # Load environment variables
29
- load_dotenv()
30
-
31
- # Initialize FastAPI app
32
- app = FastAPI(
33
- title="Maya1 TTS API",
34
- description="Open source TTS inference for Maya1",
35
- version="1.0.0",
36
- docs_url=None,
37
- redoc_url=None,
38
- )
39
-
40
- app.add_middleware(
41
- CORSMiddleware,
42
- allow_origins=["*"],
43
- allow_credentials=True,
44
- allow_methods=["*"],
45
- allow_headers=["*"],
46
- )
47
-
48
- # Global state
49
- model = None
50
- prompt_builder = None
51
- snac_decoder = None
52
- pipeline = None
53
- streaming_pipeline = None
54
-
55
-
56
- # ============================================================================
57
- # Startup/Shutdown
58
- # ============================================================================
59
-
60
- @app.on_event("startup")
61
- async def startup_event():
62
- """Initialize model on startup."""
63
- global model, prompt_builder, snac_decoder, pipeline, streaming_pipeline
64
-
65
- print("\n" + "="*60)
66
- print(" Starting Maya1 TTS API Server")
67
- print("="*60 + "\n")
68
-
69
- # Initialize components
70
- model = Maya1Model()
71
- prompt_builder = Maya1PromptBuilder(model.tokenizer, model)
72
-
73
- # Initialize SNAC decoder
74
- snac_decoder = SNACDecoder(enable_batching=True, max_batch_size=64, batch_timeout_ms=15)
75
- await snac_decoder.start_batch_processor()
76
-
77
- # Initialize pipelines
78
- pipeline = Maya1Pipeline(model, prompt_builder, snac_decoder)
79
- streaming_pipeline = Maya1SlidingWindowPipeline(model, prompt_builder, snac_decoder)
80
-
81
- print("\n" + "="*60)
82
- print("Maya1 TTS API Server Ready")
83
- print("="*60 + "\n")
84
-
85
-
86
- @app.on_event("shutdown")
87
- async def shutdown_event():
88
- """Cleanup on shutdown."""
89
- print("\nShutting down Maya1 TTS API Server")
90
-
91
- if snac_decoder and snac_decoder.is_running:
92
- await snac_decoder.stop_batch_processor()
93
-
94
-
95
- # ============================================================================
96
- # Utility Functions
97
- # ============================================================================
98
-
99
- def create_wav_header(sample_rate: int = 24000, channels: int = 1, bits_per_sample: int = 16, data_size: int = 0) -> bytes:
100
- """Create WAV file header."""
101
- import struct
102
-
103
- byte_rate = sample_rate * channels * bits_per_sample // 8
104
- block_align = channels * bits_per_sample // 8
105
-
106
- header = struct.pack(
107
- '<4sI4s4sIHHIIHH4sI',
108
- b'RIFF',
109
- 36 + data_size,
110
- b'WAVE',
111
- b'fmt ',
112
- 16,
113
- 1,
114
- channels,
115
- sample_rate,
116
- byte_rate,
117
- block_align,
118
- bits_per_sample,
119
- b'data',
120
- data_size
121
- )
122
-
123
- return header
124
-
125
-
126
- # ============================================================================
127
- # Request/Response Models
128
- # ============================================================================
129
-
130
- class TTSRequest(BaseModel):
131
- """TTS generation request."""
132
- description: str = Field(
133
- ...,
134
- description="Voice description (e.g., 'Male voice in their 30s with american accent')"
135
- )
136
- text: str = Field(
137
- ...,
138
- description="Text to synthesize (can include <emotion> tags)"
139
- )
140
- temperature: Optional[float] = Field(
141
- default=DEFAULT_TEMPERATURE,
142
- description="Sampling temperature"
143
- )
144
- top_p: Optional[float] = Field(
145
- default=DEFAULT_TOP_P,
146
- description="Nucleus sampling"
147
- )
148
- max_tokens: Optional[int] = Field(
149
- default=DEFAULT_MAX_TOKENS,
150
- description="Maximum tokens to generate"
151
- )
152
- repetition_penalty: Optional[float] = Field(
153
- default=DEFAULT_REPETITION_PENALTY,
154
- description="Repetition penalty"
155
- )
156
- seed: Optional[int] = Field(
157
- default=None,
158
- description="Random seed for reproducibility",
159
- ge=0,
160
- )
161
- stream: bool = Field(
162
- default=False,
163
- description="Stream audio (True) or return complete WAV (False)"
164
- )
165
-
166
-
167
- # ============================================================================
168
- # Endpoints
169
- # ============================================================================
170
-
171
- @app.get("/")
172
- async def root():
173
- """Root endpoint."""
174
- return {
175
- "service": "Maya1 TTS API",
176
- "version": "1.0.0",
177
- "status": "running",
178
- "model": "Maya1-Voice (open source)",
179
- "endpoints": {
180
- "generate": "/v1/tts/generate (POST)",
181
- "health": "/health (GET)",
182
- },
183
- }
184
-
185
-
186
- @app.get("/health")
187
- async def health_check():
188
- """Health check endpoint."""
189
- return {
190
- "status": "healthy",
191
- "model": "Maya1-Voice",
192
- "timestamp": time.time(),
193
- }
194
-
195
-
196
- # ============================================================================
197
- # TTS Generation Endpoint
198
- # ============================================================================
199
-
200
- @app.post("/v1/tts/generate")
201
- async def generate_tts(request: TTSRequest):
202
- """Generate TTS audio from description and text."""
203
-
204
- try:
205
- # Route to streaming or non-streaming
206
- if request.stream:
207
- return await _generate_tts_streaming(
208
- description=request.description,
209
- text=request.text,
210
- temperature=request.temperature,
211
- top_p=request.top_p,
212
- max_tokens=request.max_tokens,
213
- repetition_penalty=request.repetition_penalty,
214
- seed=request.seed,
215
- )
216
- else:
217
- return await _generate_tts_complete(
218
- description=request.description,
219
- text=request.text,
220
- temperature=request.temperature,
221
- top_p=request.top_p,
222
- max_tokens=request.max_tokens,
223
- repetition_penalty=request.repetition_penalty,
224
- seed=request.seed,
225
- )
226
-
227
- except HTTPException:
228
- raise
229
- except Exception as e:
230
- print(f" Error: {e}")
231
- raise HTTPException(status_code=500, detail=str(e))
232
-
233
-
234
- async def _generate_tts_complete(
235
- description: str,
236
- text: str,
237
- temperature: float,
238
- top_p: float,
239
- max_tokens: int,
240
- repetition_penalty: float,
241
- seed: Optional[int],
242
- ):
243
- """Generate complete WAV file (non-streaming)."""
244
-
245
- try:
246
- import asyncio
247
-
248
- # Generate audio
249
- audio_bytes = await asyncio.wait_for(
250
- pipeline.generate_speech(
251
- description=description,
252
- text=text,
253
- temperature=temperature,
254
- top_p=top_p,
255
- max_tokens=max_tokens,
256
- repetition_penalty=repetition_penalty,
257
- seed=seed,
258
- ),
259
- timeout=GENERATE_TIMEOUT
260
- )
261
-
262
- if audio_bytes is None:
263
- raise Exception("Audio generation failed")
264
-
265
- # Create WAV file
266
- wav_buffer = io.BytesIO()
267
- with wave.open(wav_buffer, 'wb') as wav_file:
268
- wav_file.setnchannels(1)
269
- wav_file.setsampwidth(2)
270
- wav_file.setframerate(AUDIO_SAMPLE_RATE)
271
- wav_file.writeframes(audio_bytes)
272
-
273
- wav_buffer.seek(0)
274
-
275
- return StreamingResponse(
276
- wav_buffer,
277
- media_type="audio/wav",
278
- headers={"Content-Disposition": "attachment; filename=output.wav"}
279
- )
280
-
281
- except asyncio.TimeoutError:
282
- raise HTTPException(status_code=504, detail="Generation timeout")
283
-
284
-
285
- async def _generate_tts_streaming(
286
- description: str,
287
- text: str,
288
- temperature: float,
289
- top_p: float,
290
- max_tokens: int,
291
- repetition_penalty: float,
292
- seed: Optional[int],
293
- ):
294
- """Generate streaming audio."""
295
- start_time = time.time()
296
- first_audio_time = None
297
-
298
- async def audio_stream_generator():
299
- """Generate audio stream with WAV header."""
300
- nonlocal first_audio_time
301
-
302
- # Send WAV header first
303
- yield create_wav_header(sample_rate=AUDIO_SAMPLE_RATE, channels=1, bits_per_sample=16)
304
-
305
- # Stream audio chunks
306
- async for audio_chunk in streaming_pipeline.generate_speech_stream(
307
- description=description,
308
- text=text,
309
- temperature=temperature,
310
- top_p=top_p,
311
- max_tokens=max_tokens,
312
- repetition_penalty=repetition_penalty,
313
- seed=seed,
314
- ):
315
- if first_audio_time is None:
316
- first_audio_time = time.time()
317
- ttfb_ms = (first_audio_time - start_time) * 1000
318
- print(f"⏱️ TTFB: {ttfb_ms:.1f}ms")
319
-
320
- yield audio_chunk
321
-
322
- try:
323
- return StreamingResponse(
324
- audio_stream_generator(),
325
- media_type="audio/wav",
326
- headers={"Cache-Control": "no-cache"}
327
- )
328
-
329
- except Exception as e:
330
- print(f"Streaming error: {e}")
331
- raise HTTPException(status_code=500, detail=str(e))
332
-
333
-
334
- # For running directly
335
- if __name__ == "__main__":
336
- import uvicorn
337
- uvicorn.run(
338
- app,
339
- host="0.0.0.0",
340
- port=8000,
341
- log_level="info"
342
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
maya1/constants.py DELETED
@@ -1,95 +0,0 @@
1
- """
2
- Maya1 Constants
3
- Token IDs and special tokens used in the model.
4
- Matches training configuration exactly.
5
- """
6
-
7
- # Special control tokens
8
- SOH_ID = 128259 # Start of Human turn
9
- EOH_ID = 128260 # End of Human turn
10
- SOA_ID = 128261 # Start of AI turn
11
- EOA_ID = 128262 # End of AI turn (not used in maya1)
12
- PAD_ID = 128263 # Padding token
13
-
14
- # Text tokens
15
- BOS_ID = 128000 # Begin of sequence (Llama BOS)
16
- TEXT_EOT_ID = 128009 # End of text (appears in prefix, not a stop token!)
17
-
18
- # Audio tokens
19
- CODE_START_TOKEN_ID = 128257 # SOS - Start of Speech
20
- CODE_END_TOKEN_ID = 128258 # EOS - End of Speech (audio stop token)
21
- CODE_TOKEN_OFFSET = 128266 # Start of SNAC codes
22
-
23
- # SNAC token range
24
- SNAC_MIN_ID = 128266
25
- SNAC_MAX_ID = 156937 # 128266 + (7 * 4096) - 1
26
-
27
- # Stop tokens for generation
28
- # CRITICAL: Only use CODE_END_TOKEN_ID (128258) for audio generation
29
- # TEXT_EOT_ID (128009) appears in prefix and should NOT stop generation
30
- TRAINING_STOP_TOKEN_IDS = [CODE_END_TOKEN_ID] # [128258]
31
- ALL_POSSIBLE_STOP_TOKENS = [TEXT_EOT_ID, CODE_END_TOKEN_ID] # For reference only
32
-
33
- # 20 Extended Emotion Tags (must be single tokens)
34
- ALL_EMOTION_TAGS = [
35
- '<angry>',
36
- '<appalled>',
37
- '<chuckle>',
38
- '<cry>',
39
- '<curious>',
40
- '<disappointed>',
41
- '<excited>',
42
- '<exhale>',
43
- '<gasp>',
44
- '<giggle>',
45
- '<gulp>',
46
- '<laugh>',
47
- '<laugh_harder>',
48
- '<mischievous>',
49
- '<sarcastic>',
50
- '<scream>',
51
- '<sigh>',
52
- '<sing>',
53
- '<snort>',
54
- '<whisper>',
55
- ]
56
-
57
- # Model configuration
58
- DEFAULT_MODEL_PATH = "maya-research/maya1"
59
- DEFAULT_CHECKPOINT = "checkpoint-25000"
60
- DEFAULT_MAX_MODEL_LEN = 8192
61
-
62
- # SNAC configuration
63
- SNAC_MODEL_NAME = "hubertsiuzdak/snac_24khz"
64
- SNAC_SAMPLE_RATE = 24000
65
- SNAC_TOKENS_PER_FRAME = 7
66
- SNAC_LEVELS = 3
67
-
68
- # Audio configuration
69
- AUDIO_SAMPLE_RATE = 24000
70
- AUDIO_CHANNELS = 1
71
- AUDIO_BITS_PER_SAMPLE = 16
72
-
73
- # Generation defaults
74
- DEFAULT_TEMPERATURE = 0.4 # Lower temp for more stable generation
75
- DEFAULT_TOP_P = 0.9
76
- DEFAULT_MAX_TOKENS = 2048 # Reasonable default for most use cases
77
- DEFAULT_MIN_TOKENS = 28 # At least 4 SNAC frames
78
- DEFAULT_REPETITION_PENALTY = 1.1
79
- DEFAULT_SEED = None # None = random, set integer for reproducibility
80
-
81
- # IMPORTANT: Emotion tags consume audio time!
82
- # <laugh> = ~4-6 seconds (~300-400 tokens)
83
- # <excited>, <chuckle> = ~1-2 seconds (~50-150 tokens)
84
-
85
- # Recommended max_tokens by use case:
86
- # - Short phrases (< 10 words): 150-250 tokens (~3-5s)
87
- # - Medium text (10-30 words): 250-500 tokens (~5-10s)
88
- # - Long text (30+ words): 500-1500 tokens (~10-30s)
89
- # - Very long text: 1500-2000 tokens (~30-42s)
90
- # Note: 1 second ≈ 48 tokens (7 tokens/frame * 6.86 frames/sec)
91
-
92
- # Streaming configuration
93
- STREAM_BUFFER_SIZE = 28 # 4 frames (process every 28 tokens)
94
- SNAC_BATCH_SIZE = 64
95
- SNAC_BATCH_TIMEOUT_MS = 15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
maya1/model_loader.py DELETED
@@ -1,145 +0,0 @@
1
- """
2
- Maya1 Model Loader
3
- Loads Maya1 model with vLLM engine and validates emotion tags.
4
- """
5
-
6
- import os
7
- from transformers import AutoTokenizer
8
- from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
9
- from .constants import (
10
- ALL_EMOTION_TAGS,
11
- DEFAULT_MAX_MODEL_LEN,
12
- SOH_ID, EOH_ID, SOA_ID, BOS_ID, TEXT_EOT_ID, CODE_START_TOKEN_ID,
13
- )
14
-
15
-
16
- class Maya1Model:
17
- """Maya1 TTS Model with vLLM inference engine."""
18
-
19
- def __init__(
20
- self,
21
- model_path: str = None,
22
- dtype: str = "bfloat16",
23
- max_model_len: int = DEFAULT_MAX_MODEL_LEN,
24
- gpu_memory_utilization: float = 0.85,
25
- tensor_parallel_size: int = 1,
26
- **engine_kwargs
27
- ):
28
- """
29
- Initialize Maya1 model with vLLM.
30
-
31
- Args:
32
- model_path: Path to checkpoint (local or HF repo)
33
- dtype: Model precision (bfloat16 recommended)
34
- max_model_len: Maximum sequence length
35
- gpu_memory_utilization: GPU memory fraction
36
- tensor_parallel_size: Number of GPUs
37
- """
38
- # Use provided path or environment variable or default
39
- if model_path is None:
40
- model_path = os.environ.get(
41
- 'MAYA1_MODEL_PATH',
42
- os.path.expanduser('~/models/maya1-voice')
43
- )
44
-
45
- self.model_path = model_path
46
- self.dtype = dtype
47
-
48
- print(f"Initializing Maya1 Model")
49
- print(f"Model: {model_path}")
50
-
51
- # Load tokenizer
52
- self.tokenizer = AutoTokenizer.from_pretrained(
53
- model_path,
54
- trust_remote_code=True,
55
- )
56
-
57
- print(f"Tokenizer loaded: {len(self.tokenizer)} tokens")
58
-
59
- # Validate emotion tags
60
- self._validate_emotion_tags()
61
-
62
- # Precompute special token strings
63
- self._init_special_tokens()
64
-
65
- # Initialize vLLM engine
66
- print(f"Initializing vLLM engine...")
67
- engine_args = AsyncEngineArgs(
68
- model=model_path,
69
- tokenizer=model_path,
70
- dtype=dtype,
71
- max_model_len=max_model_len,
72
- gpu_memory_utilization=gpu_memory_utilization,
73
- tensor_parallel_size=tensor_parallel_size,
74
- trust_remote_code=True,
75
- disable_log_stats=False,
76
- **engine_kwargs
77
- )
78
-
79
- self.engine = AsyncLLMEngine.from_engine_args(engine_args)
80
-
81
- print(f"Maya1 Model ready\n")
82
-
83
- def _validate_emotion_tags(self):
84
- """Validate that all 20 emotion tags are single tokens."""
85
- failed_tags = []
86
- for tag in ALL_EMOTION_TAGS:
87
- token_ids = self.tokenizer.encode(tag, add_special_tokens=False)
88
- if len(token_ids) != 1:
89
- failed_tags.append((tag, len(token_ids)))
90
-
91
- if failed_tags:
92
- print(f"ERROR: {len(failed_tags)} emotion tags are NOT single tokens!")
93
- raise AssertionError(f"Emotion tags validation failed")
94
-
95
- print(f"All {len(ALL_EMOTION_TAGS)} emotion tags validated")
96
-
97
- def _init_special_tokens(self):
98
- """Precompute special token strings for fast prefix building."""
99
- self.soh_token = self.tokenizer.decode([SOH_ID])
100
- self.bos_token = self.tokenizer.bos_token
101
- self.eot_token = self.tokenizer.decode([TEXT_EOT_ID])
102
- self.eoh_token = self.tokenizer.decode([EOH_ID])
103
- self.soa_token = self.tokenizer.decode([SOA_ID])
104
- self.sos_token = self.tokenizer.decode([CODE_START_TOKEN_ID])
105
-
106
- async def generate(self, prompt: str, sampling_params: SamplingParams):
107
- """
108
- Generate tokens from prompt (non-streaming).
109
- Args:
110
- prompt: Input prompt
111
- sampling_params: vLLM sampling parameters
112
- Returns:
113
- Generated output from vLLM
114
- """
115
- request_id = f"req_{id(prompt)}"
116
-
117
- # Collect results from async generator
118
- final_output = None
119
- async for output in self.engine.generate(
120
- prompt=prompt,
121
- sampling_params=sampling_params,
122
- request_id=request_id
123
- ):
124
- final_output = output
125
-
126
- return [final_output] if final_output else []
127
-
128
- async def generate_stream(self, prompt: str, sampling_params: SamplingParams):
129
- """
130
- Generate tokens from prompt (streaming).
131
- Args:
132
- prompt: Input prompt
133
- sampling_params: vLLM sampling parameters
134
- Yields:
135
- Generated outputs from vLLM
136
- """
137
- request_id = f"req_{id(prompt)}"
138
-
139
- # Stream from engine
140
- async for output in self.engine.generate(
141
- prompt=prompt,
142
- sampling_params=sampling_params,
143
- request_id=request_id
144
- ):
145
- yield output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
maya1/pipeline.py DELETED
@@ -1,128 +0,0 @@
1
- """
2
- Maya1 Generation Pipeline
3
- End-to-end pipeline for TTS generation (non-streaming).
4
- """
5
-
6
- import asyncio
7
- from typing import Optional, List
8
- from vllm import SamplingParams
9
-
10
- from .constants import (
11
- CODE_END_TOKEN_ID,
12
- CODE_START_TOKEN_ID,
13
- SNAC_MIN_ID,
14
- SNAC_MAX_ID,
15
- DEFAULT_TEMPERATURE,
16
- DEFAULT_TOP_P,
17
- DEFAULT_MAX_TOKENS,
18
- DEFAULT_MIN_TOKENS,
19
- DEFAULT_REPETITION_PENALTY,
20
- DEFAULT_SEED,
21
- )
22
-
23
-
24
- class Maya1Pipeline:
25
- """End-to-end TTS pipeline for Maya1."""
26
-
27
- def __init__(self, model, prompt_builder, snac_decoder):
28
- """
29
- Initialize pipeline.
30
- Args:
31
- model: Maya1Model instance
32
- prompt_builder: Maya1PromptBuilder instance
33
- snac_decoder: SNACDecoder instance
34
- """
35
- self.model = model
36
- self.prompt_builder = prompt_builder
37
- self.snac_decoder = snac_decoder
38
- print(f"✅ Maya1Pipeline initialized")
39
-
40
- async def generate_speech(
41
- self,
42
- description: str,
43
- text: str,
44
- temperature: float = DEFAULT_TEMPERATURE,
45
- top_p: float = DEFAULT_TOP_P,
46
- max_tokens: int = DEFAULT_MAX_TOKENS,
47
- repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
48
- seed: Optional[int] = None,
49
- ) -> Optional[bytes]:
50
- """
51
- Generate speech audio (non-streaming).
52
- Args:
53
- description: Voice description
54
- text: Text to synthesize (may include <emotion> tags)
55
- temperature: Sampling temperature
56
- top_p: Nucleus sampling
57
- max_tokens: Max SNAC tokens to generate
58
- repetition_penalty: Prevent loops
59
- seed: Random seed for reproducibility
60
-
61
- Returns:
62
- Audio bytes (int16 PCM, 24kHz mono) or None if failed
63
- """
64
- # Build prompt
65
- prompt = self.prompt_builder.build_prefix(description, text)
66
-
67
- # Configure sampling
68
- sampling_params = SamplingParams(
69
- temperature=temperature,
70
- top_p=top_p,
71
- max_tokens=max_tokens,
72
- min_tokens=DEFAULT_MIN_TOKENS,
73
- repetition_penalty=repetition_penalty,
74
- stop_token_ids=[CODE_END_TOKEN_ID],
75
- seed=seed if seed is not None else DEFAULT_SEED,
76
- )
77
-
78
- # Generate tokens
79
- outputs = await self.model.generate(prompt, sampling_params)
80
-
81
- if not outputs or len(outputs) == 0:
82
- return None
83
-
84
- output = outputs[0]
85
- generated_token_ids = output.outputs[0].token_ids
86
-
87
- # Extract SNAC codes
88
- snac_codes = self._extract_snac_codes(generated_token_ids)
89
-
90
- if not snac_codes:
91
- return None
92
-
93
- # Decode to audio
94
- audio_bytes = await self.snac_decoder.decode_single_async(snac_codes)
95
-
96
- if audio_bytes:
97
- frames = len(snac_codes) // 7
98
- duration_sec = frames / 6.86
99
- print(f" Generated {frames} frames (~{duration_sec:.1f}s audio)")
100
-
101
- return audio_bytes
102
-
103
- def _extract_snac_codes(self, token_ids: List[int]) -> List[int]:
104
- # Find SOS and EOS positions
105
- try:
106
- sos_idx = token_ids.index(CODE_START_TOKEN_ID)
107
- except ValueError:
108
- sos_idx = -1
109
-
110
- try:
111
- eos_idx = token_ids.index(CODE_END_TOKEN_ID)
112
- except ValueError:
113
- eos_idx = len(token_ids)
114
-
115
- # Extract tokens between SOS and EOS
116
- if sos_idx >= 0:
117
- snac_tokens = token_ids[sos_idx + 1:eos_idx]
118
- else:
119
- # If no SOS found, take everything before EOS
120
- snac_tokens = token_ids[:eos_idx]
121
-
122
- # Filter to only valid SNAC token IDs
123
- snac_codes = [
124
- token_id for token_id in snac_tokens
125
- if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID
126
- ]
127
-
128
- return snac_codes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
maya1/prompt_builder.py DELETED
@@ -1,31 +0,0 @@
1
- """
2
- Maya1 Prompt Builder
3
- Builds formatted prompts for description-conditioned TTS.
4
- Format: <SOH><BOS><description="..."> text<EOT><EOH><SOA><SOS>
5
- """
6
-
7
- from .constants import ALL_EMOTION_TAGS
8
-
9
-
10
- class Maya1PromptBuilder:
11
- """Builds prompts in the format expected by Maya1 model."""
12
-
13
- def __init__(self, tokenizer, model):
14
- self.tokenizer = tokenizer
15
- self.model = model
16
-
17
- def build_prefix(self, description: str, text: str) -> str:
18
- # Format as: <description="..."> text
19
- formatted_text = f'<description="{description}"> {text}'
20
- # Build full prefix with special tokens
21
- prompt = (
22
- self.model.soh_token +
23
- self.model.bos_token +
24
- formatted_text +
25
- self.model.eot_token +
26
- self.model.eoh_token +
27
- self.model.soa_token +
28
- self.model.sos_token
29
- )
30
-
31
- return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
maya1/snac_decoder.py DELETED
@@ -1,515 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import asyncio
4
- from typing import List, Optional, Tuple
5
- from snac import SNAC
6
-
7
- from .constants import (
8
- CODE_END_TOKEN_ID,
9
- CODE_TOKEN_OFFSET,
10
- SNAC_MODEL_NAME,
11
- SNAC_SAMPLE_RATE,
12
- SNAC_TOKENS_PER_FRAME,
13
- )
14
-
15
-
16
- class SNACDecoder:
17
- """
18
- SNAC Decoder for maya1.
19
- Unpacks 7-token SNAC frames and decodes to audio waveforms.
20
- Unpacking logic is the EXACT INVERSE of training preprocessing.
21
- Supports async batching for concurrent requests.
22
- CRITICAL: Any mismatch in unpacking will produce garbage audio.
23
- """
24
-
25
- def __init__(
26
- self,
27
- device: str = "cuda",
28
- compile_decoder: bool = False,
29
- enable_batching: bool = False,
30
- max_batch_size: int = 64,
31
- batch_timeout_ms: int = 15,
32
- ):
33
- """
34
- Initialize SNAC decoder.
35
-
36
- Args:
37
- device: Device for SNAC model (cuda/cpu)
38
- compile_decoder: Use torch.compile for speedup
39
- enable_batching: Enable async batching
40
- max_batch_size: Max sequences to batch together
41
- batch_timeout_ms: Max wait time before processing batch
42
- """
43
- self.device = device
44
- self.enable_batching = enable_batching
45
- self.max_batch_size = max_batch_size
46
- self.batch_timeout_ms = batch_timeout_ms
47
-
48
- print(f"Loading SNAC 24kHz model to {device}...")
49
- self.snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(device)
50
-
51
- if compile_decoder:
52
- print(f"Compiling SNAC decoder with torch.compile...")
53
- self._compile_model()
54
-
55
- # Batching infrastructure
56
- if enable_batching:
57
- self.request_queue = asyncio.Queue()
58
- self.batch_processor_task = None
59
- self._running = False
60
- print(f"Batching enabled (max_batch={max_batch_size}, timeout={batch_timeout_ms}ms)")
61
-
62
- print(f"SNAC decoder initialized")
63
-
64
- def _compile_model(self):
65
- """Compile SNAC decoder with torch.compile"""
66
- # Warm up with various sizes
67
- for frames in [4, 16, 32]:
68
- dummy_codes = [
69
- torch.randint(0, 4096, (1, frames), device=self.device),
70
- torch.randint(0, 4096, (1, frames * 2), device=self.device),
71
- torch.randint(0, 4096, (1, frames * 4), device=self.device),
72
- ]
73
- with torch.inference_mode():
74
- z_q = self.snac_model.quantizer.from_codes(dummy_codes)
75
- _ = self.snac_model.decoder(z_q)
76
-
77
- # Apply compilation
78
- self.snac_model.decoder = torch.compile(
79
- self.snac_model.decoder,
80
- mode="max-autotune"
81
- )
82
- self.snac_model.quantizer = torch.compile(
83
- self.snac_model.quantizer,
84
- mode="reduce-overhead"
85
- )
86
-
87
- print(f"SNAC decoder compiled")
88
-
89
- def unpack_snac_from_7(self, vocab_ids: List[int]) -> List[List[int]]:
90
- """
91
- Unpack 7-token SNAC frames to 3 hierarchical levels.
92
-
93
- This is the EXACT INVERSE of the training preprocessing function
94
- `pack_snac_to_7_and_offset()`.
95
-
96
- Frame structure:
97
- [slot0, slot1, slot2, slot3, slot4, slot5, slot6]
98
-
99
- Unpacking:
100
- - slot0: L1[i]
101
- - slot1: L2[2*i] (even index)
102
- - slot2: L3[4*i + 0]
103
- - slot3: L3[4*i + 1]
104
- - slot4: L2[2*i + 1] (odd index)
105
- - slot5: L3[4*i + 2]
106
- - slot6: L3[4*i + 3]
107
-
108
- Args:
109
- vocab_ids: List of SNAC token IDs (128266-156937)
110
- Must be divisible by 7
111
-
112
- Returns:
113
- [L1, L2, L3] where:
114
- L1: n elements (coarse level)
115
- L2: 2n elements (medium level)
116
- L3: 4n elements (fine level)
117
- """
118
- # Strip EOS token if present
119
- if vocab_ids and vocab_ids[-1] == CODE_END_TOKEN_ID:
120
- vocab_ids = vocab_ids[:-1]
121
-
122
- # Ensure complete frames (divisible by 7)
123
- frames = len(vocab_ids) // SNAC_TOKENS_PER_FRAME
124
- vocab_ids = vocab_ids[:frames * SNAC_TOKENS_PER_FRAME]
125
-
126
- if frames == 0:
127
- return [[], [], []]
128
-
129
- l1, l2, l3 = [], [], []
130
-
131
- for i in range(frames):
132
- # Extract 7 slots for this frame
133
- slots = vocab_ids[i*7:(i+1)*7]
134
-
135
- # Subtract offset (128266) and mod 4096 to get original codes
136
- # Each level uses 4096 codes (0-4095)
137
- l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096)
138
- l2.extend([
139
- (slots[1] - CODE_TOKEN_OFFSET) % 4096, # Even index
140
- (slots[4] - CODE_TOKEN_OFFSET) % 4096, # Odd index
141
- ])
142
- l3.extend([
143
- (slots[2] - CODE_TOKEN_OFFSET) % 4096,
144
- (slots[3] - CODE_TOKEN_OFFSET) % 4096,
145
- (slots[5] - CODE_TOKEN_OFFSET) % 4096,
146
- (slots[6] - CODE_TOKEN_OFFSET) % 4096,
147
- ])
148
-
149
- return [l1, l2, l3]
150
-
151
- @torch.inference_mode()
152
- def decode(
153
- self,
154
- snac_tokens: List[int],
155
- trim_warmup: bool = True,
156
- trim_amount: Optional[int] = None,
157
- use_sliding_window: bool = False
158
- ) -> Optional[np.ndarray]:
159
- """
160
- Decode SNAC tokens to audio waveform.
161
-
162
- Args:
163
- snac_tokens: List of SNAC token IDs (7*n tokens)
164
- trim_warmup: Whether to trim SNAC warmup samples (default: True)
165
- trim_amount: Number of samples to trim (default: 2048 for first chunk, 0 for others)
166
- Can be set to a smaller value (e.g., 512) for intermediate chunks
167
- use_sliding_window: If True, only return middle 2048 samples (for sliding window streaming)
168
-
169
- Returns:
170
- Audio waveform as numpy array (float32, 24kHz mono)
171
- Shape: (samples,)
172
- Returns None if not enough tokens
173
- """
174
- if len(snac_tokens) < SNAC_TOKENS_PER_FRAME:
175
- print(f"Not enough SNAC tokens: {len(snac_tokens)} < {SNAC_TOKENS_PER_FRAME}")
176
- return None
177
-
178
- # Unpack to 3 levels
179
- levels = self.unpack_snac_from_7(snac_tokens)
180
-
181
- if not levels[0]: # No frames after unpacking
182
- return None
183
-
184
- # Convert to tensors
185
- codes = [
186
- torch.tensor(level, dtype=torch.long, device=self.device).unsqueeze(0)
187
- for level in levels
188
- ]
189
-
190
- # Decode through SNAC
191
- z_q = self.snac_model.quantizer.from_codes(codes)
192
- audio = self.snac_model.decoder(z_q)
193
-
194
- # Extract audio (remove padding if any)
195
- # SNAC decoder outputs: [batch, 1, samples]
196
- audio = audio[0, 0].cpu().numpy()
197
-
198
- # Sliding window mode: only keep middle 2048 samples
199
- # This eliminates popping/cracking when using overlapping 28-token windows
200
- if use_sliding_window:
201
- if len(audio) >= 4096:
202
- audio = audio[2048:4096] # Keep middle portion only
203
- else:
204
- # For shorter audio, keep everything (final chunk)
205
- pass
206
- else:
207
- # Standard mode: trim warm-up samples
208
- # Default: 2048 samples for first chunk, 0 for subsequent chunks
209
- # Can be customized via trim_amount parameter
210
- if trim_warmup:
211
- if trim_amount is None:
212
- trim_amount = 2048 # Default full trim
213
-
214
- if len(audio) > trim_amount:
215
- audio = audio[trim_amount:]
216
-
217
- return audio
218
-
219
- def decode_to_bytes(
220
- self,
221
- snac_tokens: List[int],
222
- trim_warmup: bool = True,
223
- use_sliding_window: bool = False
224
- ) -> Optional[bytes]:
225
- """
226
- Decode SNAC tokens to audio bytes (int16 PCM).
227
-
228
- Args:
229
- snac_tokens: List of SNAC token IDs
230
- trim_warmup: Whether to trim SNAC warmup samples (default: True)
231
- use_sliding_window: If True, only return middle 2048 samples (for sliding window streaming)
232
-
233
- Returns:
234
- Audio as bytes (int16 PCM, 24kHz mono)
235
- Returns None if decode fails
236
- """
237
- audio = self.decode(snac_tokens, trim_warmup=trim_warmup, use_sliding_window=use_sliding_window)
238
-
239
- if audio is None:
240
- return None
241
-
242
- # Convert float32 to int16 PCM
243
- audio_int16 = (audio * 32767).astype(np.int16)
244
-
245
- return audio_int16.tobytes()
246
-
247
- def validate_tokens(self, snac_tokens: List[int]) -> bool:
248
- """
249
- Validate SNAC tokens before decoding.
250
- Args:
251
- snac_tokens: List of SNAC token IDs
252
- Returns:
253
- True if valid, False otherwise
254
- """
255
- # Check minimum length
256
- if len(snac_tokens) < SNAC_TOKENS_PER_FRAME:
257
- print(f"Too few tokens: {len(snac_tokens)}")
258
- return False
259
-
260
- # Check divisibility by 7
261
- if len(snac_tokens) % SNAC_TOKENS_PER_FRAME != 0:
262
- print(f" Warning: Token count {len(snac_tokens)} not divisible by 7")
263
- print(f" Will truncate to {(len(snac_tokens) // 7) * 7}")
264
-
265
- # Check token range
266
- for i, token_id in enumerate(snac_tokens):
267
- if token_id < CODE_TOKEN_OFFSET or token_id > 156937:
268
- print(f" Invalid token at position {i}: {token_id}")
269
- print(f" Expected range: [{CODE_TOKEN_OFFSET}, 156937]")
270
- return False
271
-
272
- return True
273
-
274
- # ========== Async Batching Methods ==========
275
-
276
- @property
277
- def is_running(self) -> bool:
278
- """Check if batch processor is running."""
279
- return self._running if self.enable_batching else False
280
-
281
- async def start_batch_processor(self):
282
- """Start the background batch processor task."""
283
- if not self.enable_batching:
284
- return
285
-
286
- if self._running:
287
- print("Batch processor already running")
288
- return
289
-
290
- self._running = True
291
- self.batch_processor_task = asyncio.create_task(self._batch_processor_loop())
292
- print("Batch processor started")
293
-
294
- async def stop_batch_processor(self):
295
- """Stop the background batch processor task."""
296
- if not self.enable_batching:
297
- return
298
-
299
- if not self._running:
300
- return
301
-
302
- self._running = False
303
-
304
- if self.batch_processor_task:
305
- self.batch_processor_task.cancel()
306
- try:
307
- await self.batch_processor_task
308
- except asyncio.CancelledError:
309
- pass
310
-
311
- print("Batch processor stopped")
312
-
313
- async def decode_single_async(
314
- self,
315
- snac_tokens: List[int],
316
- trim_warmup: bool = True,
317
- use_sliding_window: bool = False
318
- ) -> Optional[bytes]:
319
- """
320
- Async decode for batching support.
321
-
322
- Queues the request and waits for batched processing.
323
-
324
- Args:
325
- snac_tokens: List of SNAC token IDs
326
- trim_warmup: Whether to trim SNAC warmup samples (default: True)
327
- use_sliding_window: If True, only return middle 2048 samples (for sliding window streaming)
328
-
329
- Returns:
330
- Audio bytes or None if decode fails
331
- """
332
- if not self.enable_batching:
333
- # Fallback to synchronous decode
334
- return self.decode_to_bytes(snac_tokens, trim_warmup=trim_warmup, use_sliding_window=use_sliding_window)
335
-
336
- # Create future for result
337
- result_future = asyncio.Future()
338
-
339
- # Add to queue (include trim_warmup and sliding_window flags)
340
- await self.request_queue.put((snac_tokens, trim_warmup, use_sliding_window, result_future))
341
-
342
- # Wait for result
343
- return await result_future
344
-
345
- async def _batch_processor_loop(self):
346
- """Background task that processes batched decode requests."""
347
- while self._running:
348
- try:
349
- # Collect batch
350
- batch = await self._collect_batch()
351
-
352
- if not batch:
353
- continue
354
-
355
- # Process batch
356
- await self._process_batch(batch)
357
-
358
- except asyncio.CancelledError:
359
- break
360
- except Exception as e:
361
- print(f"Batch processor error: {e}")
362
- import traceback
363
- traceback.print_exc()
364
-
365
- async def _collect_batch(self) -> List[Tuple[List[int], bool, bool, asyncio.Future]]:
366
- """
367
- Collect requests into a batch.
368
- Waits for timeout or until batch is full.
369
- Returns:
370
- List of (tokens, trim_warmup, use_sliding_window, future) tuples
371
- """
372
- batch = []
373
- timeout_sec = self.batch_timeout_ms / 1000.0
374
-
375
- try:
376
- # Wait for first request (blocking)
377
- first_item = await asyncio.wait_for(
378
- self.request_queue.get(),
379
- timeout=timeout_sec
380
- )
381
- batch.append(first_item)
382
-
383
- # Collect more requests (non-blocking)
384
- while len(batch) < self.max_batch_size:
385
- try:
386
- item = await asyncio.wait_for(
387
- self.request_queue.get(),
388
- timeout=timeout_sec
389
- )
390
- batch.append(item)
391
- except asyncio.TimeoutError:
392
- break # Timeout reached, process what we have
393
-
394
- except asyncio.TimeoutError:
395
- # No requests in timeout period
396
- pass
397
-
398
- return batch
399
-
400
- @torch.inference_mode()
401
- async def _process_batch(self, batch: List[Tuple[List[int], bool, bool, asyncio.Future]]):
402
- """
403
- Process a batch of decode requests.
404
- Args:
405
- batch: List of (tokens, trim_warmup, use_sliding_window, future) tuples
406
- """
407
- if not batch:
408
- return
409
-
410
- # Extract components
411
- token_sequences = [item[0] for item in batch]
412
- trim_warmup_flags = [item[1] for item in batch]
413
- sliding_window_flags = [item[2] for item in batch]
414
- futures = [item[3] for item in batch]
415
-
416
- lengths = [len(tokens) for tokens in token_sequences]
417
- can_batch_efficiently = len(set(lengths)) == 1
418
-
419
- if can_batch_efficiently and len(batch) > 1:
420
- # Efficient batching: all same length
421
- try:
422
- audio_bytes_list = await self._decode_batch_same_length(
423
- token_sequences, trim_warmup_flags, sliding_window_flags
424
- )
425
-
426
- # Set results
427
- for future, audio_bytes in zip(futures, audio_bytes_list):
428
- if not future.done():
429
- future.set_result(audio_bytes)
430
-
431
- except Exception as e:
432
- # Set exceptions
433
- for future in futures:
434
- if not future.done():
435
- future.set_exception(e)
436
- else:
437
- # Sequential decode (different lengths or single item)
438
- for tokens, trim_warmup, use_sliding_window, future in batch:
439
- try:
440
- audio_bytes = self.decode_to_bytes(
441
- tokens, trim_warmup=trim_warmup, use_sliding_window=use_sliding_window
442
- )
443
- if not future.done():
444
- future.set_result(audio_bytes)
445
- except Exception as e:
446
- if not future.done():
447
- future.set_exception(e)
448
-
449
- async def _decode_batch_same_length(
450
- self,
451
- token_sequences: List[List[int]],
452
- trim_warmup_flags: List[bool],
453
- sliding_window_flags: List[bool]
454
- ) -> List[Optional[bytes]]:
455
- """
456
- Decode multiple sequences with same length in parallel.
457
-
458
- Args:
459
- token_sequences: List of token sequences (all same length)
460
- trim_warmup_flags: List of trim_warmup flags for each sequence
461
- sliding_window_flags: List of use_sliding_window flags for each sequence
462
-
463
- Returns:
464
- List of audio bytes
465
- """
466
- if not token_sequences:
467
- return []
468
-
469
- # Unpack all sequences
470
- unpacked_list = [self.unpack_snac_from_7(tokens) for tokens in token_sequences]
471
-
472
- # Check all have valid frames
473
- valid_indices = [i for i, levels in enumerate(unpacked_list) if levels[0]]
474
-
475
- if not valid_indices:
476
- return [None] * len(token_sequences)
477
-
478
- # Stack into batched tensors
479
- batch_size = len(valid_indices)
480
- frames = len(unpacked_list[valid_indices[0]][0])
481
-
482
- # Build batched codes [batch, frames], [batch, 2*frames], [batch, 4*frames]
483
- codes = [
484
- torch.stack([
485
- torch.tensor(unpacked_list[i][level_idx], dtype=torch.long, device=self.device)
486
- for i in valid_indices
487
- ], dim=0)
488
- for level_idx in range(3)
489
- ]
490
-
491
- # Batched decode
492
- z_q = self.snac_model.quantizer.from_codes(codes)
493
- audio_batch = self.snac_model.decoder(z_q) # [batch, 1, samples]
494
-
495
- # Extract and convert to bytes
496
- audio_bytes_list = [None] * len(token_sequences)
497
-
498
- for batch_idx, orig_idx in enumerate(valid_indices):
499
- audio = audio_batch[batch_idx, 0].detach().cpu().numpy()
500
-
501
- # Apply sliding window or trim warmup based on flags
502
- if sliding_window_flags[orig_idx]:
503
- # Sliding window mode: keep middle 2048 samples only
504
- if len(audio) >= 4096:
505
- audio = audio[2048:4096]
506
- else:
507
- # Standard mode: trim warm-up if requested
508
- if trim_warmup_flags[orig_idx] and len(audio) > 2048:
509
- audio = audio[2048:]
510
-
511
- # Convert to int16
512
- audio_int16 = (audio * 32767).astype(np.int16)
513
- audio_bytes_list[orig_idx] = audio_int16.tobytes()
514
-
515
- return audio_bytes_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
maya1/streaming_pipeline.py DELETED
@@ -1,159 +0,0 @@
1
- """
2
- Maya1 Streaming Pipeline - Sliding Window Approach
3
- Implements sliding window technique for smooth streaming without artifacts.
4
- """
5
-
6
- import asyncio
7
- from typing import AsyncGenerator, Optional
8
- from vllm import SamplingParams
9
-
10
- from .constants import (
11
- CODE_END_TOKEN_ID,
12
- SNAC_MIN_ID,
13
- SNAC_MAX_ID,
14
- DEFAULT_TEMPERATURE,
15
- DEFAULT_TOP_P,
16
- DEFAULT_MAX_TOKENS,
17
- DEFAULT_MIN_TOKENS,
18
- DEFAULT_REPETITION_PENALTY,
19
- DEFAULT_SEED,
20
- )
21
-
22
-
23
- class Maya1SlidingWindowPipeline:
24
- """
25
- Streaming TTS pipeline using sliding window approach.
26
- Decodes overlapping 28-token windows (4 frames) and keeps only
27
- the middle 2048 samples for smooth audio continuity.
28
- """
29
-
30
- # Sliding window configuration
31
- WINDOW_SIZE = 28 # 4 frames (7 tokens per frame)
32
- YIELD_STRIDE = 7 # Yield every 1 frame
33
- MIDDLE_SAMPLES = 2048 # Keep middle 2048 samples from each decode
34
-
35
- def __init__(self, model, prompt_builder, snac_decoder):
36
- """
37
- Initialize sliding window streaming pipeline.
38
-
39
- Args:
40
- model: Maya1Model instance
41
- prompt_builder: Maya1PromptBuilder instance
42
- snac_decoder: SNACDecoder instance
43
- """
44
- self.model = model
45
- self.prompt_builder = prompt_builder
46
- self.snac_decoder = snac_decoder
47
- print(f"Sliding window pipeline initialized")
48
-
49
- async def generate_speech_stream(
50
- self,
51
- description: str,
52
- text: str,
53
- temperature: float = DEFAULT_TEMPERATURE,
54
- top_p: float = DEFAULT_TOP_P,
55
- max_tokens: int = DEFAULT_MAX_TOKENS,
56
- repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
57
- seed: Optional[int] = None,
58
- ) -> AsyncGenerator[bytes, None]:
59
- """
60
- Generate speech audio with sliding window streaming.
61
-
62
- Args:
63
- description: Voice description
64
- text: Text to synthesize (may include <emotion> tags)
65
- temperature: Sampling temperature
66
- top_p: Nucleus sampling
67
- max_tokens: Max SNAC tokens to generate
68
- repetition_penalty: Prevent loops
69
- seed: Random seed
70
-
71
- Yields:
72
- Audio bytes (int16 PCM, 24kHz mono)
73
- """
74
- # Build prompt
75
- prompt = self.prompt_builder.build_prefix(description, text)
76
-
77
- # Configure sampling
78
- sampling_params = SamplingParams(
79
- temperature=temperature,
80
- top_p=top_p,
81
- max_tokens=max_tokens,
82
- min_tokens=DEFAULT_MIN_TOKENS,
83
- repetition_penalty=repetition_penalty,
84
- stop_token_ids=[CODE_END_TOKEN_ID],
85
- seed=seed if seed is not None else DEFAULT_SEED,
86
- )
87
-
88
- # Stream tokens
89
- snac_buffer = []
90
- last_yield_position = 0
91
- chunk_count = 0
92
- total_tokens_seen = 0
93
-
94
- async for output in self.model.generate_stream(prompt, sampling_params):
95
- # Get latest generated tokens (cumulative list)
96
- generated_token_ids = output.outputs[0].token_ids
97
-
98
- # Process only NEW tokens since last iteration
99
- new_tokens = generated_token_ids[total_tokens_seen:]
100
- total_tokens_seen = len(generated_token_ids)
101
-
102
- # Collect SNAC codes from new tokens
103
- for token_id in new_tokens:
104
- # Stop if we hit EOS
105
- if token_id == CODE_END_TOKEN_ID:
106
- break
107
-
108
- # Only collect valid SNAC tokens
109
- if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID:
110
- snac_buffer.append(token_id)
111
-
112
- # Yield audio when we have enough tokens for a window
113
- while len(snac_buffer) >= last_yield_position + self.WINDOW_SIZE:
114
- # Get window of 28 tokens
115
- window_start = last_yield_position
116
- window_end = window_start + self.WINDOW_SIZE
117
- window = snac_buffer[window_start:window_end]
118
-
119
- if len(window) == self.WINDOW_SIZE:
120
- # Decode window to audio
121
- audio_bytes = await self.snac_decoder.decode_single_async(window)
122
-
123
- if audio_bytes:
124
- # Extract middle portion of audio
125
- audio_samples = len(audio_bytes) // 2
126
- middle_start_sample = (audio_samples - self.MIDDLE_SAMPLES) // 2
127
- middle_end_sample = middle_start_sample + self.MIDDLE_SAMPLES
128
-
129
- # Convert to byte positions
130
- middle_start_byte = middle_start_sample * 2
131
- middle_end_byte = middle_end_sample * 2
132
-
133
- # Extract middle chunk
134
- audio_chunk = audio_bytes[middle_start_byte:middle_end_byte]
135
-
136
- chunk_count += 1
137
- if chunk_count == 1:
138
- print(f" First chunk ready")
139
-
140
- yield audio_chunk
141
-
142
- # Move forward by stride
143
- last_yield_position += self.YIELD_STRIDE
144
-
145
- # Check if generation is done
146
- if CODE_END_TOKEN_ID in new_tokens:
147
- break
148
-
149
- # Final chunk: decode remaining tokens
150
- remaining_tokens = len(snac_buffer) - last_yield_position
151
- if remaining_tokens >= self.WINDOW_SIZE:
152
- window = snac_buffer[-self.WINDOW_SIZE:]
153
- audio_bytes = await self.snac_decoder.decode_single_async(window)
154
- if audio_bytes:
155
- yield audio_bytes[-self.MIDDLE_SAMPLES * 2:]
156
-
157
- frames = len(snac_buffer) // 7
158
- duration = frames / 6.86
159
- print(f"Streamed {chunk_count} chunks (~{duration:.1f}s audio)")