Spaces:
Running on Zero
Running on Zero
| import os | |
| import sys | |
| import re | |
| import json | |
| import base64 | |
| import uuid | |
| import tempfile | |
| import traceback | |
| from datetime import datetime, timezone | |
| import numpy as np | |
| import soundfile as sf | |
| # ββ CRITICAL: import spaces BEFORE torch and acestep βββββββββββββββββββββββββ | |
| try: | |
| import spaces | |
| HAS_SPACES = True | |
| except ImportError: | |
| HAS_SPACES = False | |
| # Clear proxies that may interfere | |
| for _v in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]: | |
| os.environ.pop(_v, None) | |
| os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" | |
| # Fix PermissionError on ZeroGPU: /home/user/.cache is not writable. | |
| os.environ.setdefault("HF_MODULES_CACHE", "/tmp/hf_modules") | |
| os.environ.setdefault("MPLCONFIGDIR", "/tmp/matplotlib") | |
| # Add bundled nano-vllm to path | |
| _current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| _nano_vllm = os.path.join(_current_dir, "acestep", "third_parts", "nano-vllm") | |
| if os.path.exists(_nano_vllm): | |
| sys.path.insert(0, _nano_vllm) | |
| import io | |
| import random | |
| import torch | |
| from PIL import Image | |
| from acestep.handler import AceStepHandler | |
| from gradio import Server | |
| from fastapi.responses import HTMLResponse | |
| from openai import OpenAI | |
| # ββ Model Loading βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _get_storage_path(): | |
| """Model checkpoints β try to reuse preload_from_hub cache via symlinks.""" | |
| p = os.path.join(_current_dir, "model_cache") | |
| os.makedirs(p, exist_ok=True) | |
| checkpoint_dir = os.path.join(p, "checkpoints") | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| # preload_from_hub downloads to HF cache during Docker build. | |
| # Create symlinks so the handler finds models at the expected paths | |
| # without re-downloading 20GB on each restart. | |
| from huggingface_hub import try_to_load_from_cache, scan_cache_dir | |
| for model_name, repo_id in [ | |
| ("acestep-v15-xl-turbo", "ACE-Step/acestep-v15-xl-turbo"), | |
| ]: | |
| target = os.path.join(checkpoint_dir, model_name) | |
| if not os.path.exists(target): | |
| try: | |
| from huggingface_hub import snapshot_download | |
| cached = snapshot_download(repo_id, local_files_only=True) | |
| os.symlink(cached, target) | |
| print(f"[startup] Linked {model_name} β {cached}") | |
| except Exception as e: | |
| print(f"[startup] Cache miss for {model_name}, will download: {e}") | |
| # For the unified repo (ACE-Step/Ace-Step1.5), its subdirs (vae, Qwen3-Embedding-0.6B, etc.) | |
| # need to appear directly in checkpoint_dir | |
| try: | |
| from huggingface_hub import snapshot_download | |
| cached = snapshot_download("ACE-Step/Ace-Step1.5", local_files_only=True) | |
| for sub in os.listdir(cached): | |
| src = os.path.join(cached, sub) | |
| dst = os.path.join(checkpoint_dir, sub) | |
| if os.path.isdir(src) and not os.path.exists(dst): | |
| os.symlink(src, dst) | |
| print(f"[startup] Linked {sub} β {src}") | |
| except Exception as e: | |
| print(f"[startup] Cache miss for Ace-Step1.5, will download: {e}") | |
| return p | |
| _storage = _get_storage_path() | |
| print(f"[startup] Model storage: {_storage}") | |
| print(f"[startup] Community bucket: /data (mounted)") | |
| handler = AceStepHandler(persistent_storage_path=_storage) | |
| _status, _ready = handler.initialize_service( | |
| project_root=_current_dir, | |
| config_path="acestep-v15-xl-turbo", | |
| device="auto", | |
| use_flash_attention=handler.is_flash_attention_available(), | |
| compile_model=False, | |
| offload_to_cpu=False, | |
| offload_dit_to_cpu=False, | |
| ) | |
| print(f"[startup] Handler: ready={_ready} β {_status}") | |
| # ββ Z-Image-Turbo (thumbnail generation) βββββββββββββββββββββββββββββββββββββ | |
| try: | |
| from diffusers import ZImagePipeline, FlowMatchEulerDiscreteScheduler | |
| _zimage_pipe = ZImagePipeline.from_pretrained( | |
| "Tongyi-MAI/Z-Image-Turbo", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| _zimage_pipe.to("cuda") | |
| print("[startup] Z-Image-Turbo loaded for thumbnails") | |
| except Exception as e: | |
| _zimage_pipe = None | |
| print(f"[startup] Z-Image-Turbo not available: {e}") | |
| # ββ LLM Compose ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| COMPOSE_SYSTEM = """You are a Grammy-winning songwriter and music producer. The user will describe a song idea in plain English. Your job is to flesh it out into a complete song specification. | |
| Return EXACTLY this format β no extra text: | |
| --- | |
| title: <short catchy song title> | |
| tags: <genre and style tags, comma-separated, 3-6 tags> | |
| bpm: <tempo as integer> | |
| language: <vocal language: en, zh, ja, ko, or "unknown" for instrumental> | |
| --- | |
| <song lyrics with [Verse], [Chorus], [Bridge] markers> | |
| <use [Instrumental] alone if the song has no vocals>""" | |
| BUCKET_ID = "victor/ace-step-community" | |
| BUCKET_URL = f"https://huggingface.co/buckets/{BUCKET_ID}/resolve" | |
| def _compose(description: str) -> dict: | |
| """Call HF Inference Router LLM to generate tags + lyrics from a description.""" | |
| key = os.environ.get("HF_TOKEN", "") | |
| if not key: | |
| raise RuntimeError("HF_TOKEN not configured") | |
| client = OpenAI(base_url="https://router.huggingface.co/v1", api_key=key) | |
| resp = client.chat.completions.create( | |
| model="openai/gpt-oss-120b:groq", | |
| messages=[ | |
| {"role": "system", "content": COMPOSE_SYSTEM}, | |
| {"role": "user", "content": description}, | |
| ], | |
| max_tokens=2000, | |
| temperature=0.9, | |
| ) | |
| raw = resp.choices[0].message.content or "" | |
| content = re.sub(r"<think>.*?</think>", "", raw, flags=re.DOTALL).strip() | |
| # Parse frontmatter | |
| title, tags, bpm, language = "Untitled", "", 120, "en" | |
| lyrics = content | |
| m = re.search(r"---\s*\n(.*?)\n---\s*\n(.*)", content, re.DOTALL) | |
| if m: | |
| header, lyrics = m.group(1), m.group(2).strip() | |
| for line in header.strip().split("\n"): | |
| if line.startswith("title:"): | |
| title = line[6:].strip().strip('"\'') | |
| elif line.startswith("tags:"): | |
| tags = line[5:].strip() | |
| elif line.startswith("bpm:"): | |
| try: | |
| bpm = int(line[4:].strip()) | |
| except ValueError: | |
| pass | |
| elif line.startswith("language:"): | |
| language = line[9:].strip() | |
| return {"title": title, "tags": tags, "lyrics": lyrics, "bpm": bpm, "language": language} | |
| # ββ Thumbnail Generation βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _get_song_word(title: str, tags: str, lyrics: str, description: str) -> str: | |
| """Ask LLM for a single evocative word to represent the song visually.""" | |
| # Fallback: first 2 words of description or title | |
| fallback = " ".join((description or title or "music").split()[:2]) | |
| key = os.environ.get("HF_TOKEN", "") | |
| if not key: | |
| print(f"[thumbnail] no HF_TOKEN, using fallback: {fallback}") | |
| return fallback | |
| try: | |
| client = OpenAI(base_url="https://router.huggingface.co/v1", api_key=key) | |
| resp = client.chat.completions.create( | |
| model="openai/gpt-oss-120b:groq", | |
| messages=[ | |
| {"role": "system", "content": "Reply with exactly ONE concrete visual noun (a physical object, animal, or natural element) that captures the essence of this song. No explanation, no punctuation, just the single word."}, | |
| {"role": "user", "content": f"Title: {title}\nTags: {tags}\nLyrics: {lyrics[:300]}"}, | |
| ], | |
| max_tokens=500, | |
| temperature=0.7, | |
| ) | |
| raw = resp.choices[0].message.content or "" | |
| cleaned = re.sub(r"<think>.*?</think>", "", raw, flags=re.DOTALL).strip() | |
| word = cleaned.split()[0].strip('."\'!,') if cleaned.split() else "" | |
| if not word: | |
| print(f"[thumbnail] LLM returned empty, using fallback: {fallback}") | |
| return fallback | |
| print(f"[thumbnail] word: {word}") | |
| return word | |
| except Exception as e: | |
| print(f"[thumbnail] word extraction failed: {e}, using fallback: {fallback}") | |
| return fallback | |
| def _generate_thumbnail_impl(word: str) -> bytes | None: | |
| """Generate a thumbnail using Z-Image-Turbo. Returns PNG bytes or None.""" | |
| if _zimage_pipe is None: | |
| return None | |
| try: | |
| prompt = f"{word} studio photography close-up black background" | |
| print(f"[thumbnail] generating: {prompt}") | |
| scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) | |
| _zimage_pipe.scheduler = scheduler | |
| image = _zimage_pipe( | |
| prompt=prompt, | |
| height=1024, width=1024, | |
| guidance_scale=0.0, | |
| num_inference_steps=9, | |
| generator=torch.Generator("cuda").manual_seed(random.randint(1, 1000000)), | |
| max_sequence_length=512, | |
| ).images[0] | |
| buf = io.BytesIO() | |
| image.save(buf, format="PNG", optimize=True) | |
| print(f"[thumbnail] done ({len(buf.getvalue()) // 1024}KB)") | |
| return buf.getvalue() | |
| except Exception as e: | |
| print(f"[thumbnail] generation failed: {e}") | |
| return None | |
| if HAS_SPACES: | |
| def _generate_thumbnail(word: str) -> bytes | None: | |
| return _generate_thumbnail_impl(word) | |
| else: | |
| def _generate_thumbnail(word: str) -> bytes | None: | |
| return _generate_thumbnail_impl(word) | |
| # ββ GPU Inference Function ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if HAS_SPACES: | |
| def _generate_gpu(prompt, lyrics, audio_duration, infer_steps, seed): | |
| return _run_inference(prompt, lyrics, audio_duration, infer_steps, seed) | |
| else: | |
| def _generate_gpu(prompt, lyrics, audio_duration, infer_steps, seed): | |
| return _run_inference(prompt, lyrics, audio_duration, infer_steps, seed) | |
| def _run_inference(prompt, lyrics, audio_duration, infer_steps, seed) -> str: | |
| """Core inference using v1.5 AceStepHandler. Returns path to saved WAV.""" | |
| use_random = seed < 0 | |
| result = handler.generate_music( | |
| captions=prompt, | |
| lyrics=lyrics, | |
| audio_duration=audio_duration, | |
| inference_steps=infer_steps, | |
| guidance_scale=7.0, | |
| use_random_seed=use_random, | |
| seed=None if use_random else seed, | |
| infer_method="ode", | |
| shift=1.0, | |
| use_adg=False, | |
| vocal_language="en", | |
| ) | |
| if not result.get("success"): | |
| raise RuntimeError(result.get("error", "generation failed")) | |
| audio_dict = result["audios"][0] | |
| tensor = audio_dict["tensor"] | |
| sr = audio_dict["sample_rate"] | |
| data = tensor.cpu().float().numpy() | |
| if data.ndim == 2: | |
| data = data.T | |
| if data.shape[1] == 1: | |
| data = data[:, 0] | |
| peak = np.abs(data).max() | |
| if peak > 1e-4: | |
| data = (data / peak * 0.95).astype(np.float32) | |
| out_path = os.path.join(tempfile.mkdtemp(), "output.wav") | |
| sf.write(out_path, data, sr) | |
| return out_path | |
| # ββ gr.Server App βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = Server(title="ace-step-jam") | |
| # ββ API: One-box create (compose + generate) βββββββββββββββββββββββββββββββββ | |
| def create( | |
| description: str, | |
| audio_duration: float = 60.0, | |
| seed: int = -1, | |
| community: bool = False, | |
| ) -> str: | |
| """One-box: describe a song β LLM composes tags+lyrics β generates audio. | |
| Returns JSON: {audio, title, tags, lyrics, community_url?}""" | |
| try: | |
| # Step 1: LLM compose (no GPU) | |
| composed = _compose(description) | |
| title = composed["title"] | |
| tags = composed["tags"] | |
| lyrics = composed["lyrics"] | |
| print(f"[create] title={title} tags={tags[:60]}...") | |
| # Step 2: GPU generate music | |
| wav_path = _generate_gpu(tags, lyrics, audio_duration, 8, seed) | |
| with open(wav_path, "rb") as f: | |
| wav_bytes = f.read() | |
| audio_b64 = f"data:audio/wav;base64,{base64.b64encode(wav_bytes).decode()}" | |
| # Step 3: Generate thumbnail (separate GPU session via Z-Image-Turbo) | |
| thumb_bytes = None | |
| try: | |
| word = _get_song_word(title, tags, lyrics, description) | |
| thumb_bytes = _generate_thumbnail(word) | |
| except Exception as e: | |
| print(f"[create] thumbnail failed: {e}") | |
| result = { | |
| "audio": audio_b64, | |
| "title": title, | |
| "tags": tags, | |
| "lyrics": lyrics, | |
| } | |
| if thumb_bytes: | |
| result["thumbnail"] = f"data:image/png;base64,{base64.b64encode(thumb_bytes).decode()}" | |
| # Step 3: Community upload (if checked and /data is writable) | |
| if community: | |
| try: | |
| song_id = uuid.uuid4().hex[:12] | |
| song_dir = f"/data/songs/{song_id}" | |
| os.makedirs(song_dir, exist_ok=True) | |
| # Save WAV | |
| wav_name = f"{song_id}.wav" | |
| with open(f"{song_dir}/{wav_name}", "wb") as f: | |
| f.write(wav_bytes) | |
| # Save thumbnail | |
| has_thumb = False | |
| if thumb_bytes: | |
| with open(f"{song_dir}/thumb.png", "wb") as f: | |
| f.write(thumb_bytes) | |
| has_thumb = True | |
| # Save metadata to bucket (durability) + memory (instant reads) | |
| audio_url = f"{BUCKET_URL}/songs/{song_id}/{wav_name}" | |
| thumb_url = f"{BUCKET_URL}/songs/{song_id}/thumb.png" if has_thumb else None | |
| meta = { | |
| "id": song_id, | |
| "title": title, | |
| "description": description, | |
| "tags": tags, | |
| "lyrics": lyrics, | |
| "duration": audio_duration, | |
| "audio_url": audio_url, | |
| "thumb_url": thumb_url, | |
| "has_thumb": has_thumb, | |
| "created_at": datetime.now(timezone.utc).isoformat(), | |
| } | |
| with open(f"{song_dir}/meta.json", "w") as f: | |
| json.dump(meta, f, indent=2) | |
| # Prepend to in-memory feed (no re-scan needed) | |
| _feed_songs.insert(0, meta) | |
| result["community_url"] = audio_url | |
| print(f"[create] Shared to community: {audio_url}") | |
| except Exception as upload_err: | |
| print(f"[create] Community upload failed: {upload_err}") | |
| return json.dumps(result) | |
| except Exception as e: | |
| print(f"[create ERROR] {type(e).__name__}: {e}") | |
| print(traceback.format_exc()) | |
| raise | |
| # ββ API: Direct generate (for advanced/custom mode) ββββββββββββββββββββββββββ | |
| def generate( | |
| prompt: str, | |
| lyrics: str, | |
| audio_duration: float = 60.0, | |
| infer_step: int = 8, | |
| guidance_scale: float = 7.0, | |
| seed: int = -1, | |
| lora_name_or_path: str = "", | |
| lora_weight: float = 0.8, | |
| ) -> str: | |
| """Direct generate from explicit tags + lyrics. Returns base64 WAV data URL.""" | |
| try: | |
| wav_path = _generate_gpu(prompt, lyrics, audio_duration, infer_step, seed) | |
| with open(wav_path, "rb") as f: | |
| encoded = base64.b64encode(f.read()).decode() | |
| return f"data:audio/wav;base64,{encoded}" | |
| except Exception as e: | |
| print(f"[generate ERROR] {type(e).__name__}: {e}") | |
| print(traceback.format_exc()) | |
| raise | |
| # ββ Community feed (in-memory, loaded once at startup) βββββββββββββββββββββββ | |
| _feed_songs = [] | |
| def _load_feed_from_disk(): | |
| """One-time scan at startup to populate memory from bucket.""" | |
| songs_dir = "/data/songs" | |
| if not os.path.isdir(songs_dir): | |
| print("[feed] /data/songs not found, starting with empty feed") | |
| return | |
| for song_id in os.listdir(songs_dir): | |
| meta_path = os.path.join(songs_dir, song_id, "meta.json") | |
| if os.path.isfile(meta_path): | |
| try: | |
| with open(meta_path) as f: | |
| meta = json.load(f) | |
| meta["audio_url"] = f"{BUCKET_URL}/songs/{song_id}/{song_id}.wav" | |
| thumb_path = os.path.join(songs_dir, song_id, "thumb.png") | |
| if os.path.isfile(thumb_path): | |
| meta["thumb_url"] = f"{BUCKET_URL}/songs/{song_id}/thumb.png" | |
| _feed_songs.append(meta) | |
| except Exception: | |
| pass | |
| _feed_songs.sort(key=lambda s: s.get("created_at", ""), reverse=True) | |
| print(f"[feed] Loaded {len(_feed_songs)} songs into memory") | |
| _load_feed_from_disk() | |
| def community() -> str: | |
| """List community songs β served from memory, zero disk I/O.""" | |
| return json.dumps(_feed_songs[:50]) | |
| # ββ Serve custom HTML frontend ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def homepage(): | |
| with open("index.html", "r") as f: | |
| return f.read() | |
| demo = app | |
| if __name__ == "__main__": | |
| demo.launch(show_error=True, ssr_mode=False) | |