Gong Junmin commited on
Commit
4f80622
·
unverified ·
2 Parent(s): bc7cf7a d54a0fc

Merge pull request #12 from ace-step/fix_api_simple_mode_lang

Browse files
Files changed (2) hide show
  1. acestep/api_server.py +92 -2
  2. acestep/handler.py +1 -1
acestep/api_server.py CHANGED
@@ -53,6 +53,77 @@ from acestep.inference import (
53
  from acestep.gradio_ui.events.results_handlers import _build_generation_info
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  JobStatus = Literal["queued", "running", "succeeded", "failed"]
57
 
58
 
@@ -678,11 +749,30 @@ def create_app() -> FastAPI:
678
  if has_sample_query:
679
  # Use create_sample() with description query
680
  print(f"[api_server] Description mode: generating sample from query: {req.sample_query[:100]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681
  sample_result = create_sample(
682
  llm_handler=llm,
683
  query=req.sample_query,
684
- instrumental=False, # Could be extracted from description
685
- vocal_language=req.vocal_language if req.vocal_language != "en" else None,
686
  temperature=req.lm_temperature,
687
  top_k=lm_top_k if lm_top_k > 0 else None,
688
  top_p=lm_top_p if lm_top_p < 1.0 else None,
 
53
  from acestep.gradio_ui.events.results_handlers import _build_generation_info
54
 
55
 
56
+ def _parse_description_hints(description: str) -> tuple[Optional[str], bool]:
57
+ """
58
+ Parse a description string to extract language code and instrumental flag.
59
+
60
+ This function analyzes user descriptions like "Pop rock. English" or "piano solo"
61
+ to detect:
62
+ - Language: Maps language names to ISO codes (e.g., "English" -> "en")
63
+ - Instrumental: Detects patterns indicating instrumental/no-vocal music
64
+
65
+ Args:
66
+ description: User's natural language music description
67
+
68
+ Returns:
69
+ (language_code, is_instrumental) tuple:
70
+ - language_code: ISO language code (e.g., "en", "zh") or None if not detected
71
+ - is_instrumental: True if description indicates instrumental music
72
+ """
73
+ import re
74
+
75
+ if not description:
76
+ return None, False
77
+
78
+ description_lower = description.lower().strip()
79
+
80
+ # Language mapping: input patterns -> ISO code
81
+ language_mapping = {
82
+ 'english': 'en', 'en': 'en',
83
+ 'chinese': 'zh', '中文': 'zh', 'zh': 'zh', 'mandarin': 'zh',
84
+ 'japanese': 'ja', '日本語': 'ja', 'ja': 'ja',
85
+ 'korean': 'ko', '한국어': 'ko', 'ko': 'ko',
86
+ 'spanish': 'es', 'español': 'es', 'es': 'es',
87
+ 'french': 'fr', 'français': 'fr', 'fr': 'fr',
88
+ 'german': 'de', 'deutsch': 'de', 'de': 'de',
89
+ 'italian': 'it', 'italiano': 'it', 'it': 'it',
90
+ 'portuguese': 'pt', 'português': 'pt', 'pt': 'pt',
91
+ 'russian': 'ru', 'русский': 'ru', 'ru': 'ru',
92
+ 'bengali': 'bn', 'bn': 'bn',
93
+ 'hindi': 'hi', 'hi': 'hi',
94
+ 'arabic': 'ar', 'ar': 'ar',
95
+ 'thai': 'th', 'th': 'th',
96
+ 'vietnamese': 'vi', 'vi': 'vi',
97
+ 'indonesian': 'id', 'id': 'id',
98
+ 'turkish': 'tr', 'tr': 'tr',
99
+ 'dutch': 'nl', 'nl': 'nl',
100
+ 'polish': 'pl', 'pl': 'pl',
101
+ }
102
+
103
+ # Detect language
104
+ detected_language = None
105
+ for lang_name, lang_code in language_mapping.items():
106
+ if len(lang_name) <= 2:
107
+ pattern = r'(?:^|\s|[.,;:!?])' + re.escape(lang_name) + r'(?:$|\s|[.,;:!?])'
108
+ else:
109
+ pattern = r'\b' + re.escape(lang_name) + r'\b'
110
+
111
+ if re.search(pattern, description_lower):
112
+ detected_language = lang_code
113
+ break
114
+
115
+ # Detect instrumental
116
+ is_instrumental = False
117
+ if 'instrumental' in description_lower:
118
+ is_instrumental = True
119
+ elif 'pure music' in description_lower or 'pure instrument' in description_lower:
120
+ is_instrumental = True
121
+ elif description_lower.endswith(' solo') or description_lower == 'solo':
122
+ is_instrumental = True
123
+
124
+ return detected_language, is_instrumental
125
+
126
+
127
  JobStatus = Literal["queued", "running", "succeeded", "failed"]
128
 
129
 
 
749
  if has_sample_query:
750
  # Use create_sample() with description query
751
  print(f"[api_server] Description mode: generating sample from query: {req.sample_query[:100]}")
752
+
753
+ # Parse description for language and instrumental hints (aligned with feishu_bot)
754
+ parsed_language, parsed_instrumental = _parse_description_hints(req.sample_query)
755
+ print(f"[api_server] Parsed from description: language={parsed_language}, instrumental={parsed_instrumental}")
756
+
757
+ # Determine vocal_language with priority:
758
+ # 1. User-specified vocal_language (if not default "en") - highest priority
759
+ # 2. Language parsed from description
760
+ # 3. None (no constraint)
761
+ if req.vocal_language and req.vocal_language not in ("en", "unknown", ""):
762
+ # User explicitly specified a non-default language, use it
763
+ sample_language = req.vocal_language
764
+ print(f"[api_server] Using user-specified vocal_language: {sample_language}")
765
+ else:
766
+ # Fall back to language parsed from description
767
+ sample_language = parsed_language
768
+ if sample_language:
769
+ print(f"[api_server] Using language from description: {sample_language}")
770
+
771
  sample_result = create_sample(
772
  llm_handler=llm,
773
  query=req.sample_query,
774
+ instrumental=parsed_instrumental,
775
+ vocal_language=sample_language,
776
  temperature=req.lm_temperature,
777
  top_k=lm_top_k if lm_top_k > 0 else None,
778
  top_p=lm_top_p if lm_top_p < 1.0 else None,
acestep/handler.py CHANGED
@@ -2110,7 +2110,7 @@ class AceStepHandler:
2110
 
2111
  return outputs
2112
 
2113
- def tiled_decode(self, latents, chunk_size=512, overlap=64, offload_wav_to_cpu=True):
2114
  """
2115
  Decode latents using tiling to reduce VRAM usage.
2116
  Uses overlap-discard strategy to avoid boundary artifacts.
 
2110
 
2111
  return outputs
2112
 
2113
+ def tiled_decode(self, latents, chunk_size=512, overlap=64, offload_wav_to_cpu=False):
2114
  """
2115
  Decode latents using tiling to reduce VRAM usage.
2116
  Uses overlap-discard strategy to avoid boundary artifacts.