Spaces:
Running
on
A100
Running
on
A100
Merge pull request #12 from ace-step/fix_api_simple_mode_lang
Browse files- acestep/api_server.py +92 -2
- acestep/handler.py +1 -1
acestep/api_server.py
CHANGED
|
@@ -53,6 +53,77 @@ from acestep.inference import (
|
|
| 53 |
from acestep.gradio_ui.events.results_handlers import _build_generation_info
|
| 54 |
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
JobStatus = Literal["queued", "running", "succeeded", "failed"]
|
| 57 |
|
| 58 |
|
|
@@ -678,11 +749,30 @@ def create_app() -> FastAPI:
|
|
| 678 |
if has_sample_query:
|
| 679 |
# Use create_sample() with description query
|
| 680 |
print(f"[api_server] Description mode: generating sample from query: {req.sample_query[:100]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 681 |
sample_result = create_sample(
|
| 682 |
llm_handler=llm,
|
| 683 |
query=req.sample_query,
|
| 684 |
-
instrumental=
|
| 685 |
-
vocal_language=
|
| 686 |
temperature=req.lm_temperature,
|
| 687 |
top_k=lm_top_k if lm_top_k > 0 else None,
|
| 688 |
top_p=lm_top_p if lm_top_p < 1.0 else None,
|
|
|
|
| 53 |
from acestep.gradio_ui.events.results_handlers import _build_generation_info
|
| 54 |
|
| 55 |
|
| 56 |
+
def _parse_description_hints(description: str) -> tuple[Optional[str], bool]:
|
| 57 |
+
"""
|
| 58 |
+
Parse a description string to extract language code and instrumental flag.
|
| 59 |
+
|
| 60 |
+
This function analyzes user descriptions like "Pop rock. English" or "piano solo"
|
| 61 |
+
to detect:
|
| 62 |
+
- Language: Maps language names to ISO codes (e.g., "English" -> "en")
|
| 63 |
+
- Instrumental: Detects patterns indicating instrumental/no-vocal music
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
description: User's natural language music description
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
(language_code, is_instrumental) tuple:
|
| 70 |
+
- language_code: ISO language code (e.g., "en", "zh") or None if not detected
|
| 71 |
+
- is_instrumental: True if description indicates instrumental music
|
| 72 |
+
"""
|
| 73 |
+
import re
|
| 74 |
+
|
| 75 |
+
if not description:
|
| 76 |
+
return None, False
|
| 77 |
+
|
| 78 |
+
description_lower = description.lower().strip()
|
| 79 |
+
|
| 80 |
+
# Language mapping: input patterns -> ISO code
|
| 81 |
+
language_mapping = {
|
| 82 |
+
'english': 'en', 'en': 'en',
|
| 83 |
+
'chinese': 'zh', '中文': 'zh', 'zh': 'zh', 'mandarin': 'zh',
|
| 84 |
+
'japanese': 'ja', '日本語': 'ja', 'ja': 'ja',
|
| 85 |
+
'korean': 'ko', '한국어': 'ko', 'ko': 'ko',
|
| 86 |
+
'spanish': 'es', 'español': 'es', 'es': 'es',
|
| 87 |
+
'french': 'fr', 'français': 'fr', 'fr': 'fr',
|
| 88 |
+
'german': 'de', 'deutsch': 'de', 'de': 'de',
|
| 89 |
+
'italian': 'it', 'italiano': 'it', 'it': 'it',
|
| 90 |
+
'portuguese': 'pt', 'português': 'pt', 'pt': 'pt',
|
| 91 |
+
'russian': 'ru', 'русский': 'ru', 'ru': 'ru',
|
| 92 |
+
'bengali': 'bn', 'bn': 'bn',
|
| 93 |
+
'hindi': 'hi', 'hi': 'hi',
|
| 94 |
+
'arabic': 'ar', 'ar': 'ar',
|
| 95 |
+
'thai': 'th', 'th': 'th',
|
| 96 |
+
'vietnamese': 'vi', 'vi': 'vi',
|
| 97 |
+
'indonesian': 'id', 'id': 'id',
|
| 98 |
+
'turkish': 'tr', 'tr': 'tr',
|
| 99 |
+
'dutch': 'nl', 'nl': 'nl',
|
| 100 |
+
'polish': 'pl', 'pl': 'pl',
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
# Detect language
|
| 104 |
+
detected_language = None
|
| 105 |
+
for lang_name, lang_code in language_mapping.items():
|
| 106 |
+
if len(lang_name) <= 2:
|
| 107 |
+
pattern = r'(?:^|\s|[.,;:!?])' + re.escape(lang_name) + r'(?:$|\s|[.,;:!?])'
|
| 108 |
+
else:
|
| 109 |
+
pattern = r'\b' + re.escape(lang_name) + r'\b'
|
| 110 |
+
|
| 111 |
+
if re.search(pattern, description_lower):
|
| 112 |
+
detected_language = lang_code
|
| 113 |
+
break
|
| 114 |
+
|
| 115 |
+
# Detect instrumental
|
| 116 |
+
is_instrumental = False
|
| 117 |
+
if 'instrumental' in description_lower:
|
| 118 |
+
is_instrumental = True
|
| 119 |
+
elif 'pure music' in description_lower or 'pure instrument' in description_lower:
|
| 120 |
+
is_instrumental = True
|
| 121 |
+
elif description_lower.endswith(' solo') or description_lower == 'solo':
|
| 122 |
+
is_instrumental = True
|
| 123 |
+
|
| 124 |
+
return detected_language, is_instrumental
|
| 125 |
+
|
| 126 |
+
|
| 127 |
JobStatus = Literal["queued", "running", "succeeded", "failed"]
|
| 128 |
|
| 129 |
|
|
|
|
| 749 |
if has_sample_query:
|
| 750 |
# Use create_sample() with description query
|
| 751 |
print(f"[api_server] Description mode: generating sample from query: {req.sample_query[:100]}")
|
| 752 |
+
|
| 753 |
+
# Parse description for language and instrumental hints (aligned with feishu_bot)
|
| 754 |
+
parsed_language, parsed_instrumental = _parse_description_hints(req.sample_query)
|
| 755 |
+
print(f"[api_server] Parsed from description: language={parsed_language}, instrumental={parsed_instrumental}")
|
| 756 |
+
|
| 757 |
+
# Determine vocal_language with priority:
|
| 758 |
+
# 1. User-specified vocal_language (if not default "en") - highest priority
|
| 759 |
+
# 2. Language parsed from description
|
| 760 |
+
# 3. None (no constraint)
|
| 761 |
+
if req.vocal_language and req.vocal_language not in ("en", "unknown", ""):
|
| 762 |
+
# User explicitly specified a non-default language, use it
|
| 763 |
+
sample_language = req.vocal_language
|
| 764 |
+
print(f"[api_server] Using user-specified vocal_language: {sample_language}")
|
| 765 |
+
else:
|
| 766 |
+
# Fall back to language parsed from description
|
| 767 |
+
sample_language = parsed_language
|
| 768 |
+
if sample_language:
|
| 769 |
+
print(f"[api_server] Using language from description: {sample_language}")
|
| 770 |
+
|
| 771 |
sample_result = create_sample(
|
| 772 |
llm_handler=llm,
|
| 773 |
query=req.sample_query,
|
| 774 |
+
instrumental=parsed_instrumental,
|
| 775 |
+
vocal_language=sample_language,
|
| 776 |
temperature=req.lm_temperature,
|
| 777 |
top_k=lm_top_k if lm_top_k > 0 else None,
|
| 778 |
top_p=lm_top_p if lm_top_p < 1.0 else None,
|
acestep/handler.py
CHANGED
|
@@ -2110,7 +2110,7 @@ class AceStepHandler:
|
|
| 2110 |
|
| 2111 |
return outputs
|
| 2112 |
|
| 2113 |
-
def tiled_decode(self, latents, chunk_size=512, overlap=64, offload_wav_to_cpu=
|
| 2114 |
"""
|
| 2115 |
Decode latents using tiling to reduce VRAM usage.
|
| 2116 |
Uses overlap-discard strategy to avoid boundary artifacts.
|
|
|
|
| 2110 |
|
| 2111 |
return outputs
|
| 2112 |
|
| 2113 |
+
def tiled_decode(self, latents, chunk_size=512, overlap=64, offload_wav_to_cpu=False):
|
| 2114 |
"""
|
| 2115 |
Decode latents using tiling to reduce VRAM usage.
|
| 2116 |
Uses overlap-discard strategy to avoid boundary artifacts.
|