pedroapfilho commited on
Commit
b0a0560
·
unverified ·
1 Parent(s): 0db63ea

Fix meta tensor crash: force CPU device context during model init

Browse files

ZeroGPU redirects all tensor creation to meta device at the torch
level, which breaks ResidualFSQ assertions. Wrapping from_pretrained
in torch.device('cpu') overrides this redirection.

Also removes boot diagnostics from app.py.

Files changed (2) hide show
  1. acestep/handler.py +12 -8
  2. app.py +16 -30
acestep/handler.py CHANGED
@@ -488,14 +488,18 @@ class AceStepHandler:
488
  for candidate in attn_candidates:
489
  try:
490
  logger.info(f"[initialize_service] Attempting to load model with attention implementation: {candidate}")
491
- self.model = AutoModel.from_pretrained(
492
- acestep_v15_checkpoint_path,
493
- trust_remote_code=True,
494
- attn_implementation=candidate,
495
- torch_dtype=self.dtype,
496
- low_cpu_mem_usage=False, # Disable memory-efficient weight loading
497
- _fast_init=False, # Disable fast initialization (prevents meta device)
498
- )
 
 
 
 
499
  attn_implementation = candidate
500
  break
501
  except Exception as e:
 
488
  for candidate in attn_candidates:
489
  try:
490
  logger.info(f"[initialize_service] Attempting to load model with attention implementation: {candidate}")
491
+ # Force CPU device context to override ZeroGPU's meta device
492
+ # redirection. ResidualFSQ asserts on tensor values during
493
+ # __init__, which fails on meta tensors.
494
+ with torch.device("cpu"):
495
+ self.model = AutoModel.from_pretrained(
496
+ acestep_v15_checkpoint_path,
497
+ trust_remote_code=True,
498
+ attn_implementation=candidate,
499
+ torch_dtype=self.dtype,
500
+ low_cpu_mem_usage=False,
501
+ _fast_init=False,
502
+ )
503
  attn_implementation = candidate
504
  break
505
  except Exception as e:
app.py CHANGED
@@ -6,36 +6,22 @@ A comprehensive music generation system with three main interfaces:
6
  3. LoRA Training Studio
7
  """
8
 
9
- import sys
10
- print("[BOOT] app.py starting imports...", flush=True)
11
-
12
- try:
13
- import gradio as gr
14
- print("[BOOT] gradio OK", flush=True)
15
- import torch
16
- import numpy as np
17
- from pathlib import Path
18
- import json
19
- from typing import Optional, List, Tuple
20
- import spaces
21
- print("[BOOT] stdlib + spaces OK", flush=True)
22
-
23
- from src.ace_step_engine import ACEStepEngine
24
- from src.timeline_manager import TimelineManager
25
- from src.lora_trainer import download_hf_dataset
26
- from src.audio_processor import AudioProcessor
27
- from src.utils import setup_logging, load_config
28
- print("[BOOT] src imports OK", flush=True)
29
-
30
- from acestep.training.dataset_builder import DatasetBuilder
31
- from acestep.training.configs import LoRAConfig, TrainingConfig
32
- from acestep.training.trainer import LoRATrainer as FabricLoRATrainer
33
- print("[BOOT] acestep.training imports OK", flush=True)
34
- except Exception as e:
35
- print(f"[BOOT] IMPORT FAILED: {e}", flush=True)
36
- import traceback
37
- traceback.print_exc()
38
- sys.exit(1)
39
 
40
  # Setup
41
  logger = setup_logging()
 
6
  3. LoRA Training Studio
7
  """
8
 
9
+ import gradio as gr
10
+ import torch
11
+ import numpy as np
12
+ from pathlib import Path
13
+ import json
14
+ from typing import Optional, List, Tuple
15
+ import spaces
16
+
17
+ from src.ace_step_engine import ACEStepEngine
18
+ from src.timeline_manager import TimelineManager
19
+ from src.lora_trainer import download_hf_dataset
20
+ from src.audio_processor import AudioProcessor
21
+ from src.utils import setup_logging, load_config
22
+ from acestep.training.dataset_builder import DatasetBuilder
23
+ from acestep.training.configs import LoRAConfig, TrainingConfig
24
+ from acestep.training.trainer import LoRATrainer as FabricLoRATrainer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  # Setup
27
  logger = setup_logging()