Demusics / utils.py
Kremon96's picture
Create utils.py
55518dd verified
# utils.py
"""
Utility functions for the music separation project
"""
import os
import torch
import logging
from pathlib import Path
from config import LOG_DIR, MODEL_DIR
from datetime import datetime
def setup_logging():
"""Setup logging configuration"""
log_file = LOG_DIR / f"music_separator_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
return logging.getLogger(__name__)
def save_model(model, model_name):
"""Save model weights"""
try:
model_path = MODEL_DIR / f"{model_name}.pth"
torch.save(model.state_dict(), model_path)
print(f"βœ… Model saved to {model_path}")
return True
except Exception as e:
print(f"❌ Error saving model: {str(e)}")
return False
def load_model(model, model_name):
"""Load model weights"""
try:
model_path = MODEL_DIR / f"{model_name}.pth"
if model_path.exists():
model.load_state_dict(torch.load(model_path, map_location='cpu'))
print(f"βœ… Model loaded from {model_path}")
return True
else:
print(f"⚠️ Model file {model_path} not found")
return False
except Exception as e:
print(f"❌ Error loading model: {str(e)}")
return False
def get_system_info():
"""Get comprehensive system information"""
info = {
'pytorch_version': torch.__version__,
'cuda_available': torch.cuda.is_available(),
'cuda_version': torch.version.cuda if torch.cuda.is_available() else 'N/A',
'device_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
}
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
info[f'cuda_device_{i}'] = torch.cuda.get_device_name(i)
info[f'cuda_memory_{i}'] = f"{torch.cuda.get_device_properties(i).total_memory / 1024**3:.1f}GB"
return info
def format_time(seconds):
"""Format seconds into human readable time"""
if seconds < 60:
return f"{seconds:.1f}s"
elif seconds < 3600:
return f"{seconds/60:.1f}m"
else:
return f"{seconds/3600:.1f}h"
def get_audio_duration(file_path):
"""Get audio file duration in seconds"""
try:
import soundfile as sf
info = sf.info(file_path)
return info.duration
except:
return 0