File size: 2,608 Bytes
55518dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# utils.py
"""
Utility functions for the music separation project
"""

import os
import torch
import logging
from pathlib import Path
from config import LOG_DIR, MODEL_DIR
from datetime import datetime

def setup_logging():
    """Setup logging configuration"""
    log_file = LOG_DIR / f"music_separator_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )
    
    return logging.getLogger(__name__)

def save_model(model, model_name):
    """Save model weights"""
    try:
        model_path = MODEL_DIR / f"{model_name}.pth"
        torch.save(model.state_dict(), model_path)
        print(f"✅ Model saved to {model_path}")
        return True
    except Exception as e:
        print(f"❌ Error saving model: {str(e)}")
        return False

def load_model(model, model_name):
    """Load model weights"""
    try:
        model_path = MODEL_DIR / f"{model_name}.pth"
        if model_path.exists():
            model.load_state_dict(torch.load(model_path, map_location='cpu'))
            print(f"✅ Model loaded from {model_path}")
            return True
        else:
            print(f"⚠️ Model file {model_path} not found")
            return False
    except Exception as e:
        print(f"❌ Error loading model: {str(e)}")
        return False

def get_system_info():
    """Get comprehensive system information"""
    info = {
        'pytorch_version': torch.__version__,
        'cuda_available': torch.cuda.is_available(),
        'cuda_version': torch.version.cuda if torch.cuda.is_available() else 'N/A',
        'device_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
    }
    
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            info[f'cuda_device_{i}'] = torch.cuda.get_device_name(i)
            info[f'cuda_memory_{i}'] = f"{torch.cuda.get_device_properties(i).total_memory / 1024**3:.1f}GB"
    
    return info

def format_time(seconds):
    """Format seconds into human readable time"""
    if seconds < 60:
        return f"{seconds:.1f}s"
    elif seconds < 3600:
        return f"{seconds/60:.1f}m"
    else:
        return f"{seconds/3600:.1f}h"

def get_audio_duration(file_path):
    """Get audio file duration in seconds"""
    try:
        import soundfile as sf
        info = sf.info(file_path)
        return info.duration
    except:
        return 0