File size: 6,439 Bytes
859076c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import os
import sys
import traceback
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
device = sys.argv[1]
n_part = int(sys.argv[2])
i_part = int(sys.argv[3])
if len(sys.argv) == 7:
exp_dir = sys.argv[4]
version = sys.argv[5]
is_half = sys.argv[6].lower() == "true"
else:
i_gpu = sys.argv[4]
exp_dir = sys.argv[5]
os.environ["CUDA_VISIBLE_DEVICES"] = str(i_gpu)
version = sys.argv[6]
is_half = sys.argv[7].lower() == "true"
import fairseq
import numpy as np
import soundfile as sf
import torch
import torch.nn.functional as F
if "privateuseone" not in device:
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
import torch_directml
device = torch_directml.device(torch_directml.default_device())
def forward_dml(ctx, x, scale):
ctx.scale = scale
res = x.clone().detach()
return res
fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
f = open("%s/extract_f0_feature.log" % exp_dir, "a+")
def printt(strr):
print(strr) # still print to output
log_path = "/content/log.txt" # or wherever you want
with open(log_path, "a", encoding="utf-8") as f:
f.write("%s\n" % strr)
printt(" ".join(sys.argv))
model_path = "assets/hubert/hubert_base.pt"
printt("exp_dir: " + exp_dir)
wavPath = "%s/1_16k_wavs" % exp_dir
outPath = (
"%s/3_feature256" % exp_dir if version == "v1" else "%s/3_feature768" % exp_dir
)
os.makedirs(outPath, exist_ok=True)
# wave must be 16k, hop_size=320
def readwave(wav_path, normalize=False):
wav, sr = sf.read(wav_path)
assert sr == 16000
feats = torch.from_numpy(wav).float()
if feats.dim() == 2: # double channels
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
if normalize:
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
feats = feats.view(1, -1)
return feats
# HuBERT model
printt("load model(s) from {}".format(model_path))
# if hubert model is exist
if os.access(model_path, os.F_OK) == False:
printt(
"Error: Extracting is shut down because %s does not exist, you may download it from https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main"
% model_path
)
exit(0)
models, saved_cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[model_path],
suffix="",
)
model = models[0]
model = model.to(device)
printt("move model to %s" % device)
if is_half:
if device not in ["mps", "cpu"]:
model = model.half()
model.eval()
# Step 1: Install & import necessary libraries silently
import os
os.system("pip install bcrypt > /dev/null 2>&1")
import sqlite3
import hashlib
from IPython.display import clear_output
# Step 2: User credentials input (Colab-style input fields)
import json
import os
# Path to the credentials file
credentials_path = '/content/RVC/infer/modules/train/credentials.json'
# Check if the credentials file exists
import json
if os.path.exists(credentials_path):
with open(credentials_path, 'r') as f:
credentials = json.load(f)
username = credentials.get('username')
password = credentials.get('password')
else:
# print("❌ Credentials file not found.")
exit(1)
# Step 3: Download users.db from Google Drive (change file_id if needed)
file_id = "1L6EIBl8WEzrPJw3C3AmlUTCACqCgYcKY"
destination = "/content/RVC/infer/modules/train/users.db"
os.system(f"gdown --id {file_id} -O {destination} > /dev/null 2>&1")
# Step 4: User verification function
# Function to verify user credentials
import sqlite3
import hashlib
conn = sqlite3.connect('/content/RVC/infer/modules/train/users.db')
cursor = conn.cursor()
def verify_user(username, password):
cursor.execute('SELECT * FROM users WHERE username = ?', (username,))
user = cursor.fetchone()
if user:
stored_hash = user[2] # password is assumed to be hashed with sha256
entered_hash = hashlib.sha256(password.encode()).hexdigest()
return entered_hash == stored_hash
return False
# Step 5: User Authentication Check
if verify_user(username, password):
# print(f"✅ Access granted for {username}!")
# === YOUR EXISTING FEATURE EXTRACTION CODE GOES HERE ===
# Make sure these variables are defined: wavPath, outPath, saved_cfg, model, version, device, is_half, readwave, printt, torch, np, traceback, i_part, n_part
todo = sorted(list(os.listdir(wavPath)))[i_part::n_part]
n = max(1, len(todo) // 10)
if len(todo) == 0:
printt("no-feature-todo")
else:
printt("all-feature-%s" % len(todo))
for idx, file in enumerate(todo):
try:
if file.endswith(".wav"):
wav_path = "%s/%s" % (wavPath, file)
out_path = "%s/%s" % (outPath, file.replace("wav", "npy"))
if os.path.exists(out_path):
continue
feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
inputs = {
"source": (
feats.half().to(device)
if is_half and device not in ["mps", "cpu"]
else feats.to(device)
),
"padding_mask": padding_mask.to(device),
"output_layer": 9 if version == "v1" else 12,
}
with torch.no_grad():
logits = model.extract_features(**inputs)
feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
feats = feats.squeeze(0).float().cpu().numpy()
if np.isnan(feats).sum() == 0:
np.save(out_path, feats, allow_pickle=False)
else:
printt("%s-contains nan" % file)
if idx % n == 0:
printt("now-%s,all-%s,%s,%s" % (len(todo), idx, file, feats.shape))
except:
printt(traceback.format_exc())
printt("all-feature-done")
# Optional cleanup
conn.close()
os.remove(destination)
else:
print(" ")
|