Spaces:
Running
on
A100
Running
on
A100
feat: add Description, Format, Model Select
Browse files- acestep/api_server.py +203 -24
acestep/api_server.py
CHANGED
|
@@ -14,7 +14,6 @@ from __future__ import annotations
|
|
| 14 |
import asyncio
|
| 15 |
import json
|
| 16 |
import os
|
| 17 |
-
import re
|
| 18 |
import sys
|
| 19 |
import time
|
| 20 |
import traceback
|
|
@@ -48,6 +47,8 @@ from acestep.inference import (
|
|
| 48 |
GenerationParams,
|
| 49 |
GenerationConfig,
|
| 50 |
generate_music,
|
|
|
|
|
|
|
| 51 |
)
|
| 52 |
from acestep.gradio_ui.events.results_handlers import _build_generation_info
|
| 53 |
|
|
@@ -66,6 +67,12 @@ class GenerateMusicRequest(BaseModel):
|
|
| 66 |
thinking: bool = False
|
| 67 |
# Sample-mode requests auto-generate caption/lyrics/metas via LM (no user prompt).
|
| 68 |
sample_mode: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
bpm: Optional[int] = None
|
| 71 |
# Accept common client keys via manual parsing (see _build_req_from_mapping).
|
|
@@ -233,6 +240,22 @@ def _get_project_root() -> str:
|
|
| 233 |
return os.path.dirname(os.path.dirname(current_file))
|
| 234 |
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
def _load_project_env() -> None:
|
| 237 |
if load_dotenv is None:
|
| 238 |
return
|
|
@@ -377,6 +400,25 @@ def create_app() -> FastAPI:
|
|
| 377 |
app.state._llm_init_error = None
|
| 378 |
app.state._llm_init_lock = Lock()
|
| 379 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
max_workers = int(os.getenv("ACESTEP_API_WORKERS", "1"))
|
| 381 |
executor = ThreadPoolExecutor(max_workers=max_workers)
|
| 382 |
|
|
@@ -425,6 +467,7 @@ def create_app() -> FastAPI:
|
|
| 425 |
offload_to_cpu = _env_bool("ACESTEP_OFFLOAD_TO_CPU", False)
|
| 426 |
offload_dit_to_cpu = _env_bool("ACESTEP_OFFLOAD_DIT_TO_CPU", False)
|
| 427 |
|
|
|
|
| 428 |
status_msg, ok = h.initialize_service(
|
| 429 |
project_root=project_root,
|
| 430 |
config_path=config_path,
|
|
@@ -438,6 +481,48 @@ def create_app() -> FastAPI:
|
|
| 438 |
app.state._init_error = status_msg
|
| 439 |
raise RuntimeError(status_msg)
|
| 440 |
app.state._initialized = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
|
| 442 |
async def _cleanup_job_temp_files(job_id: str) -> None:
|
| 443 |
async with app.state.job_temp_files_lock:
|
|
@@ -450,12 +535,48 @@ def create_app() -> FastAPI:
|
|
| 450 |
|
| 451 |
async def _run_one_job(job_id: str, req: GenerateMusicRequest) -> None:
|
| 452 |
job_store: _JobStore = app.state.job_store
|
| 453 |
-
h: AceStepHandler = app.state.handler
|
| 454 |
llm: LLMHandler = app.state.llm_handler
|
| 455 |
executor: ThreadPoolExecutor = app.state.executor
|
| 456 |
|
| 457 |
await _ensure_initialized()
|
| 458 |
job_store.mark_running(job_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
|
| 460 |
def _blocking_generate() -> Dict[str, Any]:
|
| 461 |
"""Generate music using unified inference logic from acestep.inference"""
|
|
@@ -526,7 +647,7 @@ def create_app() -> FastAPI:
|
|
| 526 |
if getattr(app.state, "_llm_init_error", None):
|
| 527 |
raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
|
| 528 |
|
| 529 |
-
# Handle sample mode: generate
|
| 530 |
caption = req.caption
|
| 531 |
lyrics = req.lyrics
|
| 532 |
bpm = req.bpm
|
|
@@ -534,31 +655,85 @@ def create_app() -> FastAPI:
|
|
| 534 |
time_signature = req.time_signature
|
| 535 |
audio_duration = req.audio_duration
|
| 536 |
|
| 537 |
-
if
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
temperature=req.lm_temperature,
|
| 543 |
top_k=lm_top_k if lm_top_k > 0 else None,
|
| 544 |
top_p=lm_top_p if lm_top_p < 1.0 else None,
|
| 545 |
-
repetition_penalty=req.lm_repetition_penalty,
|
| 546 |
use_constrained_decoding=req.constrained_decoding,
|
| 547 |
-
constrained_decoding_debug=req.constrained_decoding_debug,
|
| 548 |
)
|
| 549 |
-
|
| 550 |
-
if not sample_metadata or str(sample_status).startswith("❌"):
|
| 551 |
-
raise RuntimeError(f"Sample generation failed: {sample_status}")
|
| 552 |
-
|
| 553 |
-
# Use generated values with fallback defaults
|
| 554 |
-
caption = sample_metadata.get("caption", "")
|
| 555 |
-
lyrics = sample_metadata.get("lyrics", "")
|
| 556 |
-
bpm = _to_int(sample_metadata.get("bpm"), None) or _to_int(os.getenv("ACESTEP_SAMPLE_DEFAULT_BPM", "120"), 120)
|
| 557 |
-
key_scale = sample_metadata.get("keyscale", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_KEY", "C Major")
|
| 558 |
-
time_signature = sample_metadata.get("timesignature", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_TIMESIGNATURE", "4/4")
|
| 559 |
-
audio_duration = _to_float(sample_metadata.get("duration"), None) or _to_float(os.getenv("ACESTEP_SAMPLE_DEFAULT_DURATION_SECONDS", "120"), 120.0)
|
| 560 |
|
| 561 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
|
| 563 |
print(f"[api_server] Before GenerationParams: thinking={thinking}, sample_mode={sample_mode}")
|
| 564 |
print(f"[api_server] Caption/Lyrics to use: caption_len={len(caption)}, lyrics_len={len(lyrics)}")
|
|
@@ -701,9 +876,10 @@ def create_app() -> FastAPI:
|
|
| 701 |
return None
|
| 702 |
return s
|
| 703 |
|
| 704 |
-
# Get model information
|
| 705 |
lm_model_name = os.getenv("ACESTEP_LM_MODEL_PATH", "acestep-5Hz-lm-0.6B-v3")
|
| 706 |
-
|
|
|
|
| 707 |
|
| 708 |
return {
|
| 709 |
"first_audio_path": _path_to_audio_url(first_audio) if first_audio else None,
|
|
@@ -835,6 +1011,9 @@ def create_app() -> FastAPI:
|
|
| 835 |
lyrics=str(get("lyrics", "") or ""),
|
| 836 |
thinking=_to_bool(get("thinking"), False),
|
| 837 |
sample_mode=_to_bool(_get_any("sample_mode", "sampleMode"), False),
|
|
|
|
|
|
|
|
|
|
| 838 |
bpm=normalized_bpm,
|
| 839 |
key_scale=normalized_keyscale,
|
| 840 |
time_signature=normalized_timesig,
|
|
|
|
| 14 |
import asyncio
|
| 15 |
import json
|
| 16 |
import os
|
|
|
|
| 17 |
import sys
|
| 18 |
import time
|
| 19 |
import traceback
|
|
|
|
| 47 |
GenerationParams,
|
| 48 |
GenerationConfig,
|
| 49 |
generate_music,
|
| 50 |
+
create_sample,
|
| 51 |
+
format_sample,
|
| 52 |
)
|
| 53 |
from acestep.gradio_ui.events.results_handlers import _build_generation_info
|
| 54 |
|
|
|
|
| 67 |
thinking: bool = False
|
| 68 |
# Sample-mode requests auto-generate caption/lyrics/metas via LM (no user prompt).
|
| 69 |
sample_mode: bool = False
|
| 70 |
+
# Description for sample mode: auto-generate caption/lyrics from description query
|
| 71 |
+
sample_query: str = Field(default="", description="Query/description for sample mode (use create_sample)")
|
| 72 |
+
# Whether to use format_sample() to enhance input caption/lyrics
|
| 73 |
+
use_format: bool = Field(default=False, description="Use format_sample() to enhance input (default: False)")
|
| 74 |
+
# Model name for multi-model support (select which DiT model to use)
|
| 75 |
+
model: Optional[str] = Field(default=None, description="Model name to use (e.g., 'acestep-v15-turbo')")
|
| 76 |
|
| 77 |
bpm: Optional[int] = None
|
| 78 |
# Accept common client keys via manual parsing (see _build_req_from_mapping).
|
|
|
|
| 240 |
return os.path.dirname(os.path.dirname(current_file))
|
| 241 |
|
| 242 |
|
| 243 |
+
def _get_model_name(config_path: str) -> str:
|
| 244 |
+
"""
|
| 245 |
+
Extract model name from config_path.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
config_path: Path like "acestep-v15-turbo" or "/path/to/acestep-v15-turbo"
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
Model name (last directory name from config_path)
|
| 252 |
+
"""
|
| 253 |
+
if not config_path:
|
| 254 |
+
return ""
|
| 255 |
+
normalized = config_path.rstrip("/\\")
|
| 256 |
+
return os.path.basename(normalized)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
def _load_project_env() -> None:
|
| 260 |
if load_dotenv is None:
|
| 261 |
return
|
|
|
|
| 400 |
app.state._llm_init_error = None
|
| 401 |
app.state._llm_init_lock = Lock()
|
| 402 |
|
| 403 |
+
# Multi-model support: secondary DiT handlers
|
| 404 |
+
handler2 = None
|
| 405 |
+
handler3 = None
|
| 406 |
+
config_path2 = os.getenv("ACESTEP_CONFIG_PATH2", "").strip()
|
| 407 |
+
config_path3 = os.getenv("ACESTEP_CONFIG_PATH3", "").strip()
|
| 408 |
+
|
| 409 |
+
if config_path2:
|
| 410 |
+
handler2 = AceStepHandler()
|
| 411 |
+
if config_path3:
|
| 412 |
+
handler3 = AceStepHandler()
|
| 413 |
+
|
| 414 |
+
app.state.handler2 = handler2
|
| 415 |
+
app.state.handler3 = handler3
|
| 416 |
+
app.state._initialized2 = False
|
| 417 |
+
app.state._initialized3 = False
|
| 418 |
+
app.state._config_path = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo-rl")
|
| 419 |
+
app.state._config_path2 = config_path2
|
| 420 |
+
app.state._config_path3 = config_path3
|
| 421 |
+
|
| 422 |
max_workers = int(os.getenv("ACESTEP_API_WORKERS", "1"))
|
| 423 |
executor = ThreadPoolExecutor(max_workers=max_workers)
|
| 424 |
|
|
|
|
| 467 |
offload_to_cpu = _env_bool("ACESTEP_OFFLOAD_TO_CPU", False)
|
| 468 |
offload_dit_to_cpu = _env_bool("ACESTEP_OFFLOAD_DIT_TO_CPU", False)
|
| 469 |
|
| 470 |
+
# Initialize primary model
|
| 471 |
status_msg, ok = h.initialize_service(
|
| 472 |
project_root=project_root,
|
| 473 |
config_path=config_path,
|
|
|
|
| 481 |
app.state._init_error = status_msg
|
| 482 |
raise RuntimeError(status_msg)
|
| 483 |
app.state._initialized = True
|
| 484 |
+
|
| 485 |
+
# Initialize secondary model if configured
|
| 486 |
+
if app.state.handler2 and app.state._config_path2:
|
| 487 |
+
try:
|
| 488 |
+
status_msg2, ok2 = app.state.handler2.initialize_service(
|
| 489 |
+
project_root=project_root,
|
| 490 |
+
config_path=app.state._config_path2,
|
| 491 |
+
device=device,
|
| 492 |
+
use_flash_attention=use_flash_attention,
|
| 493 |
+
compile_model=False,
|
| 494 |
+
offload_to_cpu=offload_to_cpu,
|
| 495 |
+
offload_dit_to_cpu=offload_dit_to_cpu,
|
| 496 |
+
)
|
| 497 |
+
app.state._initialized2 = ok2
|
| 498 |
+
if ok2:
|
| 499 |
+
print(f"[API Server] Secondary model loaded: {_get_model_name(app.state._config_path2)}")
|
| 500 |
+
else:
|
| 501 |
+
print(f"[API Server] Warning: Secondary model failed to load: {status_msg2}")
|
| 502 |
+
except Exception as e:
|
| 503 |
+
print(f"[API Server] Warning: Failed to initialize secondary model: {e}")
|
| 504 |
+
app.state._initialized2 = False
|
| 505 |
+
|
| 506 |
+
# Initialize third model if configured
|
| 507 |
+
if app.state.handler3 and app.state._config_path3:
|
| 508 |
+
try:
|
| 509 |
+
status_msg3, ok3 = app.state.handler3.initialize_service(
|
| 510 |
+
project_root=project_root,
|
| 511 |
+
config_path=app.state._config_path3,
|
| 512 |
+
device=device,
|
| 513 |
+
use_flash_attention=use_flash_attention,
|
| 514 |
+
compile_model=False,
|
| 515 |
+
offload_to_cpu=offload_to_cpu,
|
| 516 |
+
offload_dit_to_cpu=offload_dit_to_cpu,
|
| 517 |
+
)
|
| 518 |
+
app.state._initialized3 = ok3
|
| 519 |
+
if ok3:
|
| 520 |
+
print(f"[API Server] Third model loaded: {_get_model_name(app.state._config_path3)}")
|
| 521 |
+
else:
|
| 522 |
+
print(f"[API Server] Warning: Third model failed to load: {status_msg3}")
|
| 523 |
+
except Exception as e:
|
| 524 |
+
print(f"[API Server] Warning: Failed to initialize third model: {e}")
|
| 525 |
+
app.state._initialized3 = False
|
| 526 |
|
| 527 |
async def _cleanup_job_temp_files(job_id: str) -> None:
|
| 528 |
async with app.state.job_temp_files_lock:
|
|
|
|
| 535 |
|
| 536 |
async def _run_one_job(job_id: str, req: GenerateMusicRequest) -> None:
|
| 537 |
job_store: _JobStore = app.state.job_store
|
|
|
|
| 538 |
llm: LLMHandler = app.state.llm_handler
|
| 539 |
executor: ThreadPoolExecutor = app.state.executor
|
| 540 |
|
| 541 |
await _ensure_initialized()
|
| 542 |
job_store.mark_running(job_id)
|
| 543 |
+
|
| 544 |
+
# Select DiT handler based on user's model choice
|
| 545 |
+
# Default: use primary handler
|
| 546 |
+
selected_handler: AceStepHandler = app.state.handler
|
| 547 |
+
selected_model_name = _get_model_name(app.state._config_path)
|
| 548 |
+
|
| 549 |
+
if req.model:
|
| 550 |
+
model_matched = False
|
| 551 |
+
|
| 552 |
+
# Check if it matches the second model
|
| 553 |
+
if app.state.handler2 and getattr(app.state, "_initialized2", False):
|
| 554 |
+
model2_name = _get_model_name(app.state._config_path2)
|
| 555 |
+
if req.model == model2_name:
|
| 556 |
+
selected_handler = app.state.handler2
|
| 557 |
+
selected_model_name = model2_name
|
| 558 |
+
model_matched = True
|
| 559 |
+
print(f"[API Server] Job {job_id}: Using second model: {model2_name}")
|
| 560 |
+
|
| 561 |
+
# Check if it matches the third model
|
| 562 |
+
if not model_matched and app.state.handler3 and getattr(app.state, "_initialized3", False):
|
| 563 |
+
model3_name = _get_model_name(app.state._config_path3)
|
| 564 |
+
if req.model == model3_name:
|
| 565 |
+
selected_handler = app.state.handler3
|
| 566 |
+
selected_model_name = model3_name
|
| 567 |
+
model_matched = True
|
| 568 |
+
print(f"[API Server] Job {job_id}: Using third model: {model3_name}")
|
| 569 |
+
|
| 570 |
+
if not model_matched:
|
| 571 |
+
available_models = [_get_model_name(app.state._config_path)]
|
| 572 |
+
if app.state.handler2 and getattr(app.state, "_initialized2", False):
|
| 573 |
+
available_models.append(_get_model_name(app.state._config_path2))
|
| 574 |
+
if app.state.handler3 and getattr(app.state, "_initialized3", False):
|
| 575 |
+
available_models.append(_get_model_name(app.state._config_path3))
|
| 576 |
+
print(f"[API Server] Job {job_id}: Model '{req.model}' not found in {available_models}, using primary: {selected_model_name}")
|
| 577 |
+
|
| 578 |
+
# Use selected handler for generation
|
| 579 |
+
h: AceStepHandler = selected_handler
|
| 580 |
|
| 581 |
def _blocking_generate() -> Dict[str, Any]:
|
| 582 |
"""Generate music using unified inference logic from acestep.inference"""
|
|
|
|
| 647 |
if getattr(app.state, "_llm_init_error", None):
|
| 648 |
raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
|
| 649 |
|
| 650 |
+
# Handle sample mode or description: generate caption/lyrics/metas via LM
|
| 651 |
caption = req.caption
|
| 652 |
lyrics = req.lyrics
|
| 653 |
bpm = req.bpm
|
|
|
|
| 655 |
time_signature = req.time_signature
|
| 656 |
audio_duration = req.audio_duration
|
| 657 |
|
| 658 |
+
# Check if sample_query (description) is provided for create_sample
|
| 659 |
+
has_sample_query = bool(req.sample_query and req.sample_query.strip())
|
| 660 |
+
|
| 661 |
+
if sample_mode or has_sample_query:
|
| 662 |
+
if has_sample_query:
|
| 663 |
+
# Use create_sample() with description query
|
| 664 |
+
print(f"[api_server] Description mode: generating sample from query: {req.sample_query[:100]}")
|
| 665 |
+
sample_result = create_sample(
|
| 666 |
+
llm_handler=llm,
|
| 667 |
+
query=req.sample_query,
|
| 668 |
+
instrumental=False, # Could be extracted from description
|
| 669 |
+
vocal_language=req.vocal_language if req.vocal_language != "en" else None,
|
| 670 |
+
temperature=req.lm_temperature,
|
| 671 |
+
top_k=lm_top_k if lm_top_k > 0 else None,
|
| 672 |
+
top_p=lm_top_p if lm_top_p < 1.0 else None,
|
| 673 |
+
use_constrained_decoding=req.constrained_decoding,
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
if not sample_result.success:
|
| 677 |
+
raise RuntimeError(f"create_sample failed: {sample_result.error or sample_result.status_message}")
|
| 678 |
+
|
| 679 |
+
# Use generated sample data
|
| 680 |
+
caption = sample_result.caption
|
| 681 |
+
lyrics = sample_result.lyrics
|
| 682 |
+
bpm = sample_result.bpm
|
| 683 |
+
key_scale = sample_result.keyscale
|
| 684 |
+
time_signature = sample_result.timesignature
|
| 685 |
+
audio_duration = sample_result.duration
|
| 686 |
+
|
| 687 |
+
print(f"[api_server] Sample from description generated: caption_len={len(caption)}, lyrics_len={len(lyrics)}, bpm={bpm}")
|
| 688 |
+
else:
|
| 689 |
+
# Original sample_mode behavior: random generation
|
| 690 |
+
print("[api_server] Sample mode: generating random caption/lyrics via LM")
|
| 691 |
+
sample_metadata, sample_status = llm.understand_audio_from_codes(
|
| 692 |
+
audio_codes="NO USER INPUT",
|
| 693 |
+
temperature=req.lm_temperature,
|
| 694 |
+
top_k=lm_top_k if lm_top_k > 0 else None,
|
| 695 |
+
top_p=lm_top_p if lm_top_p < 1.0 else None,
|
| 696 |
+
repetition_penalty=req.lm_repetition_penalty,
|
| 697 |
+
use_constrained_decoding=req.constrained_decoding,
|
| 698 |
+
constrained_decoding_debug=req.constrained_decoding_debug,
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
if not sample_metadata or str(sample_status).startswith("❌"):
|
| 702 |
+
raise RuntimeError(f"Sample generation failed: {sample_status}")
|
| 703 |
+
|
| 704 |
+
# Use generated values with fallback defaults
|
| 705 |
+
caption = sample_metadata.get("caption", "")
|
| 706 |
+
lyrics = sample_metadata.get("lyrics", "")
|
| 707 |
+
bpm = _to_int(sample_metadata.get("bpm"), None) or _to_int(os.getenv("ACESTEP_SAMPLE_DEFAULT_BPM", "120"), 120)
|
| 708 |
+
key_scale = sample_metadata.get("keyscale", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_KEY", "C Major")
|
| 709 |
+
time_signature = sample_metadata.get("timesignature", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_TIMESIGNATURE", "4/4")
|
| 710 |
+
audio_duration = _to_float(sample_metadata.get("duration"), None) or _to_float(os.getenv("ACESTEP_SAMPLE_DEFAULT_DURATION_SECONDS", "120"), 120.0)
|
| 711 |
+
|
| 712 |
+
print(f"[api_server] Sample generated: caption_len={len(caption)}, lyrics_len={len(lyrics)}, bpm={bpm}, duration={audio_duration}")
|
| 713 |
+
|
| 714 |
+
# Apply format_sample() if use_format is True and caption/lyrics are provided
|
| 715 |
+
if req.use_format and (caption or lyrics):
|
| 716 |
+
print(f"[api_server] Applying format_sample to enhance input...")
|
| 717 |
+
_ensure_llm_ready()
|
| 718 |
+
if getattr(app.state, "_llm_init_error", None):
|
| 719 |
+
raise RuntimeError(f"5Hz LM init failed (needed for format): {app.state._llm_init_error}")
|
| 720 |
+
|
| 721 |
+
format_result = format_sample(
|
| 722 |
+
llm_handler=llm,
|
| 723 |
+
caption=caption,
|
| 724 |
+
lyrics=lyrics,
|
| 725 |
temperature=req.lm_temperature,
|
| 726 |
top_k=lm_top_k if lm_top_k > 0 else None,
|
| 727 |
top_p=lm_top_p if lm_top_p < 1.0 else None,
|
|
|
|
| 728 |
use_constrained_decoding=req.constrained_decoding,
|
|
|
|
| 729 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 730 |
|
| 731 |
+
if format_result.success:
|
| 732 |
+
caption = format_result.caption
|
| 733 |
+
lyrics = format_result.lyrics
|
| 734 |
+
print(f"[api_server] Format applied: new caption_len={len(caption)}, lyrics_len={len(lyrics)}")
|
| 735 |
+
else:
|
| 736 |
+
print(f"[api_server] Warning: format_sample failed: {format_result.error}, using original input")
|
| 737 |
|
| 738 |
print(f"[api_server] Before GenerationParams: thinking={thinking}, sample_mode={sample_mode}")
|
| 739 |
print(f"[api_server] Caption/Lyrics to use: caption_len={len(caption)}, lyrics_len={len(lyrics)}")
|
|
|
|
| 876 |
return None
|
| 877 |
return s
|
| 878 |
|
| 879 |
+
# Get model information
|
| 880 |
lm_model_name = os.getenv("ACESTEP_LM_MODEL_PATH", "acestep-5Hz-lm-0.6B-v3")
|
| 881 |
+
# Use selected_model_name (set at the beginning of _run_one_job)
|
| 882 |
+
dit_model_name = selected_model_name
|
| 883 |
|
| 884 |
return {
|
| 885 |
"first_audio_path": _path_to_audio_url(first_audio) if first_audio else None,
|
|
|
|
| 1011 |
lyrics=str(get("lyrics", "") or ""),
|
| 1012 |
thinking=_to_bool(get("thinking"), False),
|
| 1013 |
sample_mode=_to_bool(_get_any("sample_mode", "sampleMode"), False),
|
| 1014 |
+
sample_query=str(_get_any("sample_query", "sampleQuery", "description", "desc", default="") or ""),
|
| 1015 |
+
use_format=_to_bool(_get_any("use_format", "useFormat", "format"), False),
|
| 1016 |
+
model=str(_get_any("model", "dit_model", "ditModel", default="") or "").strip() or None,
|
| 1017 |
bpm=normalized_bpm,
|
| 1018 |
key_scale=normalized_keyscale,
|
| 1019 |
time_signature=normalized_timesig,
|