Vivek Vaddina commited on
Commit
254b144
·
unverified ·
1 Parent(s): f240b3a

initial working commit

Browse files
.gitignore ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,visualstudiocode
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks,visualstudiocode
3
+
4
+ ### JupyterNotebooks ###
5
+ # gitignore template for Jupyter Notebooks
6
+ # website: http://jupyter.org/
7
+
8
+ .ipynb_checkpoints
9
+ */.ipynb_checkpoints/*
10
+
11
+ # IPython
12
+ profile_default/
13
+ ipython_config.py
14
+
15
+ # Remove previous ipynb_checkpoints
16
+ # git rm -r .ipynb_checkpoints/
17
+
18
+ ### Python ###
19
+ # Byte-compiled / optimized / DLL files
20
+ __pycache__/
21
+ *.py[cod]
22
+ *$py.class
23
+
24
+ # C extensions
25
+ *.so
26
+
27
+ # Distribution / packaging
28
+ .Python
29
+ build/
30
+ develop-eggs/
31
+ dist/
32
+ downloads/
33
+ eggs/
34
+ .eggs/
35
+ lib/
36
+ lib64/
37
+ parts/
38
+ sdist/
39
+ var/
40
+ wheels/
41
+ share/python-wheels/
42
+ *.egg-info/
43
+ .installed.cfg
44
+ *.egg
45
+ MANIFEST
46
+
47
+ # PyInstaller
48
+ # Usually these files are written by a python script from a template
49
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
50
+ *.manifest
51
+ *.spec
52
+
53
+ # Installer logs
54
+ pip-log.txt
55
+ pip-delete-this-directory.txt
56
+
57
+ # Unit test / coverage reports
58
+ htmlcov/
59
+ .tox/
60
+ .nox/
61
+ .coverage
62
+ .coverage.*
63
+ .cache
64
+ nosetests.xml
65
+ coverage.xml
66
+ *.cover
67
+ *.py,cover
68
+ .hypothesis/
69
+ .pytest_cache/
70
+ cover/
71
+
72
+ # Translations
73
+ *.mo
74
+ *.pot
75
+
76
+ # Django stuff:
77
+ *.log
78
+ local_settings.py
79
+ db.sqlite3
80
+ db.sqlite3-journal
81
+
82
+ # Flask stuff:
83
+ instance/
84
+ .webassets-cache
85
+
86
+ # Scrapy stuff:
87
+ .scrapy
88
+
89
+ # Sphinx documentation
90
+ docs/_build/
91
+
92
+ # PyBuilder
93
+ .pybuilder/
94
+ target/
95
+
96
+ # Jupyter Notebook
97
+
98
+ # IPython
99
+
100
+ # pyenv
101
+ # For a library or package, you might want to ignore these files since the code is
102
+ # intended to run in multiple environments; otherwise, check them in:
103
+ # .python-version
104
+
105
+ # pipenv
106
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
107
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
108
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
109
+ # install all needed dependencies.
110
+ #Pipfile.lock
111
+
112
+ # poetry
113
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
114
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
115
+ # commonly ignored for libraries.
116
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
117
+ #poetry.lock
118
+
119
+ # pdm
120
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
121
+ #pdm.lock
122
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
123
+ # in version control.
124
+ # https://pdm.fming.dev/#use-with-ide
125
+ .pdm.toml
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .venv
140
+ env/
141
+ venv/
142
+ ENV/
143
+ env.bak/
144
+ venv.bak/
145
+
146
+ # Spyder project settings
147
+ .spyderproject
148
+ .spyproject
149
+
150
+ # Rope project settings
151
+ .ropeproject
152
+
153
+ # mkdocs documentation
154
+ /site
155
+
156
+ # mypy
157
+ .mypy_cache/
158
+ .dmypy.json
159
+ dmypy.json
160
+
161
+ # Pyre type checker
162
+ .pyre/
163
+
164
+ # pytype static type analyzer
165
+ .pytype/
166
+
167
+ # Cython debug symbols
168
+ cython_debug/
169
+
170
+ # PyCharm
171
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
172
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
173
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
174
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
175
+ #.idea/
176
+
177
+ ### Python Patch ###
178
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
179
+ poetry.toml
180
+
181
+ # ruff
182
+ .ruff_cache/
183
+
184
+ # LSP config files
185
+ pyrightconfig.json
186
+
187
+ ### VisualStudioCode ###
188
+ .vscode/*
189
+ !.vscode/settings.json
190
+ !.vscode/tasks.json
191
+ !.vscode/launch.json
192
+ !.vscode/extensions.json
193
+ !.vscode/*.code-snippets
194
+
195
+ # Local History for Visual Studio Code
196
+ .history/
197
+
198
+ # Built Visual Studio Code Extensions
199
+ *.vsix
200
+
201
+ ### VisualStudioCode Patch ###
202
+ # Ignore all local history of files
203
+ .history
204
+ .ionide
205
+
206
+ # End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,visualstudiocode
207
+
208
+ ## Custom
209
+ data/
210
+ # pixi environments
211
+ .pixi/*
212
+ !.pixi/config.toml
.pixi/config.toml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ run-post-link-scripts = "insecure"
2
+
3
+ [shell]
4
+ change-ps1 = false
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+
4
+ from src.config import CKPT_PATH
5
+ from src.modeling import Model
6
+
7
+
8
+ # -------------------------------------------------
9
+ # Load model once at startup
10
+ # -------------------------------------------------
11
+ MODEL = Model(device="cpu")
12
+ MODEL.load_from_chkpt(Path(CKPT_PATH))
13
+
14
+
15
+ # -------------------------------------------------
16
+ # Inference function used by Gradio
17
+ # -------------------------------------------------
18
+ def run_inference(audio_file):
19
+ if audio_file is None:
20
+ return None, ""
21
+
22
+ # audio_file is a filepath provided by Gradio
23
+ audio_fp = Path(audio_file)
24
+
25
+ result = MODEL.make_preds(audio_fp)
26
+ name = ' '.join(result.upper().split('_'))
27
+ return f"# 🐦 Identified species:**{name}**"
28
+
29
+ def clear_outputs():
30
+ return None, ""
31
+
32
+
33
+ # -------------------------------------------------
34
+ # Gradio UI
35
+ # -------------------------------------------------
36
+ with gr.Blocks(title="Bird Species Identification") as demo:
37
+ gr.Markdown(
38
+ """
39
+ ### 🐦 Bird Species Identification
40
+ Upload an audio recording of a bird call to identify the species.
41
+ """
42
+ )
43
+
44
+ audio_input = gr.Audio(
45
+ sources=["upload"],
46
+ type="filepath",
47
+ label="Upload bird audio"
48
+ )
49
+
50
+ output_text = gr.Markdown(
51
+ label="Identified species",
52
+ )
53
+
54
+ with gr.Row():
55
+ submit_btn = gr.Button("Identify")
56
+ clear_btn = gr.Button("Clear")
57
+
58
+ submit_btn.click(
59
+ fn=run_inference,
60
+ inputs=audio_input,
61
+ outputs=output_text
62
+ )
63
+
64
+ clear_btn.click(
65
+ fn=clear_outputs,
66
+ outputs=[audio_input, output_text]
67
+ )
68
+
69
+
70
+ if __name__ == "__main__":
71
+ demo.launch()
models/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f58c1413fe19595def2cbcb4ba01fced3bd84418874b253ba5529510a677550
3
+ size 85613285
pixi.lock ADDED
The diff for this file is too large to render. See raw diff
 
pixi.toml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [workspace]
2
+ channels = ["conda-forge", "pytorch"]
3
+ name = "munich_bird_identifier"
4
+ platforms = ["linux-64"]
5
+ version = "0.1.0"
6
+
7
+ [tasks]
8
+
9
+ [dependencies]
10
+ python = "3.12.*"
11
+ librosa = ">=0.11.0,<0.12"
12
+ click = ">=8.3.1,<9"
13
+ gradio = ">=6.2.0,<7"
14
+ ipython = ">=9.9.0,<10"
15
+
16
+ [pypi-dependencies]
17
+ torch = { version = "*", index = "https://download.pytorch.org/whl/cpu" }
18
+ torchvision = { version = "*", index = "https://download.pytorch.org/whl/cpu" }
19
+ torchaudio = { version = "*", index = "https://download.pytorch.org/whl/cpu" }
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ torchaudio
5
+ librosa
6
+ click
samples/corvus_corone_XC592284.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db23dd9478e7dbd8dfd878fb08d900fad694ea93556c2013ab4b954553507957
3
+ size 180652
samples/scolopax_rusticola_XC795042.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1273f5edb7971248c08c78d7bc25ee30a6f9c475e27d204be9fc223c016faac
3
+ size 628162
src/__init__.py ADDED
File without changes
src/audio.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ from src.config import N_MELS, SR
3
+
4
+ # chosen to be able to use for modeling downstream
5
+
6
+ def load_audio(audio_fp, sr=None, res_type='soxr_hq'):
7
+ wave, sr = librosa.load(audio_fp, sr=sr, res_type=res_type)
8
+ return wave, sr
9
+
10
+
11
+ def get_melspec(y, sr=None, plot=False):
12
+ if not sr: sr = SR # default
13
+ mel_power = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=N_MELS, hop_length=1000)
14
+ mel_dB = librosa.power_to_db(mel_power)
15
+ if plot:
16
+ pass
17
+ return mel_dB
src/config.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from pathlib import Path
4
+ from dotenv import load_dotenv
5
+
6
+ load_dotenv()
7
+
8
+ def get_logger(LOG_LEVEL="INFO"):
9
+ LOG_PATH = Path("logs.log")
10
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
11
+
12
+ log = logging.Logger("agentic_search")
13
+ log.setLevel(LOG_LEVEL)
14
+
15
+ file_handler = logging.FileHandler(LOG_PATH)
16
+ file_handler.setLevel(LOG_LEVEL)
17
+ file_handler.setFormatter(formatter)
18
+
19
+ log.addHandler(file_handler)
20
+
21
+ return log
22
+
23
+
24
+ log = get_logger("DEBUG")
25
+
26
+ CKPT_PATH = Path('models/checkpoint.pth')
27
+ N_MELS = 256
28
+ SR = 32_000
29
+
30
+ # these are the bird species that the model has trained on
31
+ IDX2CODE = {
32
+ 0: 'accipiter_gentilis',
33
+ 1: 'acrocephalus_scirpaceus',
34
+ 2: 'aegolius_funereus',
35
+ 3: 'alauda_arvensis',
36
+ 4: 'anthus_cervinus',
37
+ 5: 'anthus_trivialis',
38
+ 6: 'asio_otus',
39
+ 7: 'charadrius_dubius',
40
+ 8: 'chloris_chloris',
41
+ 9: 'coccothraustes_coccothraustes',
42
+ 10: 'corvus_corone',
43
+ 11: 'corvus_frugilegus',
44
+ 12: 'crex_crex',
45
+ 13: 'cuculus_canorus',
46
+ 14: 'curruca_communis',
47
+ 15: 'cyanistes_caeruleus',
48
+ 16: 'dendrocopos_major',
49
+ 17: 'dryocopus_martius',
50
+ 18: 'emberiza_citrinella',
51
+ 19: 'erithacus_rubecula',
52
+ 20: 'falco_peregrinus',
53
+ 21: 'fringilla_coelebs',
54
+ 22: 'garrulus_glandarius',
55
+ 23: 'lanius_collurio',
56
+ 24: 'larus_michahellis',
57
+ 25: 'linaria_cannabina',
58
+ 26: 'locustella_fluviatilis',
59
+ 27: 'locustella_naevia',
60
+ 28: 'lullula_arborea',
61
+ 29: 'luscinia_megarhynchos',
62
+ 30: 'mareca_penelope',
63
+ 31: 'motacilla_flava',
64
+ 32: 'muscicapa_striata',
65
+ 33: 'nucifraga_caryocatactes',
66
+ 34: 'nycticorax_nycticorax',
67
+ 35: 'nymphicus_hollandicus',
68
+ 36: 'parus_major',
69
+ 37: 'perdix_perdix',
70
+ 38: 'periparus_ater',
71
+ 39: 'phoenicurus_phoenicurus',
72
+ 40: 'phylloscopus_collybita',
73
+ 41: 'phylloscopus_sibilatrix',
74
+ 42: 'phylloscopus_trochilus',
75
+ 43: 'picus_canus',
76
+ 44: 'picus_viridis',
77
+ 45: 'poecile_montanus',
78
+ 46: 'poecile_palustris',
79
+ 47: 'prunella_modularis',
80
+ 48: 'saxicola_rubicola',
81
+ 49: 'scolopax_rusticola',
82
+ 50: 'serinus_serinus',
83
+ 51: 'strix_aluco',
84
+ 52: 'sylvia_atricapilla',
85
+ 53: 'sylvia_borin',
86
+ 54: 'troglodytes_troglodytes',
87
+ 55: 'turdus_merula',
88
+ 56: 'turdus_philomelos'
89
+ }
90
+ CODE2IDX = {v:k for k,v in IDX2CODE.items()}
src/modeling.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, optim
3
+ from torchvision.models import resnet34, ResNet34_Weights
4
+ from src.processing import generate_test_images
5
+ from src.config import IDX2CODE
6
+
7
+
8
+ class BirdNet(nn.Module):
9
+ def __init__(self, n_out=len(IDX2CODE.keys()), pretrained=True, freeze_backbone=True, dropout=.25):
10
+ super().__init__()
11
+ self.model = resnet34(weights=ResNet34_Weights.DEFAULT if pretrained else None)
12
+
13
+ # Modify first convolution layer to accept 1-channel grayscale input
14
+ # Original ResNet34 expects 3-channel RGB input
15
+ # We adapt it to accept 1-channel grayscale melspectrogram
16
+ original_conv1 = self.model.conv1
17
+ self.model.conv1 = nn.Conv2d(
18
+ in_channels=1, # Grayscale input
19
+ out_channels=original_conv1.out_channels,
20
+ kernel_size=original_conv1.kernel_size,
21
+ stride=original_conv1.stride,
22
+ padding=original_conv1.padding,
23
+ bias=original_conv1.bias
24
+ )
25
+
26
+ if pretrained:
27
+ with torch.no_grad():
28
+ self.model.conv1.weight.data = original_conv1.weight.data.mean(dim=1, keepdim=True)
29
+
30
+ # in_features = self.model.fc.in_features
31
+ # layers = list(self.model.children())[:-2]
32
+ # layers.append(nn.AdaptiveMaxPool2d(1))
33
+ # self.encoder = nn.Sequential(*layers)
34
+
35
+ self.model.fc = nn.Linear(self.model.fc.in_features, n_out)
36
+ # self.model.fc = nn.Sequential(
37
+ # nn.Linear(self.model.fc.in_features, 256),
38
+ # nn.ReLU(),
39
+ # nn.Dropout(dropout),
40
+ # nn.Linear(256, n_out)
41
+ # )
42
+ # Optional: Freeze backbone for fine-tuning (train only the final layer)
43
+ if freeze_backbone:
44
+ for param in self.model.parameters():
45
+ param.requires_grad = False
46
+ # Unfreeze the final layer
47
+ for param in self.model.fc.parameters():
48
+ param.requires_grad = True
49
+
50
+ def forward(self, x):
51
+ return self.model(x)
52
+
53
+ class Model:
54
+ def __init__(self, device, n_out=len(IDX2CODE.keys()), loss_fn=nn.CrossEntropyLoss(),
55
+ pretrained=True, freeze_backbone=True, dropout=.1):
56
+ self.n_out = n_out
57
+ self.device = device
58
+ self.model = BirdNet(self.n_out, pretrained=pretrained,
59
+ freeze_backbone=freeze_backbone, dropout=dropout).to(self.device)
60
+ self.lr = 5e-3
61
+ self.loss_fn = loss_fn
62
+ self.opt = optim.Adam(self.model.parameters(), lr=self.lr)
63
+ # self.lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.opt, mode='min', factor=.5, patience=3, min_lr=1e-5)
64
+ self.epoch_train_losses = []
65
+ self.epoch_val_losses = []
66
+ self.epoch_train_accs = []
67
+ self.epoch_val_accs = []
68
+ self.epoch = 0
69
+
70
+ def load_from_chkpt(self, chkpt_path):
71
+ chkpt = torch.load(chkpt_path, weights_only=False, map_location=torch.device(self.device))
72
+ self.epoch = chkpt['epoch']
73
+ self.model.load_state_dict(chkpt['model'])
74
+ self.opt.load_state_dict(chkpt['optim'])
75
+ self.epoch_train_losses = chkpt['train_losses']
76
+ self.epoch_val_losses = chkpt['valid_losses']
77
+ self.epoch_train_accs = chkpt['train_accs']
78
+ self.epoch_val_accs = chkpt['valid_accs']
79
+
80
+ def make_preds(self, fp):
81
+ arrs = generate_test_images(fp)
82
+ self.model.eval();
83
+ with torch.no_grad():
84
+ out = self.model(arrs.to(self.device).float())
85
+ labels = out.argmax(dim=1)
86
+ vc = labels.unique(return_counts=True)
87
+ return IDX2CODE[vc[0][vc[1].argmax()].item()]
src/processing.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import soundfile as sf
4
+ from src.audio import load_audio, get_melspec
5
+ from src.config import SR
6
+ from src.utils import get_idx, to_square
7
+
8
+ # https://www.kaggle.com/code/tarunpaparaju/birdcall-identification-spectrogram-loader
9
+ def to_imagenet(X, mean=None, std=None, norm_max=None, norm_min=None, eps=1e-6):
10
+ mean = mean or X.mean()
11
+ X = X - mean
12
+ std = std or X.std()
13
+ Xstd = X / (std + eps)
14
+ _min, _max = Xstd.min(), Xstd.max()
15
+ norm_max = norm_max or _max
16
+ norm_min = norm_min or _min
17
+ if (_max - _min) > eps:
18
+ # Normalize to [0, 255]
19
+ V = Xstd
20
+ V[V < norm_min] = norm_min
21
+ V[V > norm_max] = norm_max
22
+ V = (V - norm_min) / (norm_max - norm_min)
23
+ else:
24
+ # Just zero
25
+ V = np.zeros_like(Xstd, dtype=np.uint8)
26
+ return V #np.stack([V]*3, axis=-1)
27
+
28
+ def extract_melspec_as_imgarr(fp, n_secs=8, random_chunk=True, convert_to_int8=False):
29
+ info = sf.info(fp)
30
+ y, _ = load_audio(fp, SR) #, offset=start, duration=n_secs
31
+ while True:
32
+ start, end = get_idx(info.duration, n_secs, random_chunk=random_chunk)
33
+ y2 = y[start:end]
34
+ if len(y2):
35
+ y = y2
36
+ break
37
+ mel_dB = to_square(get_melspec(y, SR))
38
+ try:
39
+ normalised_db = to_imagenet(mel_dB) # replaced minmax_scale
40
+ except:
41
+ normalised_db = torch.zeros_like(torch.as_tensor(mel_dB))
42
+ db_array = np.asarray(normalised_db)*255
43
+ if convert_to_int8:
44
+ db_array = db_array.astype(np.uint8)
45
+ return db_array[::-1].astype(float)
46
+
47
+
48
+ def generate_test_images(fp, n=10):
49
+ arrs = []
50
+ for _ in range(n):
51
+ arrs.append(extract_melspec_as_imgarr(fp))
52
+ return torch.as_tensor(np.array(arrs)).unsqueeze(1)
src/utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from src.config import SR, CODE2IDX
5
+
6
+
7
+ def get_idx(duration, n_secs=5, sr=SR, random_chunk=True):
8
+ num_frames = np.ceil(sr * duration)
9
+ chunk_idx = (n_secs*sr)
10
+ DEFAULT_OFFSET = 10
11
+ start = np.random.randint(DEFAULT_OFFSET, num_frames-chunk_idx) if random_chunk else DEFAULT_OFFSET
12
+ return start, start+chunk_idx
13
+
14
+ def to_square(arr):
15
+ """Convert (almost square) array to a square array by padding/truncating."""
16
+ rows, cols = arr.shape
17
+
18
+ if cols < rows:
19
+ pad_width = ((0, 0), (0, rows - cols))
20
+ return np.pad(arr, pad_width, mode='constant')
21
+ else:
22
+ return arr[:, :rows]
23
+
24
+ def to_tensor(data):
25
+ return [torch.FloatTensor(x) for x in data]
26
+
27
+ def one_hot(idx):
28
+ y = torch.zeros(len(CODE2IDX.keys()))
29
+ y[idx] = 1.
30
+ return y