Kremon96 commited on
Commit
55518dd
·
verified ·
1 Parent(s): 14c0824

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +86 -0
utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+ """
3
+ Utility functions for the music separation project
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ import logging
9
+ from pathlib import Path
10
+ from config import LOG_DIR, MODEL_DIR
11
+ from datetime import datetime
12
+
13
+ def setup_logging():
14
+ """Setup logging configuration"""
15
+ log_file = LOG_DIR / f"music_separator_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
16
+
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
20
+ handlers=[
21
+ logging.FileHandler(log_file),
22
+ logging.StreamHandler()
23
+ ]
24
+ )
25
+
26
+ return logging.getLogger(__name__)
27
+
28
+ def save_model(model, model_name):
29
+ """Save model weights"""
30
+ try:
31
+ model_path = MODEL_DIR / f"{model_name}.pth"
32
+ torch.save(model.state_dict(), model_path)
33
+ print(f"✅ Model saved to {model_path}")
34
+ return True
35
+ except Exception as e:
36
+ print(f"❌ Error saving model: {str(e)}")
37
+ return False
38
+
39
+ def load_model(model, model_name):
40
+ """Load model weights"""
41
+ try:
42
+ model_path = MODEL_DIR / f"{model_name}.pth"
43
+ if model_path.exists():
44
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
45
+ print(f"✅ Model loaded from {model_path}")
46
+ return True
47
+ else:
48
+ print(f"⚠️ Model file {model_path} not found")
49
+ return False
50
+ except Exception as e:
51
+ print(f"❌ Error loading model: {str(e)}")
52
+ return False
53
+
54
+ def get_system_info():
55
+ """Get comprehensive system information"""
56
+ info = {
57
+ 'pytorch_version': torch.__version__,
58
+ 'cuda_available': torch.cuda.is_available(),
59
+ 'cuda_version': torch.version.cuda if torch.cuda.is_available() else 'N/A',
60
+ 'device_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
61
+ }
62
+
63
+ if torch.cuda.is_available():
64
+ for i in range(torch.cuda.device_count()):
65
+ info[f'cuda_device_{i}'] = torch.cuda.get_device_name(i)
66
+ info[f'cuda_memory_{i}'] = f"{torch.cuda.get_device_properties(i).total_memory / 1024**3:.1f}GB"
67
+
68
+ return info
69
+
70
+ def format_time(seconds):
71
+ """Format seconds into human readable time"""
72
+ if seconds < 60:
73
+ return f"{seconds:.1f}s"
74
+ elif seconds < 3600:
75
+ return f"{seconds/60:.1f}m"
76
+ else:
77
+ return f"{seconds/3600:.1f}h"
78
+
79
+ def get_audio_duration(file_path):
80
+ """Get audio file duration in seconds"""
81
+ try:
82
+ import soundfile as sf
83
+ info = sf.info(file_path)
84
+ return info.duration
85
+ except:
86
+ return 0