| """ |
| ACE-Step v1.5 - HuggingFace Space Entry Point |
| |
| This file serves as the entry point for HuggingFace Space deployment. |
| It initializes the service and launches the Gradio interface. |
| """ |
| import os |
| import sys |
|
|
| |
| current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
| |
| nano_vllm_path = os.path.join(current_dir, "acestep", "third_parts", "nano-vllm") |
| if os.path.exists(nano_vllm_path): |
| sys.path.insert(0, nano_vllm_path) |
|
|
| |
| os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" |
|
|
| |
| for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']: |
| os.environ.pop(proxy_var, None) |
|
|
| import torch |
| from acestep.handler import AceStepHandler |
| from acestep.llm_inference import LLMHandler |
| from acestep.dataset_handler import DatasetHandler |
| from acestep.gradio_ui import create_gradio_interface |
|
|
|
|
| def get_gpu_memory_gb(): |
| """ |
| Get GPU memory in GB. Returns 0 if no GPU is available. |
| """ |
| try: |
| if torch.cuda.is_available(): |
| total_memory = torch.cuda.get_device_properties(0).total_memory |
| memory_gb = total_memory / (1024**3) |
| return memory_gb |
| else: |
| return 0 |
| except Exception as e: |
| print(f"Warning: Failed to detect GPU memory: {e}", file=sys.stderr) |
| return 0 |
|
|
|
|
| def get_persistent_storage_path(): |
| """ |
| Detect and return a writable persistent storage path. |
| |
| HuggingFace Space persistent storage requirements: |
| 1. Must be enabled in Space settings |
| 2. Path is typically /data for Docker SDK |
| 3. Falls back to app directory if /data is not writable |
| |
| Local development: |
| - Set CHECKPOINT_DIR environment variable to use local checkpoints |
| Example: CHECKPOINT_DIR=/path/to/checkpoints python app.py |
| The path should be the parent directory of 'checkpoints' folder |
| """ |
| |
| checkpoint_dir_override = os.environ.get("CHECKPOINT_DIR") |
| if checkpoint_dir_override: |
| |
| if checkpoint_dir_override.endswith("/checkpoints") or checkpoint_dir_override.endswith("\\checkpoints"): |
| checkpoint_dir_override = os.path.dirname(checkpoint_dir_override) |
| if os.path.exists(checkpoint_dir_override): |
| print(f"Using local checkpoint directory (CHECKPOINT_DIR): {checkpoint_dir_override}") |
| return checkpoint_dir_override |
| else: |
| print(f"Warning: CHECKPOINT_DIR path does not exist: {checkpoint_dir_override}") |
| |
| |
| hf_data_path = "/data" |
| |
| |
| if os.path.exists(hf_data_path): |
| try: |
| test_file = os.path.join(hf_data_path, ".write_test") |
| with open(test_file, 'w') as f: |
| f.write("test") |
| os.remove(test_file) |
| print(f"Using HuggingFace persistent storage: {hf_data_path}") |
| return hf_data_path |
| except (PermissionError, OSError) as e: |
| print(f"Warning: /data exists but is not writable: {e}") |
| |
| |
| fallback_path = os.path.join(current_dir, "data") |
| os.makedirs(fallback_path, exist_ok=True) |
| print(f"Using local storage (non-persistent): {fallback_path}") |
| print("Note: To enable persistent storage, configure it in HuggingFace Space settings") |
| return fallback_path |
|
|
|
|
| def main(): |
| """Main entry point for HuggingFace Space""" |
| |
| |
| debug_ui = os.environ.get("DEBUG_UI", "").lower() in ("1", "true", "yes") |
| if debug_ui: |
| print("=" * 60) |
| print("DEBUG_UI mode enabled - skipping model initialization") |
| print("UI will be fully functional but generation is disabled") |
| print("=" * 60) |
| |
| |
| persistent_storage_path = get_persistent_storage_path() |
| |
| |
| gpu_memory_gb = get_gpu_memory_gb() |
| auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < 16 |
| |
| if not debug_ui: |
| if auto_offload: |
| print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (< 16GB)") |
| print("Auto-enabling CPU offload to reduce GPU memory usage") |
| elif gpu_memory_gb > 0: |
| print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (>= 16GB)") |
| print("CPU offload disabled by default") |
| else: |
| print("No GPU detected, running on CPU") |
| |
| |
| print("Creating handlers...") |
| dit_handler = AceStepHandler(persistent_storage_path=persistent_storage_path) |
| llm_handler = LLMHandler(persistent_storage_path=persistent_storage_path) |
| dataset_handler = DatasetHandler() |
| |
| |
| config_path = os.environ.get( |
| "SERVICE_MODE_DIT_MODEL", |
| "acestep-v15-turbo" |
| ) |
| |
| config_path_2 = os.environ.get("SERVICE_MODE_DIT_MODEL_2", "acestep-v15-turbo-shift3").strip() |
| |
| lm_model_path = os.environ.get( |
| "SERVICE_MODE_LM_MODEL", |
| "acestep-5Hz-lm-1.7B" |
| ) |
| backend = os.environ.get("SERVICE_MODE_BACKEND", "vllm") |
| device = "auto" |
| |
| print(f"Service mode configuration:") |
| print(f" DiT model 1: {config_path}") |
| if config_path_2: |
| print(f" DiT model 2: {config_path_2}") |
| print(f" LM model: {lm_model_path}") |
| print(f" Backend: {backend}") |
| print(f" Offload to CPU: {auto_offload}") |
| print(f" DEBUG_UI: {debug_ui}") |
| |
| |
| use_flash_attention = dit_handler.is_flash_attention_available() |
| print(f" Flash Attention: {use_flash_attention}") |
| |
| |
| init_status = "" |
| enable_generate = False |
| dit_handler_2 = None |
| |
| if debug_ui: |
| |
| init_status = "⚠️ DEBUG_UI mode - models not loaded\nUI is functional but generation is disabled" |
| enable_generate = False |
| print("Skipping model initialization (DEBUG_UI mode)") |
| else: |
| |
| print(f"Initializing DiT model 1: {config_path}...") |
| init_status, enable_generate = dit_handler.initialize_service( |
| project_root=current_dir, |
| config_path=config_path, |
| device=device, |
| use_flash_attention=use_flash_attention, |
| compile_model=False, |
| offload_to_cpu=auto_offload, |
| offload_dit_to_cpu=False |
| ) |
| |
| if not enable_generate: |
| print(f"Warning: DiT model 1 initialization issue: {init_status}", file=sys.stderr) |
| else: |
| print("DiT model 1 initialized successfully") |
| |
| |
| if config_path_2: |
| print(f"Initializing DiT model 2: {config_path_2}...") |
| dit_handler_2 = AceStepHandler(persistent_storage_path=persistent_storage_path) |
| |
| |
| init_status_2, enable_generate_2 = dit_handler_2.initialize_service( |
| project_root=current_dir, |
| config_path=config_path_2, |
| device=device, |
| use_flash_attention=use_flash_attention, |
| compile_model=False, |
| offload_to_cpu=auto_offload, |
| offload_dit_to_cpu=False, |
| |
| shared_vae=dit_handler.vae, |
| shared_text_encoder=dit_handler.text_encoder, |
| shared_text_tokenizer=dit_handler.text_tokenizer, |
| shared_silence_latent=dit_handler.silence_latent, |
| ) |
| |
| if not enable_generate_2: |
| print(f"Warning: DiT model 2 initialization issue: {init_status_2}", file=sys.stderr) |
| init_status += f"\n⚠️ DiT model 2 failed: {init_status_2}" |
| else: |
| print("DiT model 2 initialized successfully") |
| init_status += f"\n✅ DiT model 2: {config_path_2}" |
| |
| |
| checkpoint_dir = dit_handler._get_checkpoint_dir() |
| print(f"Initializing 5Hz LM: {lm_model_path}...") |
| lm_status, lm_success = llm_handler.initialize( |
| checkpoint_dir=checkpoint_dir, |
| lm_model_path=lm_model_path, |
| backend=backend, |
| device=device, |
| offload_to_cpu=auto_offload, |
| dtype=dit_handler.dtype |
| ) |
| |
| if lm_success: |
| print("5Hz LM initialized successfully") |
| init_status += f"\n{lm_status}" |
| else: |
| print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr) |
| init_status += f"\n{lm_status}" |
| |
| |
| available_dit_models = [config_path] |
| if config_path_2 and dit_handler_2 is not None: |
| available_dit_models.append(config_path_2) |
| |
| |
| init_params = { |
| 'pre_initialized': True, |
| 'service_mode': True, |
| 'checkpoint': None, |
| 'config_path': config_path, |
| 'config_path_2': config_path_2 if config_path_2 else None, |
| 'device': device, |
| 'init_llm': True, |
| 'lm_model_path': lm_model_path, |
| 'backend': backend, |
| 'use_flash_attention': use_flash_attention, |
| 'offload_to_cpu': auto_offload, |
| 'offload_dit_to_cpu': False, |
| 'init_status': init_status, |
| 'enable_generate': enable_generate, |
| 'dit_handler': dit_handler, |
| 'dit_handler_2': dit_handler_2, |
| 'available_dit_models': available_dit_models, |
| 'llm_handler': llm_handler, |
| 'language': 'en', |
| 'persistent_storage_path': persistent_storage_path, |
| 'debug_ui': debug_ui, |
| } |
| |
| print("Service initialization completed!") |
| |
| |
| print("Creating Gradio interface...") |
| demo = create_gradio_interface( |
| dit_handler, |
| llm_handler, |
| dataset_handler, |
| init_params=init_params, |
| language='en' |
| ) |
| |
| |
| print("Enabling queue for multi-user support...") |
| demo.queue(max_size=20, default_concurrency_limit=1) |
| |
| |
| print("Launching server on 0.0.0.0:7860...") |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False, |
| show_error=True, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|