File size: 3,543 Bytes
c4cc266
 
580ce58
c4cc266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580ce58
c4cc266
 
 
580ce58
 
c4cc266
 
 
 
 
 
 
580ce58
c4cc266
 
 
 
 
 
 
 
 
 
 
 
580ce58
c4cc266
 
580ce58
c4cc266
 
 
 
 
580ce58
c4cc266
 
 
580ce58
c4cc266
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Module which exposes functionality for extracting training features from
audio datasets, now with DJCM support.
"""

from __future__ import annotations
from multiprocessing import cpu_count

from ultimate_rvc.core.common import (
    display_progress,
    get_combined_file_hash,
    validate_model,
)
from ultimate_rvc.core.exceptions import (
    Entity,
    ModelAsssociatedEntityNotFoundError,
    Step,
)
from ultimate_rvc.core.train.common import validate_devices
from ultimate_rvc.typing_extra import (
    DeviceType,
    EmbedderModel,
    TrainingF0Method,
)

def extract_features(
    model_name: str,
    f0_method: TrainingF0Method = TrainingF0Method.RMVPE,
    hop_length: int = 128,
    embedder_model: EmbedderModel = EmbedderModel.CONTENTVEC,
    custom_embedder_model: str | None = None,
    include_mutes: int = 2,
    cpu_cores: int = cpu_count(),
    hardware_acceleration: DeviceType = DeviceType.AUTOMATIC,
    gpu_ids: set[int] | None = None,
) -> None:
    model_path = validate_model(model_name, Entity.TRAINING_MODEL)
    sliced_audios16k_path = model_path / "sliced_audios_16k"
    if not sliced_audios16k_path.is_dir() or not any(sliced_audios16k_path.iterdir()):
        raise ModelAsssociatedEntityNotFoundError(
            Entity.PREPROCESSED_AUDIO_DATASET_FILES,
            model_name,
            Step.DATASET_PREPROCESSING,
        )

    custom_embedder_model_path, combined_file_hash = None, None
    chosen_embedder_model, embedder_model_id = [embedder_model] * 2
    if embedder_model == EmbedderModel.CUSTOM:
        custom_embedder_model_path = validate_model(
            custom_embedder_model,
            Entity.CUSTOM_EMBEDDER_MODEL,
        )
        json_file = custom_embedder_model_path / "config.json"
        bin_path = custom_embedder_model_path / "pytorch_model.bin"

        combined_file_hash = get_combined_file_hash([json_file, bin_path])
        chosen_embedder_model = str(custom_embedder_model_path)
        embedder_model_id = f"custom_{combined_file_hash}"

    # Generate f0_method_id
    f0_method_id = f0_method
    if f0_method in {TrainingF0Method.CREPE, TrainingF0Method.CREPE_TINY}:
        f0_method_id = f"{f0_method}_{hop_length}"
    elif f0_method == TrainingF0Method.DJCM:
        f0_method_id = "djcm"  # DJCM tidak butuh hop_length

    device_type, device_ids = validate_devices(hardware_acceleration, gpu_ids)
    devices = (
        [f"{device_type}:{device_id}" for device_id in device_ids]
        if device_ids
        else [device_type]
    )

    from ultimate_rvc.rvc.train.extract import extract  # noqa: PLC0415

    file_infos = extract.initialize_extraction(
        str(model_path),
        f0_method_id,
        embedder_model_id,
    )
    extract.update_model_info(
        str(model_path),
        chosen_embedder_model,
        combined_file_hash,
    )

    display_progress("[~] Extracting pitch features...")
    extract.run_pitch_extraction(file_infos, devices, f0_method, hop_length, cpu_cores)

    display_progress("[~] Extracting audio embeddings...")
    extract.run_embedding_extraction(
        file_infos,
        devices,
        embedder_model,
        str(custom_embedder_model_path) if custom_embedder_model_path else None,
        cpu_cores,
    )

    from ultimate_rvc.rvc.train.extract import preparing_files  # noqa: PLC0415
    preparing_files.generate_config(str(model_path))
    preparing_files.generate_filelist(
        str(model_path),
        include_mutes,
        f0_method_id,
        embedder_model_id,
    )