Spaces:
Sleeping
Sleeping
Update multi_inference.py
Browse files- multi_inference.py +35 -17
multi_inference.py
CHANGED
|
@@ -81,19 +81,25 @@ class MVSEPLESS:
|
|
| 81 |
call_method: str = "cli",
|
| 82 |
selected_stems: list = None
|
| 83 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
if selected_stems is None:
|
| 85 |
selected_stems = []
|
| 86 |
|
| 87 |
if not input_file:
|
| 88 |
print("Please, input path to input file")
|
| 89 |
-
return [("
|
| 90 |
|
| 91 |
if not os.path.exists(input_file):
|
| 92 |
print("Input file not exist")
|
| 93 |
-
return [("
|
| 94 |
|
| 95 |
-
if "STEM" not in template:
|
| 96 |
-
template = template + "
|
|
|
|
|
|
|
| 97 |
|
| 98 |
print(f"Starting inference: {model_type}/{model_name}, bitrate={output_bitrate}, method={call_method}, stems={selected_stems}")
|
| 99 |
os.makedirs(output_dir, exist_ok=True)
|
|
@@ -103,7 +109,7 @@ class MVSEPLESS:
|
|
| 103 |
info = models_data[model_type][model_name]
|
| 104 |
except KeyError:
|
| 105 |
print("Model not exist")
|
| 106 |
-
return [("
|
| 107 |
|
| 108 |
conf, ckpt = download_model(self.models_cache_dir, model_name, model_type,
|
| 109 |
info["checkpoint_url"], info["config_url"])
|
|
@@ -122,13 +128,18 @@ class MVSEPLESS:
|
|
| 122 |
if selected_stems:
|
| 123 |
instruments = " ".join(f'"{s}"' for s in selected_stems)
|
| 124 |
cmd.append(f'--selected_instruments {instruments}')
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
results_path = os.path.join(output_dir, "results.json")
|
| 128 |
if os.path.exists(results_path):
|
| 129 |
with open(results_path, encoding="utf-8") as f:
|
| 130 |
return json.load(f)
|
| 131 |
-
return [("
|
| 132 |
|
| 133 |
elif call_method == "direct":
|
| 134 |
from separator.msst_separator import mvsep_offline
|
|
@@ -141,14 +152,14 @@ class MVSEPLESS:
|
|
| 141 |
)
|
| 142 |
except Exception as e:
|
| 143 |
print(e)
|
| 144 |
-
return [("
|
| 145 |
|
| 146 |
elif model_type in ["vr", "mdx"]:
|
| 147 |
try:
|
| 148 |
info = models_data[model_type][model_name]
|
| 149 |
except KeyError:
|
| 150 |
print("Model not exist")
|
| 151 |
-
return [("
|
| 152 |
|
| 153 |
if model_type == "vr" and info.get("custom_vr", False):
|
| 154 |
conf, ckpt = download_model(self.models_cache_dir, model_name, model_type,
|
|
@@ -165,13 +176,16 @@ class MVSEPLESS:
|
|
| 165 |
if selected_stems:
|
| 166 |
instruments = " ".join(f'"{s}"' for s in selected_stems)
|
| 167 |
cmd.append(f'--selected_instruments {instruments}')
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
| 170 |
results_path = os.path.join(output_dir, "results.json")
|
| 171 |
if os.path.exists(results_path):
|
| 172 |
with open(results_path, encoding="utf-8") as f:
|
| 173 |
return json.load(f)
|
| 174 |
-
return [("
|
| 175 |
|
| 176 |
elif call_method == "direct":
|
| 177 |
from separator.uvr_sep import custom_vr_separate
|
|
@@ -185,7 +199,7 @@ class MVSEPLESS:
|
|
| 185 |
)
|
| 186 |
except Exception as e:
|
| 187 |
print(e)
|
| 188 |
-
return [("
|
| 189 |
else:
|
| 190 |
if call_method == "cli":
|
| 191 |
cmd = ["python", "-m", "separator.uvr_sep", "uvr",
|
|
@@ -197,13 +211,17 @@ class MVSEPLESS:
|
|
| 197 |
if selected_stems:
|
| 198 |
instruments = " ".join(f'"{s}"' for s in selected_stems)
|
| 199 |
cmd.append(f'--selected_instruments {instruments}')
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
results_path = os.path.join(output_dir, "results.json")
|
| 203 |
if os.path.exists(results_path):
|
| 204 |
with open(results_path, encoding="utf-8") as f:
|
| 205 |
return json.load(f)
|
| 206 |
-
return [("
|
| 207 |
|
| 208 |
elif call_method == "direct":
|
| 209 |
from separator.uvr_sep import non_custom_uvr_inference
|
|
@@ -217,10 +235,10 @@ class MVSEPLESS:
|
|
| 217 |
)
|
| 218 |
except Exception as e:
|
| 219 |
print(e)
|
| 220 |
-
return [("
|
| 221 |
|
| 222 |
print("Unsupported model type")
|
| 223 |
-
return [("
|
| 224 |
|
| 225 |
def parse_args():
|
| 226 |
parser = argparse.ArgumentParser(description="Multi-inference for separation audio in Google Colab")
|
|
|
|
| 81 |
call_method: str = "cli",
|
| 82 |
selected_stems: list = None
|
| 83 |
):
|
| 84 |
+
|
| 85 |
+
if output_dir is None:
|
| 86 |
+
output_dir = os.getcwd()
|
| 87 |
+
|
| 88 |
if selected_stems is None:
|
| 89 |
selected_stems = []
|
| 90 |
|
| 91 |
if not input_file:
|
| 92 |
print("Please, input path to input file")
|
| 93 |
+
return [("Input path is none", "/none/none.mp3")]
|
| 94 |
|
| 95 |
if not os.path.exists(input_file):
|
| 96 |
print("Input file not exist")
|
| 97 |
+
return [("Input file not exist", "/none/none.mp3")]
|
| 98 |
|
| 99 |
+
if "STEM" not in template and template is not None:
|
| 100 |
+
template = template + "_STEM_"
|
| 101 |
+
if not template:
|
| 102 |
+
template = "mvsepless_NAME_(STEM)"
|
| 103 |
|
| 104 |
print(f"Starting inference: {model_type}/{model_name}, bitrate={output_bitrate}, method={call_method}, stems={selected_stems}")
|
| 105 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
| 109 |
info = models_data[model_type][model_name]
|
| 110 |
except KeyError:
|
| 111 |
print("Model not exist")
|
| 112 |
+
return [("Model not exist", "/none/none.mp3")]
|
| 113 |
|
| 114 |
conf, ckpt = download_model(self.models_cache_dir, model_name, model_type,
|
| 115 |
info["checkpoint_url"], info["config_url"])
|
|
|
|
| 128 |
if selected_stems:
|
| 129 |
instruments = " ".join(f'"{s}"' for s in selected_stems)
|
| 130 |
cmd.append(f'--selected_instruments {instruments}')
|
| 131 |
+
try:
|
| 132 |
+
subprocess.run(" ".join(cmd), shell=True, check=True)
|
| 133 |
+
except Exception as e:
|
| 134 |
+
print(e)
|
| 135 |
+
return [("Error", "/none/none.mp3")]
|
| 136 |
+
|
| 137 |
|
| 138 |
results_path = os.path.join(output_dir, "results.json")
|
| 139 |
if os.path.exists(results_path):
|
| 140 |
with open(results_path, encoding="utf-8") as f:
|
| 141 |
return json.load(f)
|
| 142 |
+
return [("Error", "/none/none.mp3")]
|
| 143 |
|
| 144 |
elif call_method == "direct":
|
| 145 |
from separator.msst_separator import mvsep_offline
|
|
|
|
| 152 |
)
|
| 153 |
except Exception as e:
|
| 154 |
print(e)
|
| 155 |
+
return [("Error", "/none/none.mp3")]
|
| 156 |
|
| 157 |
elif model_type in ["vr", "mdx"]:
|
| 158 |
try:
|
| 159 |
info = models_data[model_type][model_name]
|
| 160 |
except KeyError:
|
| 161 |
print("Model not exist")
|
| 162 |
+
return [("Model not exist", "/none/none.mp3")]
|
| 163 |
|
| 164 |
if model_type == "vr" and info.get("custom_vr", False):
|
| 165 |
conf, ckpt = download_model(self.models_cache_dir, model_name, model_type,
|
|
|
|
| 176 |
if selected_stems:
|
| 177 |
instruments = " ".join(f'"{s}"' for s in selected_stems)
|
| 178 |
cmd.append(f'--selected_instruments {instruments}')
|
| 179 |
+
try:
|
| 180 |
+
subprocess.run(" ".join(cmd), shell=True, check=True)
|
| 181 |
+
except Exception as e:
|
| 182 |
+
print(e)
|
| 183 |
+
return [("Error", "/none/none.mp3")]
|
| 184 |
results_path = os.path.join(output_dir, "results.json")
|
| 185 |
if os.path.exists(results_path):
|
| 186 |
with open(results_path, encoding="utf-8") as f:
|
| 187 |
return json.load(f)
|
| 188 |
+
return [("Error", "/none/none.mp3")]
|
| 189 |
|
| 190 |
elif call_method == "direct":
|
| 191 |
from separator.uvr_sep import custom_vr_separate
|
|
|
|
| 199 |
)
|
| 200 |
except Exception as e:
|
| 201 |
print(e)
|
| 202 |
+
return [("Error", "/none/none.mp3")]
|
| 203 |
else:
|
| 204 |
if call_method == "cli":
|
| 205 |
cmd = ["python", "-m", "separator.uvr_sep", "uvr",
|
|
|
|
| 211 |
if selected_stems:
|
| 212 |
instruments = " ".join(f'"{s}"' for s in selected_stems)
|
| 213 |
cmd.append(f'--selected_instruments {instruments}')
|
| 214 |
+
try:
|
| 215 |
+
subprocess.run(" ".join(cmd), shell=True, check=True)
|
| 216 |
+
except Exception as e:
|
| 217 |
+
print(e)
|
| 218 |
+
return [("Error", "/none/none.mp3")]
|
| 219 |
|
| 220 |
results_path = os.path.join(output_dir, "results.json")
|
| 221 |
if os.path.exists(results_path):
|
| 222 |
with open(results_path, encoding="utf-8") as f:
|
| 223 |
return json.load(f)
|
| 224 |
+
return [("Error", "/none/none.mp3")]
|
| 225 |
|
| 226 |
elif call_method == "direct":
|
| 227 |
from separator.uvr_sep import non_custom_uvr_inference
|
|
|
|
| 235 |
)
|
| 236 |
except Exception as e:
|
| 237 |
print(e)
|
| 238 |
+
return [("Error", "/none/none.mp3")]
|
| 239 |
|
| 240 |
print("Unsupported model type")
|
| 241 |
+
return [("Unsupported model type", "/none/none.mp3")]
|
| 242 |
|
| 243 |
def parse_args():
|
| 244 |
parser = argparse.ArgumentParser(description="Multi-inference for separation audio in Google Colab")
|