Sayoyo commited on
Commit
7ef4a67
·
1 Parent(s): 574b7a9

feat: add Description, Format, Model Select

Browse files
Files changed (1) hide show
  1. acestep/api_server.py +203 -24
acestep/api_server.py CHANGED
@@ -14,7 +14,6 @@ from __future__ import annotations
14
  import asyncio
15
  import json
16
  import os
17
- import re
18
  import sys
19
  import time
20
  import traceback
@@ -48,6 +47,8 @@ from acestep.inference import (
48
  GenerationParams,
49
  GenerationConfig,
50
  generate_music,
 
 
51
  )
52
  from acestep.gradio_ui.events.results_handlers import _build_generation_info
53
 
@@ -66,6 +67,12 @@ class GenerateMusicRequest(BaseModel):
66
  thinking: bool = False
67
  # Sample-mode requests auto-generate caption/lyrics/metas via LM (no user prompt).
68
  sample_mode: bool = False
 
 
 
 
 
 
69
 
70
  bpm: Optional[int] = None
71
  # Accept common client keys via manual parsing (see _build_req_from_mapping).
@@ -233,6 +240,22 @@ def _get_project_root() -> str:
233
  return os.path.dirname(os.path.dirname(current_file))
234
 
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  def _load_project_env() -> None:
237
  if load_dotenv is None:
238
  return
@@ -377,6 +400,25 @@ def create_app() -> FastAPI:
377
  app.state._llm_init_error = None
378
  app.state._llm_init_lock = Lock()
379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  max_workers = int(os.getenv("ACESTEP_API_WORKERS", "1"))
381
  executor = ThreadPoolExecutor(max_workers=max_workers)
382
 
@@ -425,6 +467,7 @@ def create_app() -> FastAPI:
425
  offload_to_cpu = _env_bool("ACESTEP_OFFLOAD_TO_CPU", False)
426
  offload_dit_to_cpu = _env_bool("ACESTEP_OFFLOAD_DIT_TO_CPU", False)
