File size: 11,092 Bytes
9b0e640
 
 
 
22101a6
9b0e640
 
 
 
ccf578a
 
 
 
 
 
 
9b0e640
 
 
 
 
 
 
 
22101a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b0e640
 
bdee37a
 
 
 
 
 
 
 
0b990cd
 
 
 
 
bdee37a
0b990cd
 
 
 
 
 
 
 
 
 
 
 
bdee37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b0e640
 
22101a6
0b990cd
 
 
 
 
 
 
 
bdee37a
 
22101a6
 
 
 
 
0b990cd
 
 
 
 
 
 
 
 
22101a6
 
 
 
 
 
 
 
 
 
2ac0444
22101a6
0b990cd
 
 
22101a6
 
2ac0444
22101a6
 
 
 
 
0b990cd
 
 
22101a6
 
 
0b990cd
22101a6
 
 
 
 
0b990cd
 
 
 
22101a6
0b990cd
 
 
 
 
22101a6
0b990cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22101a6
0b990cd
 
 
 
22101a6
 
9b0e640
22101a6
 
 
 
0b990cd
22101a6
 
 
 
 
 
 
 
 
 
0b990cd
 
22101a6
9b0e640
22101a6
0b990cd
9b0e640
22101a6
 
 
 
 
 
 
 
 
 
 
 
 
9b0e640
22101a6
e19bc36
22101a6
9b0e640
22101a6
9b0e640
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
"""
ACE-Step v1.5 - HuggingFace Space Entry Point

This file serves as the entry point for HuggingFace Space deployment.
It initializes the service and launches the Gradio interface.
"""
import os
import sys

# Get current directory (app.py location)
current_dir = os.path.dirname(os.path.abspath(__file__))

# Add nano-vllm to Python path (local package)
nano_vllm_path = os.path.join(current_dir, "acestep", "third_parts", "nano-vllm")
if os.path.exists(nano_vllm_path):
    sys.path.insert(0, nano_vllm_path)

# Disable Gradio analytics
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"

# Clear proxy settings that may affect Gradio
for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
    os.environ.pop(proxy_var, None)

import torch
from acestep.handler import AceStepHandler
from acestep.llm_inference import LLMHandler
from acestep.dataset_handler import DatasetHandler
from acestep.gradio_ui import create_gradio_interface


def get_gpu_memory_gb():
    """
    Get GPU memory in GB. Returns 0 if no GPU is available.
    """
    try:
        if torch.cuda.is_available():
            total_memory = torch.cuda.get_device_properties(0).total_memory
            memory_gb = total_memory / (1024**3)
            return memory_gb
        else:
            return 0
    except Exception as e:
        print(f"Warning: Failed to detect GPU memory: {e}", file=sys.stderr)
        return 0


def get_persistent_storage_path():
    """
    Detect and return a writable persistent storage path.
    
    HuggingFace Space persistent storage requirements:
    1. Must be enabled in Space settings
    2. Path is typically /data for Docker SDK
    3. Falls back to app directory if /data is not writable
    
    Local development:
    - Set CHECKPOINT_DIR environment variable to use local checkpoints
      Example: CHECKPOINT_DIR=/path/to/checkpoints python app.py
      The path should be the parent directory of 'checkpoints' folder
    """
    # Check for local checkpoint directory override (for development)
    checkpoint_dir_override = os.environ.get("CHECKPOINT_DIR")
    if checkpoint_dir_override:
        # If user specifies the checkpoints folder directly, use its parent
        if checkpoint_dir_override.endswith("/checkpoints") or checkpoint_dir_override.endswith("\\checkpoints"):
            checkpoint_dir_override = os.path.dirname(checkpoint_dir_override)
        if os.path.exists(checkpoint_dir_override):
            print(f"Using local checkpoint directory (CHECKPOINT_DIR): {checkpoint_dir_override}")
            return checkpoint_dir_override
        else:
            print(f"Warning: CHECKPOINT_DIR path does not exist: {checkpoint_dir_override}")
    
    # Try HuggingFace Space persistent storage first
    hf_data_path = "/data"
    
    # Check if /data exists and is writable
    if os.path.exists(hf_data_path):
        try:
            test_file = os.path.join(hf_data_path, ".write_test")
            with open(test_file, 'w') as f:
                f.write("test")
            os.remove(test_file)
            print(f"Using HuggingFace persistent storage: {hf_data_path}")
            return hf_data_path
        except (PermissionError, OSError) as e:
            print(f"Warning: /data exists but is not writable: {e}")
    
    # Fall back to app directory (non-persistent but works without special config)
    fallback_path = os.path.join(current_dir, "data")
    os.makedirs(fallback_path, exist_ok=True)
    print(f"Using local storage (non-persistent): {fallback_path}")
    print("Note: To enable persistent storage, configure it in HuggingFace Space settings")
    return fallback_path


