|
|
|
|
|
""" |
|
|
Utility functions for the music separation project |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from config import LOG_DIR, MODEL_DIR |
|
|
from datetime import datetime |
|
|
|
|
|
def setup_logging(): |
|
|
"""Setup logging configuration""" |
|
|
log_file = LOG_DIR / f"music_separator_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
handlers=[ |
|
|
logging.FileHandler(log_file), |
|
|
logging.StreamHandler() |
|
|
] |
|
|
) |
|
|
|
|
|
return logging.getLogger(__name__) |
|
|
|
|
|
def save_model(model, model_name): |
|
|
"""Save model weights""" |
|
|
try: |
|
|
model_path = MODEL_DIR / f"{model_name}.pth" |
|
|
torch.save(model.state_dict(), model_path) |
|
|
print(f"β
Model saved to {model_path}") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"β Error saving model: {str(e)}") |
|
|
return False |
|
|
|
|
|
def load_model(model, model_name): |
|
|
"""Load model weights""" |
|
|
try: |
|
|
model_path = MODEL_DIR / f"{model_name}.pth" |
|
|
if model_path.exists(): |
|
|
model.load_state_dict(torch.load(model_path, map_location='cpu')) |
|
|
print(f"β
Model loaded from {model_path}") |
|
|
return True |
|
|
else: |
|
|
print(f"β οΈ Model file {model_path} not found") |
|
|
return False |
|
|
except Exception as e: |
|
|
print(f"β Error loading model: {str(e)}") |
|
|
return False |
|
|
|
|
|
def get_system_info(): |
|
|
"""Get comprehensive system information""" |
|
|
info = { |
|
|
'pytorch_version': torch.__version__, |
|
|
'cuda_available': torch.cuda.is_available(), |
|
|
'cuda_version': torch.version.cuda if torch.cuda.is_available() else 'N/A', |
|
|
'device_count': torch.cuda.device_count() if torch.cuda.is_available() else 0, |
|
|
} |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
for i in range(torch.cuda.device_count()): |
|
|
info[f'cuda_device_{i}'] = torch.cuda.get_device_name(i) |
|
|
info[f'cuda_memory_{i}'] = f"{torch.cuda.get_device_properties(i).total_memory / 1024**3:.1f}GB" |
|
|
|
|
|
return info |
|
|
|
|
|
def format_time(seconds): |
|
|
"""Format seconds into human readable time""" |
|
|
if seconds < 60: |
|
|
return f"{seconds:.1f}s" |
|
|
elif seconds < 3600: |
|
|
return f"{seconds/60:.1f}m" |
|
|
else: |
|
|
return f"{seconds/3600:.1f}h" |
|
|
|
|
|
def get_audio_duration(file_path): |
|
|
"""Get audio file duration in seconds""" |
|
|
try: |
|
|
import soundfile as sf |
|
|
info = sf.info(file_path) |
|
|
return info.duration |
|
|
except: |
|
|
return 0 |