JeffreyZhou798's picture
Update backend/config.py
1be7add verified
"""
Configuration and Model Management
Implements lazy loading to save memory on CPU environment
"""
import os
import torch
# ============================================================================
# Environment Optimization (CPU)
# ============================================================================
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["TORCH_NUM_THREADS"] = "4"
os.environ["MKL_NUM_THREADS"] = "4"
# ============================================================================
# Global Model Instance (Lazy Loading)
# ============================================================================
_model = None
def get_model():
"""
Lazy load SoulX-Singer model.
Avoids loading on startup to save memory.
Returns:
SoulX-Singer model instance
"""
global _model
if _model is None:
print("Loading SoulX-Singer model on CPU...")
# Import model from soulxsinger directory
import sys
base_path = os.path.dirname(__file__)
soulx_path = os.path.join(base_path, '..', 'soulxsinger')
cli_path = os.path.join(base_path, '..', 'cli')
# Add paths to sys.path
if os.path.exists(soulx_path):
sys.path.insert(0, os.path.dirname(soulx_path))
if os.path.exists(cli_path):
sys.path.insert(0, os.path.dirname(cli_path))
from soulxsinger.utils.file_utils import load_config
from cli.inference import build_model
# Check for model weights - Auto-download if missing
model_weights_path = os.path.join(base_path, '..', 'pretrained_models', 'SoulX-Singer', 'model.pt')
if not os.path.exists(model_weights_path):
print("⚠️ Model weights not found!")
print("🔄 Attempting automatic download from HuggingFace Hub...")
try:
# Install huggingface-hub if not already installed
import subprocess
subprocess.check_call(['pip', 'install', '-q', 'huggingface-hub'])
# Download model weights
from huggingface_hub import snapshot_download
model_dir = os.path.join(base_path, '..', 'pretrained_models', 'SoulX-Singer')
os.makedirs(model_dir, exist_ok=True)
print("⬇️ Downloading SoulX-Singer model (~1.5GB)...")
snapshot_download(
repo_id='Soul-AILab/SoulX-Singer',
local_dir=model_dir,
local_dir_use_symlinks=False,
ignore_patterns=['*.md', '*.txt', 'LICENSE', 'config/**', 'utils/**', 'scripts/**']
)
print("✅ Model downloaded successfully!")
except Exception as e:
print(f"❌ Auto-download failed: {e}")
print("Please manually download model.pt from:")
print("https://huggingface.co/Soul-AILab/SoulX-Singer")
print("And place it at: pretrained_models/SoulX-Singer/model.pt")
raise FileNotFoundError("Model weights not found and auto-download failed. See instructions above.")
# Load config and build model using official build_model function
config_path = os.path.join(soulx_path, "config", "soulxsinger.yaml")
config = load_config(config_path)
_model = build_model(
model_path=model_weights_path,
config=config,
device='cpu',
use_fp16=False # CPU does not support FP16
)
print("✅ Model loaded successfully!")
return _model
def clear_model():
"""
Clear model from memory.
Call this when generation is complete to free resources.
"""
global _model
if _model is not None:
del _model
_model = None
import gc
gc.collect()
print("✅ Model memory cleared")
def get_device() -> str:
"""
Get the device for inference.
Returns:
'cpu' for CPU inference
"""
return 'cpu'
def get_default_voice_path() -> str:
"""
Get path to default voice samples (child voice).
Returns:
Path to DefaultVoice_Child directory
"""
return os.path.join(os.path.dirname(__file__), '..', 'DefaultVoice_Child')
def get_cpu_warning() -> str:
"""
Get CPU environment warning message.
Returns:
Warning text about CPU generation time
"""
return "CPU Environment: Generation may take 5-10 min per second of audio"
# ============================================================================
# Model Inference Settings
# ============================================================================
INFERENCE_CONFIG = {
'n_steps': 12, # Reduced steps for CPU (default 32)
'cfg': 3.0, # CFG scale
'control': 'score', # Score-controlled mode
'use_fp16': False, # FP16 not supported on CPU
'segment_duration': 8.0 # Max segment duration (seconds)
}
def get_inference_config():
"""Get default inference configuration for CPU"""
return INFERENCE_CONFIG.copy()
# ============================================================================
# Audio Configuration
# ============================================================================
AUDIO_CONFIG = {
'sample_rate': 24000, # SoulX-Singer model sample rate
'channels': 1 # Mono audio
}