def main():
    """Main entry point for HuggingFace Space"""
    
    # Check for DEBUG_UI mode (skip model initialization for UI development)
    debug_ui = os.environ.get("DEBUG_UI", "").lower() in ("1", "true", "yes")
    if debug_ui:
        print("=" * 60)
        print("DEBUG_UI mode enabled - skipping model initialization")
        print("UI will be fully functional but generation is disabled")
        print("=" * 60)
    
    # Get persistent storage path (auto-detect)
    persistent_storage_path = get_persistent_storage_path()
    
    # Detect GPU memory for auto-configuration
    gpu_memory_gb = get_gpu_memory_gb()
    auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < 16
    
    if not debug_ui:
        if auto_offload:
            print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (< 16GB)")
            print("Auto-enabling CPU offload to reduce GPU memory usage")
        elif gpu_memory_gb > 0:
            print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (>= 16GB)")
            print("CPU offload disabled by default")
        else:
            print("No GPU detected, running on CPU")
    
    # Create handler instances
    print("Creating handlers...")
    dit_handler = AceStepHandler(persistent_storage_path=persistent_storage_path)
    llm_handler = LLMHandler(persistent_storage_path=persistent_storage_path)
    dataset_handler = DatasetHandler()
    
    # Service mode configuration from environment variables
    config_path = os.environ.get(
        "SERVICE_MODE_DIT_MODEL",
        "acestep-v15-turbo"
    )
    # Second DiT model - default to turbo-shift3 for two-model setup
    config_path_2 = os.environ.get("SERVICE_MODE_DIT_MODEL_2", "acestep-v15-turbo-shift3").strip()
    
    lm_model_path = os.environ.get(
        "SERVICE_MODE_LM_MODEL",
        "acestep-5Hz-lm-1.7B"
    )
    backend = os.environ.get("SERVICE_MODE_BACKEND", "vllm")
    device = "auto"
    
    print(f"Service mode configuration:")
    print(f"  DiT model 1: {config_path}")
    if config_path_2:
        print(f"  DiT model 2: {config_path_2}")
    print(f"  LM model: {lm_model_path}")
    print(f"  Backend: {backend}")
    print(f"  Offload to CPU: {auto_offload}")
    print(f"  DEBUG_UI: {debug_ui}")
    
    # Determine flash attention availability
    use_flash_attention = dit_handler.is_flash_attention_available()
    print(f"  Flash Attention: {use_flash_attention}")
    
    # Initialize models (skip in DEBUG_UI mode)
    init_status = ""
    enable_generate = False
    dit_handler_2 = None
    
    if debug_ui:
        # In DEBUG_UI mode, skip all model initialization
        init_status = "⚠️ DEBUG_UI mode - models not loaded\nUI is functional but generation is disabled"
        enable_generate = False
        print("Skipping model initialization (DEBUG_UI mode)")
    else:
        # Initialize primary DiT model
        print(f"Initializing DiT model 1: {config_path}...")
        init_status, enable_generate = dit_handler.initialize_service(
            project_root=current_dir,
            config_path=config_path,
            device=device,
            use_flash_attention=use_flash_attention,
            compile_model=False,
            offload_to_cpu=auto_offload,
            offload_dit_to_cpu=False
        )
        
        if not enable_generate:
            print(f"Warning: DiT model 1 initialization issue: {init_status}", file=sys.stderr)
        else:
            print("DiT model 1 initialized successfully")
        
        # Initialize second DiT model if configured
        if config_path_2:
            print(f"Initializing DiT model 2: {config_path_2}...")
            dit_handler_2 = AceStepHandler(persistent_storage_path=persistent_storage_path)
            
            # Share VAE, text_encoder, and silence_latent from the first handler to save memory
            init_status_2, enable_generate_2 = dit_handler_2.initialize_service(
                project_root=current_dir,
                config_path=config_path_2,
                device=device,
                use_flash_attention=use_flash_attention,
                compile_model=False,
                offload_to_cpu=auto_offload,
                offload_dit_to_cpu=False,
                # Share components from first handler
                shared_vae=dit_handler.vae,
                shared_text_encoder=dit_handler.text_encoder,
                shared_text_tokenizer=dit_handler.text_tokenizer,
                shared_silence_latent=dit_handler.silence_latent,
            )
            
            if not enable_generate_2:
                print(f"Warning: DiT model 2 initialization issue: {init_status_2}", file=sys.stderr)
                init_status += f"\n⚠️ DiT model 2 failed: {init_status_2}"
            else:
                print("DiT model 2 initialized successfully")
                init_status += f"\n✅ DiT model 2: {config_path_2}"
        
        # Initialize LM model
        checkpoint_dir = dit_handler._get_checkpoint_dir()
        print(f"Initializing 5Hz LM: {lm_model_path}...")
        lm_status, lm_success = llm_handler.initialize(
            checkpoint_dir=checkpoint_dir,
            lm_model_path=lm_model_path,
            backend=backend,
            device=device,
            offload_to_cpu=auto_offload,
            dtype=dit_handler.dtype
        )
        
        if lm_success:
            print("5Hz LM initialized successfully")
            init_status += f"\n{lm_status}"
        else:
            print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
            init_status += f"\n{lm_status}"
    
    # Build available models list for UI
    available_dit_models = [config_path]
    if config_path_2 and dit_handler_2 is not None:
        available_dit_models.append(config_path_2)
    
    # Prepare initialization parameters for UI
    init_params = {
        'pre_initialized': True,
        'service_mode': True,
        'checkpoint': None,
        'config_path': config_path,
        'config_path_2': config_path_2 if config_path_2 else None,
        'device': device,
        'init_llm': True,
        'lm_model_path': lm_model_path,
        'backend': backend,
        'use_flash_attention': use_flash_attention,
        'offload_to_cpu': auto_offload,
        'offload_dit_to_cpu': False,
        'init_status': init_status,
        'enable_generate': enable_generate,
        'dit_handler': dit_handler,
        'dit_handler_2': dit_handler_2,
        'available_dit_models': available_dit_models,
        'llm_handler': llm_handler,
        'language': 'en',
        'persistent_storage_path': persistent_storage_path,
        'debug_ui': debug_ui,
    }
    
    print("Service initialization completed!")
    
    # Create Gradio interface with pre-initialized handlers
    print("Creating Gradio interface...")
    demo = create_gradio_interface(
        dit_handler, 
        llm_handler, 
        dataset_handler, 
        init_params=init_params, 
        language='en'
    )
    
    # Enable queue for multi-user support
    print("Enabling queue for multi-user support...")
    demo.queue(max_size=20, default_concurrency_limit=1)
    
    # Launch
    print("Launching server on 0.0.0.0:7860...")
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        show_error=True,
    )


if __name__ == "__main__":
    main()