Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -31,7 +31,13 @@ from loguru import logger
|
|
| 31 |
from fish_speech.i18n import i18n
|
| 32 |
from fish_speech.inference_engine import TTSInferenceEngine
|
| 33 |
from fish_speech.models.dac.inference import load_model as load_decoder_model
|
| 34 |
-
from fish_speech.models.text2semantic.inference import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
from tools.webui.inference import get_inference_wrapper
|
| 36 |
from fish_speech.utils.schema import ServeTTSRequest
|
| 37 |
|
|
@@ -63,10 +69,10 @@ The model running in this WebUI is OpenAudio S1 Mini.
|
|
| 63 |
TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
|
| 64 |
|
| 65 |
try:
|
| 66 |
-
|
| 67 |
GPU_DECORATOR = spaces.GPU
|
|
|
|
| 68 |
except ImportError:
|
| 69 |
-
|
| 70 |
def GPU_DECORATOR(func):
|
| 71 |
def wrapper(*args, **kwargs):
|
| 72 |
return func(*args, **kwargs)
|
|
@@ -263,24 +269,69 @@ def parse_args():
|
|
| 263 |
|
| 264 |
return parser.parse_args()
|
| 265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
|
| 267 |
if __name__ == "__main__":
|
| 268 |
args = parse_args()
|
| 269 |
args.precision = torch.half if args.half else torch.bfloat16
|
| 270 |
|
| 271 |
logger.info("Loading Llama model...")
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
logger.info("Llama model loaded, loading VQ-GAN model...")
|
| 279 |
|
| 280 |
decoder_model = load_decoder_model(
|
| 281 |
config_name=args.decoder_config_name,
|
| 282 |
checkpoint_path=args.decoder_checkpoint_path,
|
| 283 |
-
device=
|
| 284 |
)
|
| 285 |
|
| 286 |
logger.info("Decoder model loaded, warming up...")
|
|
@@ -294,25 +345,38 @@ if __name__ == "__main__":
|
|
| 294 |
)
|
| 295 |
|
| 296 |
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
|
|
|
|
|
|
|
|
|
| 309 |
)
|
| 310 |
)
|
| 311 |
-
)
|
| 312 |
|
| 313 |
logger.info("Warming up done, launching the web UI...")
|
| 314 |
|
| 315 |
inference_fct = get_inference_wrapper(inference_engine)
|
| 316 |
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
app.queue(api_open=True).launch(show_error=True, show_api=True)
|
|
|
|
| 31 |
from fish_speech.i18n import i18n
|
| 32 |
from fish_speech.inference_engine import TTSInferenceEngine
|
| 33 |
from fish_speech.models.dac.inference import load_model as load_decoder_model
|
| 34 |
+
from fish_speech.models.text2semantic.inference import (
|
| 35 |
+
launch_thread_safe_queue,
|
| 36 |
+
load_model as load_llama_model,
|
| 37 |
+
generate_long,
|
| 38 |
+
GenerateRequest,
|
| 39 |
+
WrappedGenerateResponse
|
| 40 |
+
)
|
| 41 |
from tools.webui.inference import get_inference_wrapper
|
| 42 |
from fish_speech.utils.schema import ServeTTSRequest
|
| 43 |
|
|
|
|
| 69 |
TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
|
| 70 |
|
| 71 |
try:
|
|
|
|
| 72 |
GPU_DECORATOR = spaces.GPU
|
| 73 |
+
IS_SPACES = True
|
| 74 |
except ImportError:
|
| 75 |
+
IS_SPACES = False
|
| 76 |
def GPU_DECORATOR(func):
|
| 77 |
def wrapper(*args, **kwargs):
|
| 78 |
return func(*args, **kwargs)
|
|
|
|
| 269 |
|
| 270 |
return parser.parse_args()
|
| 271 |
|
| 272 |
+
class SynchronousLlamaWorker:
|
| 273 |
+
def __init__(self, checkpoint_path, precision, compile):
|
| 274 |
+
self.model, self.decode_one_token = load_llama_model(
|
| 275 |
+
checkpoint_path, "cpu", precision, compile=compile
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
def put(self, req: GenerateRequest):
|
| 279 |
+
request_args = req.request
|
| 280 |
+
response_queue = req.response_queue
|
| 281 |
+
|
| 282 |
+
# Move model to CUDA for inference
|
| 283 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 284 |
+
self.model.to(device)
|
| 285 |
+
|
| 286 |
+
# Setup caches
|
| 287 |
+
with torch.device(device):
|
| 288 |
+
self.model.setup_caches(
|
| 289 |
+
max_batch_size=1,
|
| 290 |
+
max_seq_len=self.model.config.max_seq_len,
|
| 291 |
+
dtype=next(self.model.parameters()).dtype,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
request_args['device'] = device
|
| 295 |
+
|
| 296 |
+
try:
|
| 297 |
+
for chunk in generate_long(
|
| 298 |
+
model=self.model, decode_one_token=self.decode_one_token, **request_args
|
| 299 |
+
):
|
| 300 |
+
response_queue.put(
|
| 301 |
+
WrappedGenerateResponse(status="success", response=chunk)
|
| 302 |
+
)
|
| 303 |
+
except Exception as e:
|
| 304 |
+
response_queue.put(WrappedGenerateResponse(status="error", response=e))
|
| 305 |
|
| 306 |
if __name__ == "__main__":
|
| 307 |
args = parse_args()
|
| 308 |
args.precision = torch.half if args.half else torch.bfloat16
|
| 309 |
|
| 310 |
logger.info("Loading Llama model...")
|
| 311 |
+
|
| 312 |
+
# If running in a Spaces environment, we use a synchronous worker and lazy loading
|
| 313 |
+
if IS_SPACES:
|
| 314 |
+
llama_queue = SynchronousLlamaWorker(
|
| 315 |
+
checkpoint_path=args.llama_checkpoint_path,
|
| 316 |
+
precision=args.precision,
|
| 317 |
+
compile=args.compile,
|
| 318 |
+
)
|
| 319 |
+
device = "cpu"
|
| 320 |
+
else:
|
| 321 |
+
llama_queue = launch_thread_safe_queue(
|
| 322 |
+
checkpoint_path=args.llama_checkpoint_path,
|
| 323 |
+
device=args.device,
|
| 324 |
+
precision=args.precision,
|
| 325 |
+
compile=args.compile,
|
| 326 |
+
)
|
| 327 |
+
device = args.device
|
| 328 |
+
|
| 329 |
logger.info("Llama model loaded, loading VQ-GAN model...")
|
| 330 |
|
| 331 |
decoder_model = load_decoder_model(
|
| 332 |
config_name=args.decoder_config_name,
|
| 333 |
checkpoint_path=args.decoder_checkpoint_path,
|
| 334 |
+
device=device,
|
| 335 |
)
|
| 336 |
|
| 337 |
logger.info("Decoder model loaded, warming up...")
|
|
|
|
| 345 |
)
|
| 346 |
|
| 347 |
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
| 348 |
+
# Skip dry run in Spaces to avoid using quota or crashing due to GPU access in main process
|
| 349 |
+
if not IS_SPACES:
|
| 350 |
+
list(
|
| 351 |
+
inference_engine.inference(
|
| 352 |
+
ServeTTSRequest(
|
| 353 |
+
text="Hello world.",
|
| 354 |
+
references=[],
|
| 355 |
+
reference_id=None,
|
| 356 |
+
max_new_tokens=1024,
|
| 357 |
+
chunk_length=200,
|
| 358 |
+
top_p=0.7,
|
| 359 |
+
repetition_penalty=1.5,
|
| 360 |
+
temperature=0.7,
|
| 361 |
+
format="wav",
|
| 362 |
+
)
|
| 363 |
)
|
| 364 |
)
|
|
|
|
| 365 |
|
| 366 |
logger.info("Warming up done, launching the web UI...")
|
| 367 |
|
| 368 |
inference_fct = get_inference_wrapper(inference_engine)
|
| 369 |
|
| 370 |
+
# Decorate the inference function with GPU access if in Spaces
|
| 371 |
+
if IS_SPACES:
|
| 372 |
+
@GPU_DECORATOR
|
| 373 |
+
def gpu_inference_wrapper(*args, **kwargs):
|
| 374 |
+
decoder_model.to("cuda")
|
| 375 |
+
return inference_fct(*args, **kwargs)
|
| 376 |
+
|
| 377 |
+
final_inference_fct = gpu_inference_wrapper
|
| 378 |
+
else:
|
| 379 |
+
final_inference_fct = inference_fct
|
| 380 |
+
|
| 381 |
+
app = build_app(final_inference_fct, args.theme)
|
| 382 |
app.queue(api_open=True).launch(show_error=True, show_api=True)
|