Spaces:
Sleeping
Sleeping
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -13,14 +13,7 @@ from omegaconf import OmegaConf
|
|
| 13 |
from tqdm.auto import tqdm
|
| 14 |
from typing import Dict, List, Tuple, Any, Union
|
| 15 |
import loralib as lora
|
| 16 |
-
import gc # For garbage collection
|
| 17 |
-
|
| 18 |
-
# ZeroGPU-specific imports
|
| 19 |
-
try:
|
| 20 |
-
from spaces import GPU
|
| 21 |
-
IS_ZEROGPU = True
|
| 22 |
-
except ImportError:
|
| 23 |
-
IS_ZEROGPU = False
|
| 24 |
|
| 25 |
def load_config(model_type: str, config_path: str) -> Union[ConfigDict, OmegaConf]:
|
| 26 |
try:
|
|
@@ -37,8 +30,6 @@ def load_config(model_type: str, config_path: str) -> Union[ConfigDict, OmegaCon
|
|
| 37 |
|
| 38 |
def get_model_from_config(model_type: str, config_path: str) -> Tuple[nn.Module, Any]:
|
| 39 |
config = load_config(model_type, config_path)
|
| 40 |
-
|
| 41 |
-
# Initialize model based on type
|
| 42 |
model = None
|
| 43 |
if model_type == 'mdx23c':
|
| 44 |
from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net
|
|
@@ -49,7 +40,6 @@ def get_model_from_config(model_type: str, config_path: str) -> Tuple[nn.Module,
|
|
| 49 |
# Add other model types as needed...
|
| 50 |
else:
|
| 51 |
raise ValueError(f"Unknown model type: {model_type}")
|
| 52 |
-
|
| 53 |
return model, config
|
| 54 |
|
| 55 |
def read_audio_transposed(path: str, instr: str = None, skip_err: bool = False) -> Tuple[np.ndarray, int]:
|
|
@@ -80,7 +70,7 @@ def apply_tta(
|
|
| 80 |
device: str,
|
| 81 |
model_type: str
|
| 82 |
) -> Dict[str, torch.Tensor]:
|
| 83 |
-
track_proc_list = [mix[::-1].clone(), -mix.clone()]
|
| 84 |
for i, augmented_mix in enumerate(track_proc_list):
|
| 85 |
waveforms = demix(config, model, augmented_mix, device, model_type=model_type, pbar=False)
|
| 86 |
for el in waveforms:
|
|
@@ -89,7 +79,9 @@ def apply_tta(
|
|
| 89 |
else:
|
| 90 |
waveforms_orig[el] -= waveforms[el]
|
| 91 |
del waveforms, augmented_mix
|
| 92 |
-
gc.collect()
|
|
|
|
|
|
|
| 93 |
for el in waveforms_orig:
|
| 94 |
waveforms_orig[el] /= (len(track_proc_list) + 1)
|
| 95 |
return waveforms_orig
|
|
@@ -102,8 +94,6 @@ def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor:
|
|
| 102 |
window[:fade_size] = fadein
|
| 103 |
return window
|
| 104 |
|
| 105 |
-
if IS_ZEROGPU:
|
| 106 |
-
@GPU
|
| 107 |
def demix(
|
| 108 |
config: ConfigDict,
|
| 109 |
model: nn.Module,
|
|
@@ -113,9 +103,8 @@ def demix(
|
|
| 113 |
pbar: bool = False
|
| 114 |
) -> Dict[str, np.ndarray]:
|
| 115 |
mix = torch.tensor(mix, dtype=torch.float16, device='cpu') # Start on CPU with FP16
|
| 116 |
-
|
| 117 |
mode = 'demucs' if model_type == 'htdemucs' else 'generic'
|
| 118 |
-
|
| 119 |
# Processing parameters
|
| 120 |
if mode == 'demucs':
|
| 121 |
chunk_size = config.training.samplerate * config.training.segment
|
|
@@ -136,10 +125,12 @@ def demix(
|
|
| 136 |
|
| 137 |
batch_size = getattr(config.inference, 'batch_size', 1) # Default to 1 for low memory
|
| 138 |
|
| 139 |
-
#
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
| 143 |
req_shape = (num_instruments,) + mix.shape
|
| 144 |
result = torch.zeros(req_shape, dtype=torch.float16, device='cpu')
|
| 145 |
counter = torch.zeros(req_shape, dtype=torch.float16, device='cpu')
|
|
@@ -212,7 +203,6 @@ def load_not_compatible_weights(model: nn.Module, weights: str, verbose: bool =
|
|
| 212 |
old_model = old_model['state']
|
| 213 |
if 'state_dict' in old_model:
|
| 214 |
old_model = old_model['state_dict']
|
| 215 |
-
|
| 216 |
for el in new_model:
|
| 217 |
if el in old_model and new_model[el].shape == old_model[el].shape:
|
| 218 |
new_model[el] = old_model[el]
|
|
@@ -236,7 +226,6 @@ def load_start_checkpoint(args: argparse.Namespace, model: nn.Module, type_='tra
|
|
| 236 |
def bind_lora_to_model(config: Dict[str, Any], model: nn.Module) -> nn.Module:
|
| 237 |
if 'lora' not in config:
|
| 238 |
raise ValueError("Configuration must contain the 'lora' key with parameters for LoRA.")
|
| 239 |
-
|
| 240 |
replaced_layers = 0
|
| 241 |
for name, module in model.named_modules():
|
| 242 |
hierarchy = name.split('.')
|
|
@@ -259,7 +248,6 @@ def bind_lora_to_model(config: Dict[str, Any], model: nn.Module) -> nn.Module:
|
|
| 259 |
replaced_layers += 1
|
| 260 |
except Exception as e:
|
| 261 |
print(f"Error replacing layer {name}: {e}")
|
| 262 |
-
|
| 263 |
print(f"Number of layers replaced with LoRA: {replaced_layers}")
|
| 264 |
return model
|
| 265 |
|
|
@@ -276,4 +264,4 @@ def draw_spectrogram(waveform, sample_rate, length, output_file):
|
|
| 276 |
fig.colorbar(img, ax=ax, format="%+2.f dB")
|
| 277 |
if output_file:
|
| 278 |
plt.savefig(output_file)
|
| 279 |
-
plt.close()
|
|
|
|
| 13 |
from tqdm.auto import tqdm
|
| 14 |
from typing import Dict, List, Tuple, Any, Union
|
| 15 |
import loralib as lora
|
| 16 |
+
import gc # For garbage collection
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def load_config(model_type: str, config_path: str) -> Union[ConfigDict, OmegaConf]:
|
| 19 |
try:
|
|
|
|
| 30 |
|
| 31 |
def get_model_from_config(model_type: str, config_path: str) -> Tuple[nn.Module, Any]:
|
| 32 |
config = load_config(model_type, config_path)
|
|
|
|
|
|
|
| 33 |
model = None
|
| 34 |
if model_type == 'mdx23c':
|
| 35 |
from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net
|
|
|
|
| 40 |
# Add other model types as needed...
|
| 41 |
else:
|
| 42 |
raise ValueError(f"Unknown model type: {model_type}")
|
|
|
|
| 43 |
return model, config
|
| 44 |
|
| 45 |
def read_audio_transposed(path: str, instr: str = None, skip_err: bool = False) -> Tuple[np.ndarray, int]:
|
|
|
|
| 70 |
device: str,
|
| 71 |
model_type: str
|
| 72 |
) -> Dict[str, torch.Tensor]:
|
| 73 |
+
track_proc_list = [mix[::-1].clone(), -mix.clone()]
|
| 74 |
for i, augmented_mix in enumerate(track_proc_list):
|
| 75 |
waveforms = demix(config, model, augmented_mix, device, model_type=model_type, pbar=False)
|
| 76 |
for el in waveforms:
|
|
|
|
| 79 |
else:
|
| 80 |
waveforms_orig[el] -= waveforms[el]
|
| 81 |
del waveforms, augmented_mix
|
| 82 |
+
gc.collect()
|
| 83 |
+
if device.startswith('cuda'):
|
| 84 |
+
torch.cuda.empty_cache()
|
| 85 |
for el in waveforms_orig:
|
| 86 |
waveforms_orig[el] /= (len(track_proc_list) + 1)
|
| 87 |
return waveforms_orig
|
|
|
|
| 94 |
window[:fade_size] = fadein
|
| 95 |
return window
|
| 96 |
|
|
|
|
|
|
|
| 97 |
def demix(
|
| 98 |
config: ConfigDict,
|
| 99 |
model: nn.Module,
|
|
|
|
| 103 |
pbar: bool = False
|
| 104 |
) -> Dict[str, np.ndarray]:
|
| 105 |
mix = torch.tensor(mix, dtype=torch.float16, device='cpu') # Start on CPU with FP16
|
|
|
|
| 106 |
mode = 'demucs' if model_type == 'htdemucs' else 'generic'
|
| 107 |
+
|
| 108 |
# Processing parameters
|
| 109 |
if mode == 'demucs':
|
| 110 |
chunk_size = config.training.samplerate * config.training.segment
|
|
|
|
| 125 |
|
| 126 |
batch_size = getattr(config.inference, 'batch_size', 1) # Default to 1 for low memory
|
| 127 |
|
| 128 |
+
# Move model to device
|
| 129 |
+
model = model.to(device)
|
| 130 |
+
model.eval()
|
| 131 |
+
|
| 132 |
+
with torch.no_grad(): # No gradients for inference
|
| 133 |
+
with torch.cuda.amp.autocast(enabled=device.startswith('cuda'), dtype=torch.float16):
|
| 134 |
req_shape = (num_instruments,) + mix.shape
|
| 135 |
result = torch.zeros(req_shape, dtype=torch.float16, device='cpu')
|
| 136 |
counter = torch.zeros(req_shape, dtype=torch.float16, device='cpu')
|
|
|
|
| 203 |
old_model = old_model['state']
|
| 204 |
if 'state_dict' in old_model:
|
| 205 |
old_model = old_model['state_dict']
|
|
|
|
| 206 |
for el in new_model:
|
| 207 |
if el in old_model and new_model[el].shape == old_model[el].shape:
|
| 208 |
new_model[el] = old_model[el]
|
|
|
|
| 226 |
def bind_lora_to_model(config: Dict[str, Any], model: nn.Module) -> nn.Module:
|
| 227 |
if 'lora' not in config:
|
| 228 |
raise ValueError("Configuration must contain the 'lora' key with parameters for LoRA.")
|
|
|
|
| 229 |
replaced_layers = 0
|
| 230 |
for name, module in model.named_modules():
|
| 231 |
hierarchy = name.split('.')
|
|
|
|
| 248 |
replaced_layers += 1
|
| 249 |
except Exception as e:
|
| 250 |
print(f"Error replacing layer {name}: {e}")
|
|
|
|
| 251 |
print(f"Number of layers replaced with LoRA: {replaced_layers}")
|
| 252 |
return model
|
| 253 |
|
|
|
|
| 264 |
fig.colorbar(img, ax=ax, format="%+2.f dB")
|
| 265 |
if output_file:
|
| 266 |
plt.savefig(output_file)
|
| 267 |
+
plt.close()
|