ford442 commited on
Commit
1d253a7
·
verified ·
1 Parent(s): 27e3c8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -24
app.py CHANGED
@@ -31,7 +31,13 @@ from loguru import logger
31
  from fish_speech.i18n import i18n
32
  from fish_speech.inference_engine import TTSInferenceEngine
33
  from fish_speech.models.dac.inference import load_model as load_decoder_model
34
- from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
 
 
 
 
 
 
35
  from tools.webui.inference import get_inference_wrapper
36
  from fish_speech.utils.schema import ServeTTSRequest
37
 
@@ -63,10 +69,10 @@ The model running in this WebUI is OpenAudio S1 Mini.
63
  TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
64
 
65
  try:
66
-
67
  GPU_DECORATOR = spaces.GPU
 
68
  except ImportError:
69
-
70
  def GPU_DECORATOR(func):
71
  def wrapper(*args, **kwargs):
72
  return func(*args, **kwargs)
@@ -263,24 +269,69 @@ def parse_args():
263
 
264
  return parser.parse_args()
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
  if __name__ == "__main__":
268
  args = parse_args()
269
  args.precision = torch.half if args.half else torch.bfloat16
270
 
271
  logger.info("Loading Llama model...")
272
- llama_queue = launch_thread_safe_queue(
273
- checkpoint_path=args.llama_checkpoint_path,
274
- device=args.device,
275
- precision=args.precision,
276
- compile=args.compile,
277
- )
 
 
 
 
 
 
 
 
 
 
 
 
278
  logger.info("Llama model loaded, loading VQ-GAN model...")
279
 
280
  decoder_model = load_decoder_model(
281
  config_name=args.decoder_config_name,
282
  checkpoint_path=args.decoder_checkpoint_path,
283
- device=args.device,
284
  )
285
 
286
  logger.info("Decoder model loaded, warming up...")
@@ -294,25 +345,38 @@ if __name__ == "__main__":
294
  )
295
 
296
  # Dry run to check if the model is loaded correctly and avoid the first-time latency
297
- list(
298
- inference_engine.inference(
299
- ServeTTSRequest(
300
- text="Hello world.",
301
- references=[],
302
- reference_id=None,
303
- max_new_tokens=1024,
304
- chunk_length=200,
305
- top_p=0.7,
306
- repetition_penalty=1.5,
307
- temperature=0.7,
308
- format="wav",
 
 
 
309
  )
310
  )
311
- )
312
 
313
  logger.info("Warming up done, launching the web UI...")
314
 
315
  inference_fct = get_inference_wrapper(inference_engine)
316
 
317
- app = build_app(inference_fct, args.theme)
 
 
 
 
 
 
 
 
 
 
 
318
  app.queue(api_open=True).launch(show_error=True, show_api=True)
 
31
  from fish_speech.i18n import i18n
32
  from fish_speech.inference_engine import TTSInferenceEngine
33
  from fish_speech.models.dac.inference import load_model as load_decoder_model
34
+ from fish_speech.models.text2semantic.inference import (
35
+ launch_thread_safe_queue,
36
+ load_model as load_llama_model,
37
+ generate_long,
38
+ GenerateRequest,
39
+ WrappedGenerateResponse
40
+ )
41
  from tools.webui.inference import get_inference_wrapper
42
  from fish_speech.utils.schema import ServeTTSRequest
43
 
 
69
  TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
70
 
71
  try:
 
72
  GPU_DECORATOR = spaces.GPU
73
+ IS_SPACES = True
74
  except ImportError:
75
+ IS_SPACES = False
76
  def GPU_DECORATOR(func):
77
  def wrapper(*args, **kwargs):
78
  return func(*args, **kwargs)
 
269
 
270
  return parser.parse_args()
271
 
272
+ class SynchronousLlamaWorker:
273
+ def __init__(self, checkpoint_path, precision, compile):
274
+ self.model, self.decode_one_token = load_llama_model(
275
+ checkpoint_path, "cpu", precision, compile=compile
276
+ )
277
+
278
+ def put(self, req: GenerateRequest):
279
+ request_args = req.request
280
+ response_queue = req.response_queue
281
+
282
+ # Move model to CUDA for inference
283
+ device = "cuda" if torch.cuda.is_available() else "cpu"
284
+ self.model.to(device)
285
+
286
+ # Setup caches
287
+ with torch.device(device):
288
+ self.model.setup_caches(
289
+ max_batch_size=1,
290
+ max_seq_len=self.model.config.max_seq_len,
291
+ dtype=next(self.model.parameters()).dtype,
292
+ )
293
+
294
+ request_args['device'] = device
295
+
296
+ try:
297
+ for chunk in generate_long(
298
+ model=self.model, decode_one_token=self.decode_one_token, **request_args
299
+ ):
300
+ response_queue.put(
301
+ WrappedGenerateResponse(status="success", response=chunk)
302
+ )
303
+ except Exception as e:
304
+ response_queue.put(WrappedGenerateResponse(status="error", response=e))
305
 
306
  if __name__ == "__main__":
307
  args = parse_args()
308
  args.precision = torch.half if args.half else torch.bfloat16
309
 
310
  logger.info("Loading Llama model...")
311
+
312
+ # If running in a Spaces environment, we use a synchronous worker and lazy loading
313
+ if IS_SPACES:
314
+ llama_queue = SynchronousLlamaWorker(
315
+ checkpoint_path=args.llama_checkpoint_path,
316
+ precision=args.precision,
317
+ compile=args.compile,
318
+ )
319
+ device = "cpu"
320
+ else:
321
+ llama_queue = launch_thread_safe_queue(
322
+ checkpoint_path=args.llama_checkpoint_path,
323
+ device=args.device,
324
+ precision=args.precision,
325
+ compile=args.compile,
326
+ )
327
+ device = args.device
328
+
329
  logger.info("Llama model loaded, loading VQ-GAN model...")
330
 
331
  decoder_model = load_decoder_model(
332
  config_name=args.decoder_config_name,
333
  checkpoint_path=args.decoder_checkpoint_path,
334
+ device=device,
335
  )
336
 
337
  logger.info("Decoder model loaded, warming up...")
 
345
  )
346
 
347
  # Dry run to check if the model is loaded correctly and avoid the first-time latency
348
+ # Skip dry run in Spaces to avoid using quota or crashing due to GPU access in main process
349
+ if not IS_SPACES:
350
+ list(
351
+ inference_engine.inference(
352
+ ServeTTSRequest(
353
+ text="Hello world.",
354
+ references=[],
355
+ reference_id=None,
356
+ max_new_tokens=1024,
357
+ chunk_length=200,
358
+ top_p=0.7,
359
+ repetition_penalty=1.5,
360
+ temperature=0.7,
361
+ format="wav",
362
+ )
363
  )
364
  )
 
365
 
366
  logger.info("Warming up done, launching the web UI...")
367
 
368
  inference_fct = get_inference_wrapper(inference_engine)
369
 
370
+ # Decorate the inference function with GPU access if in Spaces
371
+ if IS_SPACES:
372
+ @GPU_DECORATOR
373
+ def gpu_inference_wrapper(*args, **kwargs):
374
+ decoder_model.to("cuda")
375
+ return inference_fct(*args, **kwargs)
376
+
377
+ final_inference_fct = gpu_inference_wrapper
378
+ else:
379
+ final_inference_fct = inference_fct
380
+
381
+ app = build_app(final_inference_fct, args.theme)
382
  app.queue(api_open=True).launch(show_error=True, show_api=True)