427
 
 
428
  status_msg, ok = h.initialize_service(
429
  project_root=project_root,
430
  config_path=config_path,
@@ -438,6 +481,48 @@ def create_app() -> FastAPI:
438
  app.state._init_error = status_msg
439
  raise RuntimeError(status_msg)
440
  app.state._initialized = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
  async def _cleanup_job_temp_files(job_id: str) -> None:
443
  async with app.state.job_temp_files_lock:
@@ -450,12 +535,48 @@ def create_app() -> FastAPI:
450
 
451
  async def _run_one_job(job_id: str, req: GenerateMusicRequest) -> None:
452
  job_store: _JobStore = app.state.job_store
453
- h: AceStepHandler = app.state.handler
454
  llm: LLMHandler = app.state.llm_handler
455
  executor: ThreadPoolExecutor = app.state.executor
456
 
457
  await _ensure_initialized()
458
  job_store.mark_running(job_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
 
460
  def _blocking_generate() -> Dict[str, Any]:
461
  """Generate music using unified inference logic from acestep.inference"""
@@ -526,7 +647,7 @@ def create_app() -> FastAPI:
526
  if getattr(app.state, "_llm_init_error", None):
527
  raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
528
 
529
- # Handle sample mode: generate random caption/lyrics first
530
  caption = req.caption
531
  lyrics = req.lyrics
532
  bpm = req.bpm
@@ -534,31 +655,85 @@ def create_app() -> FastAPI:
534
  time_signature = req.time_signature
535
  audio_duration = req.audio_duration
536
 
537
- if sample_mode:
538
- print("[api_server] Sample mode: generating random caption/lyrics via LM")
539
- # Note: understand_audio_from_codes does not support cfg_scale or negative_prompt
540
- sample_metadata, sample_status = llm.understand_audio_from_codes(
541
- audio_codes="NO USER INPUT",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542
  temperature=req.lm_temperature,
543
  top_k=lm_top_k if lm_top_k > 0 else None,
544
  top_p=lm_top_p if lm_top_p < 1.0 else None,
545
- repetition_penalty=req.lm_repetition_penalty,
546
  use_constrained_decoding=req.constrained_decoding,
547
- constrained_decoding_debug=req.constrained_decoding_debug,
548
  )
549
-
550
- if not sample_metadata or str(sample_status).startswith("❌"):
551
- raise RuntimeError(f"Sample generation failed: {sample_status}")
552
-
553
- # Use generated values with fallback defaults
554
- caption = sample_metadata.get("caption", "")
555
- lyrics = sample_metadata.get("lyrics", "")
556
- bpm = _to_int(sample_metadata.get("bpm"), None) or _to_int(os.getenv("ACESTEP_SAMPLE_DEFAULT_BPM", "120"), 120)
557
- key_scale = sample_metadata.get("keyscale", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_KEY", "C Major")
558
- time_signature = sample_metadata.get("timesignature", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_TIMESIGNATURE", "4/4")
559
- audio_duration = _to_float(sample_metadata.get("duration"), None) or _to_float(os.getenv("ACESTEP_SAMPLE_DEFAULT_DURATION_SECONDS", "120"), 120.0)
560
 
561
- print(f"[api_server] Sample generated: caption_len={len(caption)}, lyrics_len={len(lyrics)}, bpm={bpm}, duration={audio_duration}")
 
 
 
 
 
562
 
563
  print(f"[api_server] Before GenerationParams: thinking={thinking}, sample_mode={sample_mode}")
564
  print(f"[api_server] Caption/Lyrics to use: caption_len={len(caption)}, lyrics_len={len(lyrics)}")
@@ -701,9 +876,10 @@ def create_app() -> FastAPI:
701
  return None
702
  return s
703
 
704
- # Get model information from environment variables
705
  lm_model_name = os.getenv("ACESTEP_LM_MODEL_PATH", "acestep-5Hz-lm-0.6B-v3")
706
- dit_model_name = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo-rl")
 
707
 
708
  return {
709
  "first_audio_path": _path_to_audio_url(first_audio) if first_audio else None,
@@ -835,6 +1011,9 @@ def create_app() -> FastAPI:
835
  lyrics=str(get("lyrics", "") or ""),
836
  thinking=_to_bool(get("thinking"), False),
837
  sample_mode=_to_bool(_get_any("sample_mode", "sampleMode"), False),
 
 
 
838
  bpm=normalized_bpm,
839
  key_scale=normalized_keyscale,
840
  time_signature=normalized_timesig,
 
14
  import asyncio
15
  import json
16
  import os
 
17
  import sys
18
  import time
19
  import traceback
 
47
  GenerationParams,
48
  GenerationConfig,
49
  generate_music,
50
+ create_sample,
51
+ format_sample,
52
  )
53
  from acestep.gradio_ui.events.results_handlers import _build_generation_info
54
 
 
67
  thinking: bool = False
68
  # Sample-mode requests auto-generate caption/lyrics/metas via LM (no user prompt).
69
  sample_mode: bool = False
70
+ # Description for sample mode: auto-generate caption/lyrics from description query
71
+ sample_query: str = Field(default="", description="Query/description for sample mode (use create_sample)")
72
+ # Whether to use format_sample() to enhance input caption/lyrics
73
+ use_format: bool = Field(default=False, description="Use format_sample() to enhance input (default: False)")
74
+ # Model name for multi-model support (select which DiT model to use)
75
+ model: Optional[str] = Field(default=None, description="Model name to use (e.g., 'acestep-v15-turbo')")
76
 
77
  bpm: Optional[int] = None
78
  # Accept common client keys via manual parsing (see _build_req_from_mapping).
 
240
  return os.path.dirname(os.path.dirname(current_file))
241
 
242
 
243
+ def _get_model_name(config_path: str) -> str:
244
+ """
245
+ Extract model name from config_path.
246
+
247
+ Args:
248
+ config_path: Path like "acestep-v15-turbo" or "/path/to/acestep-v15-turbo"
249
+
250
+ Returns:
251
+ Model name (last directory name from config_path)
252
+ """
253
+ if not config_path:
254
+ return ""
255
+ normalized = config_path.rstrip("/\\")
256
+ return os.path.basename(normalized)
257
+
258
+
259
  def _load_project_env() -> None:
260
  if load_dotenv is None:
261
  return
 
400
  app.state._llm_init_error = None
401
  app.state._llm_init_lock = Lock()
402
 
403
+ # Multi-model support: secondary DiT handlers
404
+ handler2 = None
405
+ handler3 = None
406
+ config_path2 = os.getenv("ACESTEP_CONFIG_PATH2", "").strip()
407
+ config_path3 = os.getenv("ACESTEP_CONFIG_PATH3", "").strip()
408
+
409
+ if config_path2:
410
+ handler2 = AceStepHandler()
411
+ if config_path3:
412
+ handler3 = AceStepHandler()
413
+
414
+ app.state.handler2 = handler2
415
+ app.state.handler3 = handler3
416
+ app.state._initialized2 = False
417
+ app.state._initialized3 = False
418
+ app.state._config_path = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo-rl")
419
+ app.state._config_path2 = config_path2
420
+ app.state._config_path3 = config_path3
421
+
422
  max_workers = int(os.getenv("ACESTEP_API_WORKERS", "1"))
423
  executor = ThreadPoolExecutor(max_workers=max_workers)
424
 
 
467
  offload_to_cpu = _env_bool("ACESTEP_OFFLOAD_TO_CPU", False)
468
  offload_dit_to_cpu = _env_bool("ACESTEP_OFFLOAD_DIT_TO_CPU", False)
469
 
470
+ # Initialize primary model
471
  status_msg, ok = h.initialize_service(
472
  project_root=project_root,
473
  config_path=config_path,
 
481
  app.state._init_error = status_msg
482
  raise RuntimeError(status_msg)
483
  app.state._initialized = True
484
+
485
+ # Initialize secondary model if configured
486
+ if app.state.handler2 and app.state._config_path2:
487
+ try:
488
+ status_msg2, ok2 = app.state.handler2.initialize_service(
489
+ project_root=project_root,
490
+ config_path=app.state._config_path2,
491
+ device=device,
492
+ use_flash_attention=use_flash_attention,
493
+ compile_model=False,
494
+ offload_to_cpu=offload_to_cpu,
495
+ offload_dit_to_cpu=offload_dit_to_cpu,
496
+ )
497
+ app.state._initialized2 = ok2
498
+ if ok2:
499
+ print(f"[API Server] Secondary model loaded: {_get_model_name(app.state._config_path2)}")
500
+ else:
501
+ print(f"[API Server] Warning: Secondary model failed to load: {status_msg2}")
502
+ except Exception as e:
503
+ print(f"[API Server] Warning: Failed to initialize secondary model: {e}")
504
+ app.state._initialized2 = False
505
+
506
+ # Initialize third model if configured
507
+ if app.state.handler3 and app.state._config_path3:
508
+ try:
509
+ status_msg3, ok3 = app.state.handler3.initialize_service(
510
+ project_root=project_root,
511
+ config_path=app.state._config_path3,
512
+ device=device,
513
+ use_flash_attention=use_flash_attention,
514
+ compile_model=False,
515
+ offload_to_cpu=offload_to_cpu,
516
+ offload_dit_to_cpu=offload_dit_to_cpu,
517
+ )
518
+ app.state._initialized3 = ok3
519
+ if ok3:
520
+ print(f"[API Server] Third model loaded: {_get_model_name(app.state._config_path3)}")
521
+ else:
522
+ print(f"[API Server] Warning: Third model failed to load: {status_msg3}")
523
+ except Exception as e:
524
+ print(f"[API Server] Warning: Failed to initialize third model: {e}")
525
+ app.state._initialized3 = False
526
 
527
  async def _cleanup_job_temp_files(job_id: str) -> None:
528
  async with app.state.job_temp_files_lock:
 
535
 
536
  async def _run_one_job(job_id: str, req: GenerateMusicRequest) -> None:
537
  job_store: _JobStore = app.state.job_store
 
538
  llm: LLMHandler = app.state.llm_handler
539
  executor: ThreadPoolExecutor = app.state.executor
540
 
541
  await _ensure_initialized()
542
  job_store.mark_running(job_id)
543
+
544
+ # Select DiT handler based on user's model choice
545
+ # Default: use primary handler
546
+ selected_handler: AceStepHandler = app.state.handler
547
+ selected_model_name = _get_model_name(app.state._config_path)
548
+
549
+ if req.model:
550
+ model_matched = False
551
+
552
+ # Check if it matches the second model
553
+ if app.state.handler2 and getattr(app.state, "_initialized2", False):
554
+ model2_name = _get_model_name(app.state._config_path2)
555
+ if req.model == model2_name:
556
+ selected_handler = app.state.handler2
557
+ selected_model_name = model2_name
558
+ model_matched = True
559
+ print(f"[API Server] Job {job_id}: Using second model: {model2_name}")
560
+
561
+ # Check if it matches the third model
562
+ if not model_matched and app.state.handler3 and getattr(app.state, "_initialized3", False):
563
+ model3_name = _get_model_name(app.state._config_path3)
564
+ if req.model == model3_name:
565
+ selected_handler = app.state.handler3
566
+ selected_model_name = model3_name
567
+ model_matched = True
568
+ print(f"[API Server] Job {job_id}: Using third model: {model3_name}")
569
+
570
+ if not model_matched:
571
+ available_models = [_get_model_name(app.state._config_path)]
572
+ if app.state.handler2 and getattr(app.state, "_initialized2", False):
573
+ available_models.append(_get_model_name(app.state._config_path2))
574
+ if app.state.handler3 and getattr(app.state, "_initialized3", False):
575
+ available_models.append(_get_model_name(app.state._config_path3))
576
+ print(f"[API Server] Job {job_id}: Model '{req.model}' not found in {available_models}, using primary: {selected_model_name}")
577
+
578
+ # Use selected handler for generation
579
+ h: AceStepHandler = selected_handler
580
 
581
  def _blocking_generate() -> Dict[str, Any]:
582
  """Generate music using unified inference logic from acestep.inference"""
 
647
  if getattr(app.state, "_llm_init_error", None):
648
  raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
649
 
650
+ # Handle sample mode or description: generate caption/lyrics/metas via LM
651
  caption = req.caption
652
  lyrics = req.lyrics
653
  bpm = req.bpm
 
655
  time_signature = req.time_signature
656
  audio_duration = req.audio_duration
657
 
658
+ # Check if sample_query (description) is provided for create_sample
659
+ has_sample_query = bool(req.sample_query and req.sample_query.strip())
660
+
661
+ if sample_mode or has_sample_query:
662
+ if has_sample_query:
663
+ # Use create_sample() with description query
664
+ print(f"[api_server] Description mode: generating sample from query: {req.sample_query[:100]}")
665
+ sample_result = create_sample(
666
+ llm_handler=llm,
667
+ query=req.sample_query,
668
+ instrumental=False, # Could be extracted from description
669
+ vocal_language=req.vocal_language if req.vocal_language != "en" else None,
670
+ temperature=req.lm_temperature,
671
+ top_k=lm_top_k if lm_top_k > 0 else None,
672
+ top_p=lm_top_p if lm_top_p < 1.0 else None,
673
+ use_constrained_decoding=req.constrained_decoding,
674
+ )
675
+
676
+ if not sample_result.success:
677
+ raise RuntimeError(f"create_sample failed: {sample_result.error or sample_result.status_message}")
678
+
679
+ # Use generated sample data
680
+ caption = sample_result.caption
681
+ lyrics = sample_result.lyrics
682
+ bpm = sample_result.bpm
683
+ key_scale = sample_result.keyscale
684
+ time_signature = sample_result.timesignature
685
+ audio_duration = sample_result.duration
686
+
687
+ print(f"[api_server] Sample from description generated: caption_len={len(caption)}, lyrics_len={len(lyrics)}, bpm={bpm}")
688
+ else:
689
+ # Original sample_mode behavior: random generation
690
+ print("[api_server] Sample mode: generating random caption/lyrics via LM")
691
+ sample_metadata, sample_status = llm.understand_audio_from_codes(
692
+ audio_codes="NO USER INPUT",
693
+ temperature=req.lm_temperature,
694
+ top_k=lm_top_k if lm_top_k > 0 else None,
695
+ top_p=lm_top_p if lm_top_p < 1.0 else None,
696
+ repetition_penalty=req.lm_repetition_penalty,
697
+ use_constrained_decoding=req.constrained_decoding,
698
+ constrained_decoding_debug=req.constrained_decoding_debug,
699
+ )
700
+
701
+ if not sample_metadata or str(sample_status).startswith("❌"):
702
+ raise RuntimeError(f"Sample generation failed: {sample_status}")
703
+
704
+ # Use generated values with fallback defaults
705
+ caption = sample_metadata.get("caption", "")
706
+ lyrics = sample_metadata.get("lyrics", "")
707
+ bpm = _to_int(sample_metadata.get("bpm"), None) or _to_int(os.getenv("ACESTEP_SAMPLE_DEFAULT_BPM", "120"), 120)
708
+ key_scale = sample_metadata.get("keyscale", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_KEY", "C Major")
709
+ time_signature = sample_metadata.get("timesignature", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_TIMESIGNATURE", "4/4")
710
+ audio_duration = _to_float(sample_metadata.get("duration"), None) or _to_float(os.getenv("ACESTEP_SAMPLE_DEFAULT_DURATION_SECONDS", "120"), 120.0)
711
+
712
+ print(f"[api_server] Sample generated: caption_len={len(caption)}, lyrics_len={len(lyrics)}, bpm={bpm}, duration={audio_duration}")
713
+
714
+ # Apply format_sample() if use_format is True and caption/lyrics are provided
715
+ if req.use_format and (caption or lyrics):
716
+ print(f"[api_server] Applying format_sample to enhance input...")
717
+ _ensure_llm_ready()
718
+ if getattr(app.state, "_llm_init_error", None):
719
+ raise RuntimeError(f"5Hz LM init failed (needed for format): {app.state._llm_init_error}")
720
+
721
+ format_result = format_sample(
722
+ llm_handler=llm,
723
+ caption=caption,
724
+ lyrics=lyrics,
725
  temperature=req.lm_temperature,
726
  top_k=lm_top_k if lm_top_k > 0 else None,
727
  top_p=lm_top_p if lm_top_p < 1.0 else None,
 
728
  use_constrained_decoding=req.constrained_decoding,
 
729
  )
 
 
 
 
 
 
 
 
 
 
 
730
 
731
+ if format_result.success:
732
+ caption = format_result.caption
733
+ lyrics = format_result.lyrics
734
+ print(f"[api_server] Format applied: new caption_len={len(caption)}, lyrics_len={len(lyrics)}")
735
+ else:
736
+ print(f"[api_server] Warning: format_sample failed: {format_result.error}, using original input")
737
 
738
  print(f"[api_server] Before GenerationParams: thinking={thinking}, sample_mode={sample_mode}")
739
  print(f"[api_server] Caption/Lyrics to use: caption_len={len(caption)}, lyrics_len={len(lyrics)}")
 
876
  return None
877
  return s
878
 
879
+ # Get model information
880
  lm_model_name = os.getenv("ACESTEP_LM_MODEL_PATH", "acestep-5Hz-lm-0.6B-v3")
881
+ # Use selected_model_name (set at the beginning of _run_one_job)
882
+ dit_model_name = selected_model_name
883
 
884
  return {
885
  "first_audio_path": _path_to_audio_url(first_audio) if first_audio else None,
 
1011
  lyrics=str(get("lyrics", "") or ""),
1012
  thinking=_to_bool(get("thinking"), False),
1013
  sample_mode=_to_bool(_get_any("sample_mode", "sampleMode"), False),
1014
+ sample_query=str(_get_any("sample_query", "sampleQuery", "description", "desc", default="") or ""),
1015
+ use_format=_to_bool(_get_any("use_format", "useFormat", "format"), False),
1016
+ model=str(_get_any("model", "dit_model", "ditModel", default="") or "").strip() or None,
1017
  bpm=normalized_bpm,
1018
  key_scale=normalized_keyscale,
1019
  time_signature=normalized_timesig,