File size: 2,772 Bytes
599236e
b70b9c3
 
a3ea780
b70b9c3
a3ea780
3e97a5d
 
 
 
b70b9c3
 
 
 
 
 
3e97a5d
 
 
 
 
 
 
 
b70b9c3
a3ea780
b70b9c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3ea780
b70b9c3
 
 
 
 
 
 
 
3e97a5d
b70b9c3
 
 
 
 
a3ea780
 
 
 
 
 
 
 
031f538
 
a3ea780
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional

@dataclass
class ProcessingConfig:
    audio_path: Path = Path("data/audio/0")
    augmented_path: Path = Path("data/audio/")
    log_mel_path: Path = Path("data/preprocessed")
    n_bands: int = 128
    n_mels: int = 128
    frame_size: int = 1024
    hop_size: int = 1024
    sample_rate: int = 44100
    fft_size: int = 8192
    target_seconds: float = 5.0
    augmentation_probability_lists = [
            [0.0, 1.0, 1.0], [1.0, 1.0, 0.0], [1.0, 0.0, 1.0],
            [0.0, 0.0, 0.0], [0.5, 0.5, 0.5],
        ]
    time_stretch_rates = [0.81, 0.93, 1.07, 1.23]
    pitch_shift_rates = [-3.5, -2.5, -2, -1, 1, 2.5, 3, 3.5]
    drc_types = ["radio", "filmstandard", "musicstandard", "speech"]

@dataclass
class DatasetConfig:
    cnn_input_length: int = 128
    sample_rate: int = 44100
    esc50_labels: List[str] = field(default_factory=lambda: [
        'dog', 'rooster', 'pig', 'cow', 'frog',
        'cat', 'hen', 'insects', 'sheep', 'crow',
        'rain', 'sea_waves', 'crackling_fire', 'crickets', 'chirping_birds',
        'water_drops', 'wind', 'pouring_water', 'toilet_flush', 'thunderstorm',
        'crying_baby', 'sneezing', 'clapping', 'breathing', 'coughing',
        'footsteps', 'laughing', 'brushing_teeth', 'snoring', 'drinking_sipping',
        'door_wood_knock', 'mouse_click', 'keyboard_typing', 'door_wood_creaks', 'can_opening',
        'washing_machine', 'vacuum_cleaner', 'clock_alarm', 'clock_tick', 'glass_breaking',
        'helicopter', 'chainsaw', 'siren', 'car_horn', 'engine',
        'train', 'church_bells', 'airplane', 'fireworks', 'hand_saw'
    ])

@dataclass
class DownloadConfig:
    repo_url: str = "https://github.com/karolpiczak/ESC-50/archive/refs/heads/master.zip"
    repo_dst_dir: Path = Path("data")
    audio_dst_dir: Path = field(init=False)
    extracted_dir: Path = os.path.join(repo_dst_dir, "ESC-50-master")
    audio_src_dir = os.path.join(extracted_dir, "audio")
    paths_to_delete: List[str] = field(default_factory=lambda: [
        ".gitignore", "esc50.gif", "LICENSE", "pytest.ini", "README.md",
        "requirements.txt", "tests", "meta", ".github", ".circleci", "ESC-50-master"
    ])

    def __post_init__(self):
        object.__setattr__(self, "audio_dst_dir", self.repo_dst_dir / "audio" / "0")

@dataclass
class TrainConfig:
    epochs: int = 50
    batch_size: int = 100
    lr: int = 0.001
    device = "cuda"
    use_all_patches: bool = True
    samples_per_epoch_fraction: float = 1/8
    checkpoint_dir_cnn: str = "models/cnn/checkpoints"
    checkpoint_dir_transformer: str = "models/transformer/checkpoints"
    save_every_n_epoch: int = 1
    resume_from: Optional[str] = None