aarvis commited on
Commit
1c5422d
·
1 Parent(s): cd55f09

Modal Code Update

Browse files
modal/snorTTS_Indic_v0_server.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # * Install Modal.
2
+ # uv run pip install modal
3
+
4
+ # * Setup Modal.
5
+ # uv run modal setup
6
+
7
+ # * Run to deploy the Modal app.
8
+ # uv run modal deploy scripts/modal/snorTTS_Indic_v0_server.py
9
+
10
+ # * Test.
11
+ # curl -X 'POST' \
12
+ # 'https://snorbyte--snortts-indic-v0-server-prod-ttsserver-serve.modal.run/?utterance=%E0%A4%95%E0%A4%B2%20%E0%A4%AE%E0%A5%88%E0%A4%82%E0%A4%A8%E0%A5%87%20%E0%A4%B8%E0%A4%BF%E0%A4%B0%E0%A5%8D%E0%A4%AB%20%E2%82%B9500%20%E0%A4%AE%E0%A5%87%E0%A4%82%20%E0%A4%8F%E0%A4%95%20cool%20headphones%20%E0%A4%B2%E0%A5%87%20%E0%A4%B2%E0%A4%BF%E0%A4%8F%2C%20%E0%A4%AC%E0%A4%B9%E0%A5%81%E0%A4%A4%20%E0%A4%AC%E0%A4%A2%E0%A4%BC%E0%A4%BF%E0%A4%AF%E0%A4%BE%20deal%20%E0%A4%A5%E0%A4%BE%20%E0%A4%AF%E0%A4%BE%E0%A4%B0%21&user_id=159&language=hindi&temperature=0.4&top_p=0.9&repetition_penalty=1.05&speed=1.05&denoise=true&stream=false' \
13
+ # -H 'accept: audio/mpeg' \
14
+ # -d '' \
15
+ # --output outputs/output_non_stream.mp3
16
+
17
+ # curl -X 'POST' \
18
+ # 'https://snorbyte--snortts-indic-v0-server-prod-ttsserver-serve.modal.run/?utterance=%E0%A4%95%E0%A4%B2%20%E0%A4%AE%E0%A5%88%E0%A4%82%E0%A4%A8%E0%A5%87%20%E0%A4%B8%E0%A4%BF%E0%A4%B0%E0%A5%8D%E0%A4%AB%20%E2%82%B9500%20%E0%A4%AE%E0%A5%87%E0%A4%82%20%E0%A4%8F%E0%A4%95%20cool%20headphones%20%E0%A4%B2%E0%A5%87%20%E0%A4%B2%E0%A4%BF%E0%A4%8F%2C%20%E0%A4%AC%E0%A4%B9%E0%A5%81%E0%A4%A4%20%E0%A4%AC%E0%A4%A2%E0%A4%BC%E0%A4%BF%E0%A4%AF%E0%A4%BE%20deal%20%E0%A4%A5%E0%A4%BE%20%E0%A4%AF%E0%A4%BE%E0%A4%B0%21&user_id=159&language=hindi&temperature=0.4&top_p=0.9&repetition_penalty=1.05&speed=1.05&denoise=true&stream=true' \
19
+ # -H 'accept: audio/mpeg' \
20
+ # -d '' \
21
+ # --output outputs/output_stream.mp3
22
+
23
+ # Import Modal.
24
+ import modal
25
+
26
+
27
+ # Define constants.
28
+ APP_NAME = "snorTTS-Indic-v0-server-prod"
29
+ SCALEDOWN_WINDOW = 15 * 60
30
+ TIMEOUT = 10 * 60
31
+ MIN_CONTAINERS = 1
32
+ MAX_CONTAINERS = 1
33
+ MAX_CONCURRENT_REQUESTS = 5
34
+
35
+
36
+ # Define the Docker image.
37
+ image = (
38
+ modal.Image.debian_slim(python_version="3.12")
39
+ .apt_install(
40
+ "curl", # Install curl for downloading files.
41
+ "ffmpeg", # Install ffmpeg for audio processing.
42
+ "git", # Install git for version control.
43
+ "libsox-dev", # Install SoX for audio processing.
44
+ )
45
+ .run_commands(
46
+ "curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y", # Install Rust.
47
+ )
48
+ .env(
49
+ {
50
+ "PATH": "/root/.cargo/bin:${PATH}", # Add Rust to PATH.
51
+ "HF_HUB_ENABLE_HF_TRANSFER": "1", # Set `HF_HUB_ENABLE_HF_TRANSFER` for fast model transfers.
52
+ }
53
+ )
54
+ .pip_install(
55
+ "deepfilternet", # Install DeepFilterNet for audio denoising.
56
+ "fastapi[standard]", # Install FastAPI for building the API.
57
+ "hf_transfer", # Install Hugging Face transfer for fast model transfers.
58
+ "loguru", # Install Loguru for logging.
59
+ "numpy", # Install NumPy for numerical operations.
60
+ "pydub", # Install Pydub for audio processing.
61
+ "snac", # Install SNAC for audio decoding.
62
+ "torchaudio", # Install Torchaudio for audio processing.
63
+ "transformers", # Install Transformers for model handling.
64
+ )
65
+ )
66
+
67
+ # Create the Modal app.
68
+ app = modal.App(APP_NAME, image=image)
69
+
70
+ with image.imports():
71
+ # Import necessary libraries for the remote function.
72
+ from typing import Any
73
+ import aiohttp
74
+ import io
75
+ import json
76
+
77
+ from df.enhance import init_df, enhance
78
+ from fastapi.responses import StreamingResponse
79
+ from loguru import logger
80
+ from pydub import AudioSegment
81
+ from snac import SNAC
82
+ from transformers import AutoTokenizer
83
+ import numpy as np
84
+ import torch
85
+ import torchaudio as ta
86
+
87
+
88
+ @app.cls(
89
+ cpu=4.0, # Set number of CPU cores.
90
+ memory=8192, # Set memory in MiB.
91
+ scaledown_window=SCALEDOWN_WINDOW, # Set how long should we stay up with no requests.
92
+ timeout=TIMEOUT, # Set the timeout for the function.
93
+ enable_memory_snapshot=True, # Enable memory snapshot for better cold boot times.
94
+ min_containers=MIN_CONTAINERS, # Minimum number of containers to keep running.
95
+ max_containers=MAX_CONTAINERS, # Maximum number of containers to run.
96
+ )
97
+ @modal.concurrent(
98
+ max_inputs=MAX_CONCURRENT_REQUESTS
99
+ ) # Limit the number of concurrent requests.
100
+ class TTSServer:
101
+ @modal.enter()
102
+ def load(self) -> None:
103
+ # Load the tokenizer.
104
+ self.tokenizer = AutoTokenizer.from_pretrained("snorbyte/snorTTS-Indic-v0")
105
+ logger.success("Loaded tokenizer from snorbyte/snorTTS-Indic-v0.")
106
+
107
+ # Token related bookkeeping.
108
+ # Set the tokenizer length.
109
+ self.tokeniser_length = 128256
110
+ logger.success("Set tokenizer length.")
111
+
112
+ # Set the end of speech ID, pad token ID, and audio start ID.
113
+ self.end_of_speech_id = self.tokeniser_length + 2
114
+ self.pad_token_id = self.tokeniser_length + 7
115
+ self.audio_start_id = self.tokeniser_length + 10
116
+ logger.success("Set end of speech ID, pad token ID, and audio start ID.")
117
+
118
+ # Decode the pad token.
119
+ self.pad_token = self.tokenizer.decode([self.pad_token_id])
120
+ logger.success("Decoded pad token.")
121
+
122
+ # Set the padding token and padding side.
123
+ self.tokenizer.pad_token = self.pad_token
124
+ self.tokenizer.padding_side = "left"
125
+ logger.success("Set padding token and padding side for the tokenizer.")
126
+
127
+ # Models.
128
+ # Load the SNAC model for audio decoding.
129
+ self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
130
+ logger.success("Loaded SNAC model for audio decoding.")
131
+
132
+ # Initialize the DF model for audio denoising.
133
+ self.df_model, self.df_state, _ = init_df()
134
+ logger.success("Initialized DF model for audio denoising.")
135
+
136
+ async def _decode_audio(self, audio_ids: list[int], speed: float, denoise: bool):
137
+ # Offset the audio tokens by the audio start ID.
138
+ snac_audio_ids = []
139
+ for i in range((len(audio_ids) + 1) // 7):
140
+ for j in range(7):
141
+ snac_audio_ids += [audio_ids[7 * i + j] - self.audio_start_id]
142
+
143
+ # Prepare the codes for SNAC decoding.
144
+ # ! Please note: codes cannot be negative. If the model generates incorrect codes
145
+ # ! at the wrong positions, audio generation will fail.
146
+ codes = [[], [], []]
147
+ for i in range((len(snac_audio_ids) + 1) // 7):
148
+ codes[0].append(snac_audio_ids[7 * i])
149
+ codes[1].append(snac_audio_ids[7 * i + 1] - 4096)
150
+ codes[2].append(snac_audio_ids[7 * i + 2] - (2 * 4096))
151
+ codes[2].append(snac_audio_ids[7 * i + 3] - (3 * 4096))
152
+ codes[1].append(snac_audio_ids[7 * i + 4] - (4 * 4096))
153
+ codes[2].append(snac_audio_ids[7 * i + 5] - (5 * 4096))
154
+ codes[2].append(snac_audio_ids[7 * i + 6] - (6 * 4096))
155
+ codes = [
156
+ torch.tensor(codes[0]).unsqueeze(0),
157
+ torch.tensor(codes[1]).unsqueeze(0),
158
+ torch.tensor(codes[2]).unsqueeze(0),
159
+ ]
160
+
161
+ try:
162
+ # Decode the audio using SNAC.
163
+ audio = self.snac_model.decode(codes).reshape(1, -1)
164
+ logger.success(f"Decoded {len(snac_audio_ids)} SNAC tokens to audio.")
165
+ except Exception as e:
166
+ logger.error(f"Error decoding audio: {e}")
167
+ return None
168
+
169
+ # Speed up or slow down the audio.
170
+ if abs(speed - 1.0) > 1e-4:
171
+ try:
172
+ audio, _ = ta.sox_effects.apply_effects_tensor(
173
+ audio, 24_000, effects=[["tempo", f"{speed}"]]
174
+ )
175
+ logger.success(
176
+ f"Applied speed effect to audio with speed factor {speed}."
177
+ )
178
+ except Exception as e:
179
+ logger.error(f"Error applying speed effect: {e}")
180
+ return None
181
+
182
+ # Denoise the audio.
183
+ if denoise:
184
+ try:
185
+ audio = ta.transforms.Resample(orig_freq=24_000, new_freq=48_000)(audio)
186
+ audio = enhance(self.df_model, self.df_state, audio)
187
+ audio = ta.transforms.Resample(orig_freq=48_000, new_freq=24_000)(audio)
188
+ logger.success("Denoised audio using DeepFilterNet.")
189
+ except Exception as e:
190
+ logger.error(f"Error denoising audio: {e}")
191
+ return None
192
+
193
+ # Move the audio to CPU and convert to numpy array.
194
+ audio = audio.detach().squeeze().cpu().numpy()
195
+
196
+ return audio
197
+
198
+ async def _generate(
199
+ self,
200
+ utterance: str,
201
+ user_id: str = 159,
202
+ language: str = "hindi",
203
+ temperature: float = 0.4,
204
+ top_p: float = 0.9,
205
+ repetition_penalty: float = 1.05,
206
+ speed: float = 1.05,
207
+ denoise: bool = False,
208
+ stream: bool = True,
209
+ ):
210
+ try:
211
+ # Limit the utterance length to 50 words.
212
+ utterance = " ".join(utterance.split(" ")[:50])
213
+
214
+ logger.info(
215
+ f"Generating audio for utterance, {utterance}, user_id, {user_id}, language, {language}, "
216
+ f"temperature, {temperature}, top_p, {top_p}, repetition_penalty, {repetition_penalty}, "
217
+ f"speed, {speed}, denoise, {denoise} and stream, {stream}."
218
+ )
219
+
220
+ # Create the prompt.
221
+ prompt = f"<custom_token_3><|begin_of_text|>{language}{user_id}: {utterance}<|eot_id|><custom_token_4><custom_token_5><custom_token_1>"
222
+
223
+ # Tokenize the prompt.
224
+ inputs = self.tokenizer(prompt, add_special_tokens=False)
225
+
226
+ # Set max audio tokens to generate.
227
+ max_tokens = 2048 - len(inputs.input_ids)
228
+
229
+ # Generate the output.
230
+ async with aiohttp.ClientSession(
231
+ base_url="https://snorbyte--snortts-indic-v0-vllm-prod-serve.modal.run"
232
+ ) as session:
233
+ # Prepare the payload for the vLLM server.
234
+ # ! Without type hinting the vLLM server will not recognize the request.
235
+ payload: dict[str, Any] = {
236
+ "prompt": prompt,
237
+ "model": "llm",
238
+ "stream": True,
239
+ "temperature": 0.4,
240
+ "top_p": 0.9,
241
+ "max_tokens": max_tokens,
242
+ "repetition_penalty": 1.05,
243
+ "add_special_tokens": False,
244
+ "stop_token_ids": [128258],
245
+ }
246
+
247
+ # Set the headers for the request.
248
+ headers = {
249
+ "Content-Type": "application/json",
250
+ "Accept": "text/event-stream",
251
+ }
252
+
253
+ # Initialize the audio tokens list.
254
+ audio_ids = []
255
+
256
+ # Send the request to the vLLM server to generate audio.
257
+ async with session.post(
258
+ "/v1/completions",
259
+ json=payload,
260
+ headers=headers,
261
+ timeout=1 * 60,
262
+ ) as resp:
263
+ # Maintine a buffer for the audio data.
264
+ buffer = io.BytesIO()
265
+
266
+ # Stream the vLLM response.
267
+ async for raw in resp.content:
268
+ # Check if the response is successful.
269
+ resp.raise_for_status()
270
+
271
+ # Decode bytes.
272
+ line = raw.decode().strip()
273
+
274
+ # Skip empty lines and end of stream.
275
+ if not line or line == "data: [DONE]":
276
+ continue
277
+
278
+ # Remove the "data: " prefix if present.
279
+ if line.startswith("data: "):
280
+ line = line[len("data: ") :]
281
+
282
+ # Parse the JSON response.
283
+ chunk = json.loads(line)
284
+
285
+ # Tokenize the generated tokens.
286
+ output = self.tokenizer(
287
+ chunk["choices"][0]["text"], add_special_tokens=False
288
+ ).input_ids
289
+
290
+ # Extract audio tokens from the output.
291
+ for id in output:
292
+ if id >= self.audio_start_id:
293
+ audio_ids.append(id)
294
+
295
+ # If streaming is enabled and the audio_ids list has more than 168 tokens,
296
+ # decode and yield the audio.
297
+ # ! This will lead to jittering in the audio stream.
298
+ if stream and len(audio_ids) > 168:
299
+ # Decode tokens to audio.
300
+ audio = await self._decode_audio(audio_ids, speed, denoise)
301
+
302
+ if audio is not None:
303
+ # Write the audio to the buffer.
304
+ # Convert to int16 PCM format expected by AudioSegment.
305
+ audio_int16 = (audio * 32767).astype(np.int16)
306
+
307
+ # Create raw audio segment.
308
+ raw_audio = AudioSegment(
309
+ audio_int16.tobytes(),
310
+ frame_rate=24000,
311
+ sample_width=2,
312
+ channels=1,
313
+ )
314
+
315
+ # Export the audio to the buffer in MP3 format.
316
+ raw_audio.export(buffer, format="mp3", bitrate="96k")
317
+
318
+ # Reset the buffer's internal pointer to the beginning of the stream.
319
+ # This allows reading the entire content from the start.
320
+ buffer.seek(0)
321
+
322
+ # Read the entire contents of the buffer into the `data` variable.
323
+ audio_data = buffer.read()
324
+
325
+ # Move the buffer's internal pointer back to the beginning again.
326
+ # This is done to prepare it for clearing.
327
+ buffer.seek(0)
328
+
329
+ # Truncate the buffer, effectively removing all contents.
330
+ # This clears it for reuse with new audio data.
331
+ buffer.truncate(0)
332
+
333
+ # Yield the audio data.
334
+ yield audio_data
335
+
336
+ # Keep the last incomplete frame.
337
+ last_index = len(audio_ids) % 7
338
+ if last_index == 0:
339
+ audio_ids = []
340
+ else:
341
+ audio_ids = audio_ids[-last_index:]
342
+
343
+ # Check if there are any remaining audio tokens to process.
344
+ if audio_ids:
345
+ # Decode tokens to audio.
346
+ audio = await self._decode_audio(audio_ids, speed, denoise)
347
+
348
+ if audio is not None:
349
+ # Write the audio to the buffer.
350
+ # Convert to int16 PCM format expected by AudioSegment.
351
+ audio_int16 = (audio * 32767).astype(np.int16)
352
+
353
+ # Create raw audio segment.
354
+ raw_audio = AudioSegment(
355
+ audio_int16.tobytes(),
356
+ frame_rate=24000,
357
+ sample_width=2,
358
+ channels=1,
359
+ )
360
+
361
+ # Export the audio to the buffer in MP3 format.
362
+ raw_audio.export(buffer, format="mp3", bitrate="96k")
363
+
364
+ # Reset the buffer's internal pointer to the beginning of the stream.
365
+ # This allows reading the entire content from the start.
366
+ buffer.seek(0)
367
+
368
+ # Read the entire contents of the buffer into the `data` variable.
369
+ audio_data = buffer.read()
370
+
371
+ # Move the buffer's internal pointer back to the beginning again.
372
+ # This is done to prepare it for clearing.
373
+ buffer.seek(0)
374
+
375
+ # Truncate the buffer, effectively removing all contents.
376
+ # This clears it for reuse with new audio data.
377
+ buffer.truncate(0)
378
+
379
+ # Yield the audio data.
380
+ yield audio_data
381
+ except Exception as e:
382
+ logger.exception(f"Error during audio generation: {e}")
383
+
384
+ @modal.fastapi_endpoint(
385
+ docs=True, method="POST"
386
+ ) # Define a FastAPI endpoint for TTS.
387
+ async def serve(
388
+ self,
389
+ utterance: str,
390
+ user_id: str = 159,
391
+ language: str = "hindi",
392
+ temperature: float = 0.4,
393
+ top_p: float = 0.9,
394
+ repetition_penalty: float = 1.05,
395
+ speed: float = 1.05,
396
+ denoise: bool = False,
397
+ stream: bool = True,
398
+ ):
399
+ # Stream the generated audio as an MP3 response.
400
+ return StreamingResponse(
401
+ self._generate(
402
+ utterance,
403
+ user_id=user_id,
404
+ language=language,
405
+ temperature=temperature,
406
+ top_p=top_p,
407
+ repetition_penalty=repetition_penalty,
408
+ speed=speed,
409
+ denoise=denoise,
410
+ stream=stream,
411
+ ),
412
+ media_type="audio/mpeg",
413
+ )
modal/snorTTS_Indic_v0_vllm.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # * Install Modal.
2
+ # uv run pip install modal
3
+
4
+ # * Setup Modal.
5
+ # uv run modal setup
6
+
7
+ # * Run to deploy the Modal app.
8
+ # uv run modal deploy scripts/modal/snorTTS_Indic_v0_vllm.py
9
+
10
+ # Import Modal.
11
+ import modal
12
+
13
+
14
+ # Define constants.
15
+ MODEL_NAME = "snorbyte/snorTTS-Indic-v0"
16
+ MAX_SEQ_LEN = 2048
17
+ MAX_CONCURRENT_SEQS = 5
18
+ APP_NAME = "snorTTS-Indic-v0-vllm-prod"
19
+ SCALEDOWN_WINDOW = 15 * 60
20
+ TIMEOUT = 10 * 60
21
+ VLLM_PORT = 8000
22
+ GPU = "T4"
23
+ MIN_CONTAINERS = 1
24
+ MAX_CONTAINERS = 1
25
+ MAX_CONCURRENT_REQUESTS = MAX_CONCURRENT_SEQS
26
+
27
+ # Define the Docker image.
28
+ image = (
29
+ modal.Image.debian_slim(python_version="3.12")
30
+ .pip_install(
31
+ "vllm==0.9.1", # Install vLLM for serving models.
32
+ "huggingface_hub[hf_transfer]==0.32.0", # Install Hugging Face transfer for fast model transfers.
33
+ "flashinfer-python==0.2.6.post1", # Install FlashInfer for optimized inference.
34
+ extra_index_url="https://download.pytorch.org/whl/cu128", # Use pytorch's extra index url for flashinfer.
35
+ )
36
+ .env(
37
+ {
38
+ "HF_HUB_ENABLE_HF_TRANSFER": "1", # Set `HF_HUB_ENABLE_HF_TRANSFER` for fast model transfers.
39
+ }
40
+ )
41
+ )
42
+
43
+ # Setup volumes for cache.
44
+ hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True)
45
+ vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True)
46
+
47
+ # Create the Modal app.
48
+ app = modal.App(APP_NAME)
49
+
50
+
51
+ with image.imports():
52
+ # Import necessary libraries for the remote function.
53
+ import subprocess
54
+
55
+
56
+ # Define the function to start the VLLM server.
57
+ @app.function(
58
+ image=image, # Set the image for the function.
59
+ gpu=GPU, # Set the GPU type for the instance.
60
+ scaledown_window=SCALEDOWN_WINDOW, # Set how we long should we stay up with no requests.
61
+ timeout=TIMEOUT, # Set the timeout for the function.
62
+ volumes={
63
+ "/root/.cache/huggingface": hf_cache_vol,
64
+ "/root/.cache/vllm": vllm_cache_vol,
65
+ }, # Set the volumes for cache.
66
+ min_containers=MIN_CONTAINERS, # Minimum number of containers to keep running.
67
+ max_containers=MAX_CONTAINERS, # Maximum number of containers to run.
68
+ )
69
+ @modal.concurrent(
70
+ max_inputs=MAX_CONCURRENT_REQUESTS
71
+ ) # Limit the number of concurrent requests.
72
+ @modal.web_server(
73
+ port=VLLM_PORT, startup_timeout=TIMEOUT
74
+ ) # Expose the VLLM server on the specified port.
75
+ def serve():
76
+ # Create the command to start the VLLM server.
77
+ cmd = [
78
+ "vllm",
79
+ "serve",
80
+ "--uvicorn-log-level=info",
81
+ MODEL_NAME,
82
+ "--served-model-name",
83
+ MODEL_NAME,
84
+ "llm",
85
+ "--max-model-len",
86
+ str(MAX_SEQ_LEN),
87
+ "--max-num-seqs",
88
+ str(MAX_CONCURRENT_SEQS),
89
+ "--host",
90
+ "0.0.0.0",
91
+ "--port",
92
+ str(VLLM_PORT),
93
+ ]
94
+
95
+ # Start the VLLM server using subprocess.
96
+ subprocess.Popen(" ".join(cmd), shell=True)