ASesYusuf1 commited on
Commit
c419df5
·
verified ·
1 Parent(s): 6d177f8

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +13 -25
utils.py CHANGED
@@ -13,14 +13,7 @@ from omegaconf import OmegaConf
13
  from tqdm.auto import tqdm
14
  from typing import Dict, List, Tuple, Any, Union
15
  import loralib as lora
16
- import gc # For garbage collection to free memory
17
-
18
- # ZeroGPU-specific imports
19
- try:
20
- from spaces import GPU
21
- IS_ZEROGPU = True
22
- except ImportError:
23
- IS_ZEROGPU = False
24
 
25
  def load_config(model_type: str, config_path: str) -> Union[ConfigDict, OmegaConf]:
26
  try:
@@ -37,8 +30,6 @@ def load_config(model_type: str, config_path: str) -> Union[ConfigDict, OmegaCon
37
 
38
  def get_model_from_config(model_type: str, config_path: str) -> Tuple[nn.Module, Any]:
39
  config = load_config(model_type, config_path)
40
-
41
- # Initialize model based on type
42
  model = None
43
  if model_type == 'mdx23c':
44
  from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net
@@ -49,7 +40,6 @@ def get_model_from_config(model_type: str, config_path: str) -> Tuple[nn.Module,
49
  # Add other model types as needed...
50
  else:
51
  raise ValueError(f"Unknown model type: {model_type}")
52
-
53
  return model, config
54
 
55
  def read_audio_transposed(path: str, instr: str = None, skip_err: bool = False) -> Tuple[np.ndarray, int]:
@@ -80,7 +70,7 @@ def apply_tta(
80
  device: str,
81
  model_type: str
82
  ) -> Dict[str, torch.Tensor]:
83
- track_proc_list = [mix[::-1].clone(), -mix.clone()] # Use clone to avoid in-place ops
84
  for i, augmented_mix in enumerate(track_proc_list):
85
  waveforms = demix(config, model, augmented_mix, device, model_type=model_type, pbar=False)
86
  for el in waveforms:
@@ -89,7 +79,9 @@ def apply_tta(
89
  else:
90
  waveforms_orig[el] -= waveforms[el]
91
  del waveforms, augmented_mix
92
- gc.collect() # Free memory after each augmentation
 
 
93
  for el in waveforms_orig:
94
  waveforms_orig[el] /= (len(track_proc_list) + 1)
95
  return waveforms_orig
@@ -102,8 +94,6 @@ def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor:
102
  window[:fade_size] = fadein
103
  return window
104
 
105
- if IS_ZEROGPU:
106
- @GPU
107
  def demix(
108
  config: ConfigDict,
109
  model: nn.Module,
@@ -113,9 +103,8 @@ def demix(
113
  pbar: bool = False
114
  ) -> Dict[str, np.ndarray]:
115
  mix = torch.tensor(mix, dtype=torch.float16, device='cpu') # Start on CPU with FP16
116
-
117
  mode = 'demucs' if model_type == 'htdemucs' else 'generic'
118
-
119
  # Processing parameters
120
  if mode == 'demucs':
121
  chunk_size = config.training.samplerate * config.training.segment
@@ -136,10 +125,12 @@ def demix(
136
 
137
  batch_size = getattr(config.inference, 'batch_size', 1) # Default to 1 for low memory
138
 
139
- # Use autocast for mixed precision
140
- scaler = torch.cuda.amp.GradScaler(enabled=True) if device.startswith('cuda') else None
141
- with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16):
142
- with torch.no_grad(): # No gradients for inference
 
 
143
  req_shape = (num_instruments,) + mix.shape
144
  result = torch.zeros(req_shape, dtype=torch.float16, device='cpu')
145
  counter = torch.zeros(req_shape, dtype=torch.float16, device='cpu')
@@ -212,7 +203,6 @@ def load_not_compatible_weights(model: nn.Module, weights: str, verbose: bool =
212
  old_model = old_model['state']
213
  if 'state_dict' in old_model:
214
  old_model = old_model['state_dict']
215
-
216
  for el in new_model:
217
  if el in old_model and new_model[el].shape == old_model[el].shape:
218
  new_model[el] = old_model[el]
@@ -236,7 +226,6 @@ def load_start_checkpoint(args: argparse.Namespace, model: nn.Module, type_='tra
236
  def bind_lora_to_model(config: Dict[str, Any], model: nn.Module) -> nn.Module:
237
  if 'lora' not in config:
238
  raise ValueError("Configuration must contain the 'lora' key with parameters for LoRA.")
239
-
240
  replaced_layers = 0
241
  for name, module in model.named_modules():
242
  hierarchy = name.split('.')
@@ -259,7 +248,6 @@ def bind_lora_to_model(config: Dict[str, Any], model: nn.Module) -> nn.Module:
259
  replaced_layers += 1
260
  except Exception as e:
261
  print(f"Error replacing layer {name}: {e}")
262
-
263
  print(f"Number of layers replaced with LoRA: {replaced_layers}")
264
  return model
265
 
@@ -276,4 +264,4 @@ def draw_spectrogram(waveform, sample_rate, length, output_file):
276
  fig.colorbar(img, ax=ax, format="%+2.f dB")
277
  if output_file:
278
  plt.savefig(output_file)
279
- plt.close() # Close plot to free memory
 
13
  from tqdm.auto import tqdm
14
  from typing import Dict, List, Tuple, Any, Union
15
  import loralib as lora
16
+ import gc # For garbage collection
 
 
 
 
 
 
 
17
 
18
  def load_config(model_type: str, config_path: str) -> Union[ConfigDict, OmegaConf]:
19
  try:
 
30
 
31
  def get_model_from_config(model_type: str, config_path: str) -> Tuple[nn.Module, Any]:
32
  config = load_config(model_type, config_path)
 
 
33
  model = None
34
  if model_type == 'mdx23c':
35
  from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net
 
40
  # Add other model types as needed...
41
  else:
42
  raise ValueError(f"Unknown model type: {model_type}")
 
43
  return model, config
44
 
45
  def read_audio_transposed(path: str, instr: str = None, skip_err: bool = False) -> Tuple[np.ndarray, int]:
 
70
  device: str,
71
  model_type: str
72
  ) -> Dict[str, torch.Tensor]:
73
+ track_proc_list = [mix[::-1].clone(), -mix.clone()]
74
  for i, augmented_mix in enumerate(track_proc_list):
75
  waveforms = demix(config, model, augmented_mix, device, model_type=model_type, pbar=False)
76
  for el in waveforms:
 
79
  else:
80
  waveforms_orig[el] -= waveforms[el]
81
  del waveforms, augmented_mix
82
+ gc.collect()
83
+ if device.startswith('cuda'):
84
+ torch.cuda.empty_cache()
85
  for el in waveforms_orig:
86
  waveforms_orig[el] /= (len(track_proc_list) + 1)
87
  return waveforms_orig
 
94
  window[:fade_size] = fadein
95
  return window
96
 
 
 
97
  def demix(
98
  config: ConfigDict,
99
  model: nn.Module,
 
103
  pbar: bool = False
104
  ) -> Dict[str, np.ndarray]:
105
  mix = torch.tensor(mix, dtype=torch.float16, device='cpu') # Start on CPU with FP16
 
106
  mode = 'demucs' if model_type == 'htdemucs' else 'generic'
107
+
108
  # Processing parameters
109
  if mode == 'demucs':
110
  chunk_size = config.training.samplerate * config.training.segment
 
125
 
126
  batch_size = getattr(config.inference, 'batch_size', 1) # Default to 1 for low memory
127
 
128
+ # Move model to device
129
+ model = model.to(device)
130
+ model.eval()
131
+
132
+ with torch.no_grad(): # No gradients for inference
133
+ with torch.cuda.amp.autocast(enabled=device.startswith('cuda'), dtype=torch.float16):
134
  req_shape = (num_instruments,) + mix.shape
135
  result = torch.zeros(req_shape, dtype=torch.float16, device='cpu')
136
  counter = torch.zeros(req_shape, dtype=torch.float16, device='cpu')
 
203
  old_model = old_model['state']
204
  if 'state_dict' in old_model:
205
  old_model = old_model['state_dict']
 
206
  for el in new_model:
207
  if el in old_model and new_model[el].shape == old_model[el].shape:
208
  new_model[el] = old_model[el]
 
226
  def bind_lora_to_model(config: Dict[str, Any], model: nn.Module) -> nn.Module:
227
  if 'lora' not in config:
228
  raise ValueError("Configuration must contain the 'lora' key with parameters for LoRA.")
 
229
  replaced_layers = 0
230
  for name, module in model.named_modules():
231
  hierarchy = name.split('.')
 
248
  replaced_layers += 1
249
  except Exception as e:
250
  print(f"Error replacing layer {name}: {e}")
 
251
  print(f"Number of layers replaced with LoRA: {replaced_layers}")
252
  return model
253
 
 
264
  fig.colorbar(img, ax=ax, format="%+2.f dB")
265
  if output_file:
266
  plt.savefig(output_file)
267
+ plt.close()