hubert_base / extract.py
Yoshitaka16's picture
Update extract.py
580ce58 verified
"""
Module which exposes functionality for extracting training features from
audio datasets, now with DJCM support.
"""
from __future__ import annotations
from multiprocessing import cpu_count
from ultimate_rvc.core.common import (
display_progress,
get_combined_file_hash,
validate_model,
)
from ultimate_rvc.core.exceptions import (
Entity,
ModelAsssociatedEntityNotFoundError,
Step,
)
from ultimate_rvc.core.train.common import validate_devices
from ultimate_rvc.typing_extra import (
DeviceType,
EmbedderModel,
TrainingF0Method,
)
def extract_features(
model_name: str,
f0_method: TrainingF0Method = TrainingF0Method.RMVPE,
hop_length: int = 128,
embedder_model: EmbedderModel = EmbedderModel.CONTENTVEC,
custom_embedder_model: str | None = None,
include_mutes: int = 2,
cpu_cores: int = cpu_count(),
hardware_acceleration: DeviceType = DeviceType.AUTOMATIC,
gpu_ids: set[int] | None = None,
) -> None:
model_path = validate_model(model_name, Entity.TRAINING_MODEL)
sliced_audios16k_path = model_path / "sliced_audios_16k"
if not sliced_audios16k_path.is_dir() or not any(sliced_audios16k_path.iterdir()):
raise ModelAsssociatedEntityNotFoundError(
Entity.PREPROCESSED_AUDIO_DATASET_FILES,
model_name,
Step.DATASET_PREPROCESSING,
)
custom_embedder_model_path, combined_file_hash = None, None
chosen_embedder_model, embedder_model_id = [embedder_model] * 2
if embedder_model == EmbedderModel.CUSTOM:
custom_embedder_model_path = validate_model(
custom_embedder_model,
Entity.CUSTOM_EMBEDDER_MODEL,
)
json_file = custom_embedder_model_path / "config.json"
bin_path = custom_embedder_model_path / "pytorch_model.bin"
combined_file_hash = get_combined_file_hash([json_file, bin_path])
chosen_embedder_model = str(custom_embedder_model_path)
embedder_model_id = f"custom_{combined_file_hash}"
# Generate f0_method_id
f0_method_id = f0_method
if f0_method in {TrainingF0Method.CREPE, TrainingF0Method.CREPE_TINY}:
f0_method_id = f"{f0_method}_{hop_length}"
elif f0_method == TrainingF0Method.DJCM:
f0_method_id = "djcm" # DJCM tidak butuh hop_length
device_type, device_ids = validate_devices(hardware_acceleration, gpu_ids)
devices = (
[f"{device_type}:{device_id}" for device_id in device_ids]
if device_ids
else [device_type]
)
from ultimate_rvc.rvc.train.extract import extract # noqa: PLC0415
file_infos = extract.initialize_extraction(
str(model_path),
f0_method_id,
embedder_model_id,
)
extract.update_model_info(
str(model_path),
chosen_embedder_model,
combined_file_hash,
)
display_progress("[~] Extracting pitch features...")
extract.run_pitch_extraction(file_infos, devices, f0_method, hop_length, cpu_cores)
display_progress("[~] Extracting audio embeddings...")
extract.run_embedding_extraction(
file_infos,
devices,
embedder_model,
str(custom_embedder_model_path) if custom_embedder_model_path else None,
cpu_cores,
)
from ultimate_rvc.rvc.train.extract import preparing_files # noqa: PLC0415
preparing_files.generate_config(str(model_path))
preparing_files.generate_filelist(
str(model_path),
include_mutes,
f0_method_id,
embedder_model_id,
)