Spaces:
Sleeping
Sleeping
Vivek Vaddina commited on
initial working commit
Browse files- .gitignore +212 -0
- .pixi/config.toml +4 -0
- app.py +71 -0
- models/checkpoint.pth +3 -0
- pixi.lock +0 -0
- pixi.toml +19 -0
- requirements.txt +6 -0
- samples/corvus_corone_XC592284.mp3 +3 -0
- samples/scolopax_rusticola_XC795042.mp3 +3 -0
- src/__init__.py +0 -0
- src/audio.py +17 -0
- src/config.py +90 -0
- src/modeling.py +87 -0
- src/processing.py +52 -0
- src/utils.py +30 -0
.gitignore
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,visualstudiocode
|
| 2 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks,visualstudiocode
|
| 3 |
+
|
| 4 |
+
### JupyterNotebooks ###
|
| 5 |
+
# gitignore template for Jupyter Notebooks
|
| 6 |
+
# website: http://jupyter.org/
|
| 7 |
+
|
| 8 |
+
.ipynb_checkpoints
|
| 9 |
+
*/.ipynb_checkpoints/*
|
| 10 |
+
|
| 11 |
+
# IPython
|
| 12 |
+
profile_default/
|
| 13 |
+
ipython_config.py
|
| 14 |
+
|
| 15 |
+
# Remove previous ipynb_checkpoints
|
| 16 |
+
# git rm -r .ipynb_checkpoints/
|
| 17 |
+
|
| 18 |
+
### Python ###
|
| 19 |
+
# Byte-compiled / optimized / DLL files
|
| 20 |
+
__pycache__/
|
| 21 |
+
*.py[cod]
|
| 22 |
+
*$py.class
|
| 23 |
+
|
| 24 |
+
# C extensions
|
| 25 |
+
*.so
|
| 26 |
+
|
| 27 |
+
# Distribution / packaging
|
| 28 |
+
.Python
|
| 29 |
+
build/
|
| 30 |
+
develop-eggs/
|
| 31 |
+
dist/
|
| 32 |
+
downloads/
|
| 33 |
+
eggs/
|
| 34 |
+
.eggs/
|
| 35 |
+
lib/
|
| 36 |
+
lib64/
|
| 37 |
+
parts/
|
| 38 |
+
sdist/
|
| 39 |
+
var/
|
| 40 |
+
wheels/
|
| 41 |
+
share/python-wheels/
|
| 42 |
+
*.egg-info/
|
| 43 |
+
.installed.cfg
|
| 44 |
+
*.egg
|
| 45 |
+
MANIFEST
|
| 46 |
+
|
| 47 |
+
# PyInstaller
|
| 48 |
+
# Usually these files are written by a python script from a template
|
| 49 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 50 |
+
*.manifest
|
| 51 |
+
*.spec
|
| 52 |
+
|
| 53 |
+
# Installer logs
|
| 54 |
+
pip-log.txt
|
| 55 |
+
pip-delete-this-directory.txt
|
| 56 |
+
|
| 57 |
+
# Unit test / coverage reports
|
| 58 |
+
htmlcov/
|
| 59 |
+
.tox/
|
| 60 |
+
.nox/
|
| 61 |
+
.coverage
|
| 62 |
+
.coverage.*
|
| 63 |
+
.cache
|
| 64 |
+
nosetests.xml
|
| 65 |
+
coverage.xml
|
| 66 |
+
*.cover
|
| 67 |
+
*.py,cover
|
| 68 |
+
.hypothesis/
|
| 69 |
+
.pytest_cache/
|
| 70 |
+
cover/
|
| 71 |
+
|
| 72 |
+
# Translations
|
| 73 |
+
*.mo
|
| 74 |
+
*.pot
|
| 75 |
+
|
| 76 |
+
# Django stuff:
|
| 77 |
+
*.log
|
| 78 |
+
local_settings.py
|
| 79 |
+
db.sqlite3
|
| 80 |
+
db.sqlite3-journal
|
| 81 |
+
|
| 82 |
+
# Flask stuff:
|
| 83 |
+
instance/
|
| 84 |
+
.webassets-cache
|
| 85 |
+
|
| 86 |
+
# Scrapy stuff:
|
| 87 |
+
.scrapy
|
| 88 |
+
|
| 89 |
+
# Sphinx documentation
|
| 90 |
+
docs/_build/
|
| 91 |
+
|
| 92 |
+
# PyBuilder
|
| 93 |
+
.pybuilder/
|
| 94 |
+
target/
|
| 95 |
+
|
| 96 |
+
# Jupyter Notebook
|
| 97 |
+
|
| 98 |
+
# IPython
|
| 99 |
+
|
| 100 |
+
# pyenv
|
| 101 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 102 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 103 |
+
# .python-version
|
| 104 |
+
|
| 105 |
+
# pipenv
|
| 106 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 107 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 108 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 109 |
+
# install all needed dependencies.
|
| 110 |
+
#Pipfile.lock
|
| 111 |
+
|
| 112 |
+
# poetry
|
| 113 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 114 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 115 |
+
# commonly ignored for libraries.
|
| 116 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 117 |
+
#poetry.lock
|
| 118 |
+
|
| 119 |
+
# pdm
|
| 120 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 121 |
+
#pdm.lock
|
| 122 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 123 |
+
# in version control.
|
| 124 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 125 |
+
.pdm.toml
|
| 126 |
+
|
| 127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 128 |
+
__pypackages__/
|
| 129 |
+
|
| 130 |
+
# Celery stuff
|
| 131 |
+
celerybeat-schedule
|
| 132 |
+
celerybeat.pid
|
| 133 |
+
|
| 134 |
+
# SageMath parsed files
|
| 135 |
+
*.sage.py
|
| 136 |
+
|
| 137 |
+
# Environments
|
| 138 |
+
.env
|
| 139 |
+
.venv
|
| 140 |
+
env/
|
| 141 |
+
venv/
|
| 142 |
+
ENV/
|
| 143 |
+
env.bak/
|
| 144 |
+
venv.bak/
|
| 145 |
+
|
| 146 |
+
# Spyder project settings
|
| 147 |
+
.spyderproject
|
| 148 |
+
.spyproject
|
| 149 |
+
|
| 150 |
+
# Rope project settings
|
| 151 |
+
.ropeproject
|
| 152 |
+
|
| 153 |
+
# mkdocs documentation
|
| 154 |
+
/site
|
| 155 |
+
|
| 156 |
+
# mypy
|
| 157 |
+
.mypy_cache/
|
| 158 |
+
.dmypy.json
|
| 159 |
+
dmypy.json
|
| 160 |
+
|
| 161 |
+
# Pyre type checker
|
| 162 |
+
.pyre/
|
| 163 |
+
|
| 164 |
+
# pytype static type analyzer
|
| 165 |
+
.pytype/
|
| 166 |
+
|
| 167 |
+
# Cython debug symbols
|
| 168 |
+
cython_debug/
|
| 169 |
+
|
| 170 |
+
# PyCharm
|
| 171 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 172 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 173 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 174 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 175 |
+
#.idea/
|
| 176 |
+
|
| 177 |
+
### Python Patch ###
|
| 178 |
+
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
| 179 |
+
poetry.toml
|
| 180 |
+
|
| 181 |
+
# ruff
|
| 182 |
+
.ruff_cache/
|
| 183 |
+
|
| 184 |
+
# LSP config files
|
| 185 |
+
pyrightconfig.json
|
| 186 |
+
|
| 187 |
+
### VisualStudioCode ###
|
| 188 |
+
.vscode/*
|
| 189 |
+
!.vscode/settings.json
|
| 190 |
+
!.vscode/tasks.json
|
| 191 |
+
!.vscode/launch.json
|
| 192 |
+
!.vscode/extensions.json
|
| 193 |
+
!.vscode/*.code-snippets
|
| 194 |
+
|
| 195 |
+
# Local History for Visual Studio Code
|
| 196 |
+
.history/
|
| 197 |
+
|
| 198 |
+
# Built Visual Studio Code Extensions
|
| 199 |
+
*.vsix
|
| 200 |
+
|
| 201 |
+
### VisualStudioCode Patch ###
|
| 202 |
+
# Ignore all local history of files
|
| 203 |
+
.history
|
| 204 |
+
.ionide
|
| 205 |
+
|
| 206 |
+
# End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,visualstudiocode
|
| 207 |
+
|
| 208 |
+
## Custom
|
| 209 |
+
data/
|
| 210 |
+
# pixi environments
|
| 211 |
+
.pixi/*
|
| 212 |
+
!.pixi/config.toml
|
.pixi/config.toml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run-post-link-scripts = "insecure"
|
| 2 |
+
|
| 3 |
+
[shell]
|
| 4 |
+
change-ps1 = false
|
app.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from src.config import CKPT_PATH
|
| 5 |
+
from src.modeling import Model
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# -------------------------------------------------
|
| 9 |
+
# Load model once at startup
|
| 10 |
+
# -------------------------------------------------
|
| 11 |
+
MODEL = Model(device="cpu")
|
| 12 |
+
MODEL.load_from_chkpt(Path(CKPT_PATH))
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# -------------------------------------------------
|
| 16 |
+
# Inference function used by Gradio
|
| 17 |
+
# -------------------------------------------------
|
| 18 |
+
def run_inference(audio_file):
|
| 19 |
+
if audio_file is None:
|
| 20 |
+
return None, ""
|
| 21 |
+
|
| 22 |
+
# audio_file is a filepath provided by Gradio
|
| 23 |
+
audio_fp = Path(audio_file)
|
| 24 |
+
|
| 25 |
+
result = MODEL.make_preds(audio_fp)
|
| 26 |
+
name = ' '.join(result.upper().split('_'))
|
| 27 |
+
return f"# 🐦 Identified species:**{name}**"
|
| 28 |
+
|
| 29 |
+
def clear_outputs():
|
| 30 |
+
return None, ""
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# -------------------------------------------------
|
| 34 |
+
# Gradio UI
|
| 35 |
+
# -------------------------------------------------
|
| 36 |
+
with gr.Blocks(title="Bird Species Identification") as demo:
|
| 37 |
+
gr.Markdown(
|
| 38 |
+
"""
|
| 39 |
+
### 🐦 Bird Species Identification
|
| 40 |
+
Upload an audio recording of a bird call to identify the species.
|
| 41 |
+
"""
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
audio_input = gr.Audio(
|
| 45 |
+
sources=["upload"],
|
| 46 |
+
type="filepath",
|
| 47 |
+
label="Upload bird audio"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
output_text = gr.Markdown(
|
| 51 |
+
label="Identified species",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
with gr.Row():
|
| 55 |
+
submit_btn = gr.Button("Identify")
|
| 56 |
+
clear_btn = gr.Button("Clear")
|
| 57 |
+
|
| 58 |
+
submit_btn.click(
|
| 59 |
+
fn=run_inference,
|
| 60 |
+
inputs=audio_input,
|
| 61 |
+
outputs=output_text
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
clear_btn.click(
|
| 65 |
+
fn=clear_outputs,
|
| 66 |
+
outputs=[audio_input, output_text]
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
if __name__ == "__main__":
|
| 71 |
+
demo.launch()
|
models/checkpoint.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7f58c1413fe19595def2cbcb4ba01fced3bd84418874b253ba5529510a677550
|
| 3 |
+
size 85613285
|
pixi.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pixi.toml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[workspace]
|
| 2 |
+
channels = ["conda-forge", "pytorch"]
|
| 3 |
+
name = "munich_bird_identifier"
|
| 4 |
+
platforms = ["linux-64"]
|
| 5 |
+
version = "0.1.0"
|
| 6 |
+
|
| 7 |
+
[tasks]
|
| 8 |
+
|
| 9 |
+
[dependencies]
|
| 10 |
+
python = "3.12.*"
|
| 11 |
+
librosa = ">=0.11.0,<0.12"
|
| 12 |
+
click = ">=8.3.1,<9"
|
| 13 |
+
gradio = ">=6.2.0,<7"
|
| 14 |
+
ipython = ">=9.9.0,<10"
|
| 15 |
+
|
| 16 |
+
[pypi-dependencies]
|
| 17 |
+
torch = { version = "*", index = "https://download.pytorch.org/whl/cpu" }
|
| 18 |
+
torchvision = { version = "*", index = "https://download.pytorch.org/whl/cpu" }
|
| 19 |
+
torchaudio = { version = "*", index = "https://download.pytorch.org/whl/cpu" }
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
torch
|
| 3 |
+
torchvision
|
| 4 |
+
torchaudio
|
| 5 |
+
librosa
|
| 6 |
+
click
|
samples/corvus_corone_XC592284.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:db23dd9478e7dbd8dfd878fb08d900fad694ea93556c2013ab4b954553507957
|
| 3 |
+
size 180652
|
samples/scolopax_rusticola_XC795042.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c1273f5edb7971248c08c78d7bc25ee30a6f9c475e27d204be9fc223c016faac
|
| 3 |
+
size 628162
|
src/__init__.py
ADDED
|
File without changes
|
src/audio.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
from src.config import N_MELS, SR
|
| 3 |
+
|
| 4 |
+
# chosen to be able to use for modeling downstream
|
| 5 |
+
|
| 6 |
+
def load_audio(audio_fp, sr=None, res_type='soxr_hq'):
|
| 7 |
+
wave, sr = librosa.load(audio_fp, sr=sr, res_type=res_type)
|
| 8 |
+
return wave, sr
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_melspec(y, sr=None, plot=False):
|
| 12 |
+
if not sr: sr = SR # default
|
| 13 |
+
mel_power = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=N_MELS, hop_length=1000)
|
| 14 |
+
mel_dB = librosa.power_to_db(mel_power)
|
| 15 |
+
if plot:
|
| 16 |
+
pass
|
| 17 |
+
return mel_dB
|
src/config.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
|
| 6 |
+
load_dotenv()
|
| 7 |
+
|
| 8 |
+
def get_logger(LOG_LEVEL="INFO"):
|
| 9 |
+
LOG_PATH = Path("logs.log")
|
| 10 |
+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
| 11 |
+
|
| 12 |
+
log = logging.Logger("agentic_search")
|
| 13 |
+
log.setLevel(LOG_LEVEL)
|
| 14 |
+
|
| 15 |
+
file_handler = logging.FileHandler(LOG_PATH)
|
| 16 |
+
file_handler.setLevel(LOG_LEVEL)
|
| 17 |
+
file_handler.setFormatter(formatter)
|
| 18 |
+
|
| 19 |
+
log.addHandler(file_handler)
|
| 20 |
+
|
| 21 |
+
return log
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
log = get_logger("DEBUG")
|
| 25 |
+
|
| 26 |
+
CKPT_PATH = Path('models/checkpoint.pth')
|
| 27 |
+
N_MELS = 256
|
| 28 |
+
SR = 32_000
|
| 29 |
+
|
| 30 |
+
# these are the bird species that the model has trained on
|
| 31 |
+
IDX2CODE = {
|
| 32 |
+
0: 'accipiter_gentilis',
|
| 33 |
+
1: 'acrocephalus_scirpaceus',
|
| 34 |
+
2: 'aegolius_funereus',
|
| 35 |
+
3: 'alauda_arvensis',
|
| 36 |
+
4: 'anthus_cervinus',
|
| 37 |
+
5: 'anthus_trivialis',
|
| 38 |
+
6: 'asio_otus',
|
| 39 |
+
7: 'charadrius_dubius',
|
| 40 |
+
8: 'chloris_chloris',
|
| 41 |
+
9: 'coccothraustes_coccothraustes',
|
| 42 |
+
10: 'corvus_corone',
|
| 43 |
+
11: 'corvus_frugilegus',
|
| 44 |
+
12: 'crex_crex',
|
| 45 |
+
13: 'cuculus_canorus',
|
| 46 |
+
14: 'curruca_communis',
|
| 47 |
+
15: 'cyanistes_caeruleus',
|
| 48 |
+
16: 'dendrocopos_major',
|
| 49 |
+
17: 'dryocopus_martius',
|
| 50 |
+
18: 'emberiza_citrinella',
|
| 51 |
+
19: 'erithacus_rubecula',
|
| 52 |
+
20: 'falco_peregrinus',
|
| 53 |
+
21: 'fringilla_coelebs',
|
| 54 |
+
22: 'garrulus_glandarius',
|
| 55 |
+
23: 'lanius_collurio',
|
| 56 |
+
24: 'larus_michahellis',
|
| 57 |
+
25: 'linaria_cannabina',
|
| 58 |
+
26: 'locustella_fluviatilis',
|
| 59 |
+
27: 'locustella_naevia',
|
| 60 |
+
28: 'lullula_arborea',
|
| 61 |
+
29: 'luscinia_megarhynchos',
|
| 62 |
+
30: 'mareca_penelope',
|
| 63 |
+
31: 'motacilla_flava',
|
| 64 |
+
32: 'muscicapa_striata',
|
| 65 |
+
33: 'nucifraga_caryocatactes',
|
| 66 |
+
34: 'nycticorax_nycticorax',
|
| 67 |
+
35: 'nymphicus_hollandicus',
|
| 68 |
+
36: 'parus_major',
|
| 69 |
+
37: 'perdix_perdix',
|
| 70 |
+
38: 'periparus_ater',
|
| 71 |
+
39: 'phoenicurus_phoenicurus',
|
| 72 |
+
40: 'phylloscopus_collybita',
|
| 73 |
+
41: 'phylloscopus_sibilatrix',
|
| 74 |
+
42: 'phylloscopus_trochilus',
|
| 75 |
+
43: 'picus_canus',
|
| 76 |
+
44: 'picus_viridis',
|
| 77 |
+
45: 'poecile_montanus',
|
| 78 |
+
46: 'poecile_palustris',
|
| 79 |
+
47: 'prunella_modularis',
|
| 80 |
+
48: 'saxicola_rubicola',
|
| 81 |
+
49: 'scolopax_rusticola',
|
| 82 |
+
50: 'serinus_serinus',
|
| 83 |
+
51: 'strix_aluco',
|
| 84 |
+
52: 'sylvia_atricapilla',
|
| 85 |
+
53: 'sylvia_borin',
|
| 86 |
+
54: 'troglodytes_troglodytes',
|
| 87 |
+
55: 'turdus_merula',
|
| 88 |
+
56: 'turdus_philomelos'
|
| 89 |
+
}
|
| 90 |
+
CODE2IDX = {v:k for k,v in IDX2CODE.items()}
|
src/modeling.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn, optim
|
| 3 |
+
from torchvision.models import resnet34, ResNet34_Weights
|
| 4 |
+
from src.processing import generate_test_images
|
| 5 |
+
from src.config import IDX2CODE
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BirdNet(nn.Module):
|
| 9 |
+
def __init__(self, n_out=len(IDX2CODE.keys()), pretrained=True, freeze_backbone=True, dropout=.25):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.model = resnet34(weights=ResNet34_Weights.DEFAULT if pretrained else None)
|
| 12 |
+
|
| 13 |
+
# Modify first convolution layer to accept 1-channel grayscale input
|
| 14 |
+
# Original ResNet34 expects 3-channel RGB input
|
| 15 |
+
# We adapt it to accept 1-channel grayscale melspectrogram
|
| 16 |
+
original_conv1 = self.model.conv1
|
| 17 |
+
self.model.conv1 = nn.Conv2d(
|
| 18 |
+
in_channels=1, # Grayscale input
|
| 19 |
+
out_channels=original_conv1.out_channels,
|
| 20 |
+
kernel_size=original_conv1.kernel_size,
|
| 21 |
+
stride=original_conv1.stride,
|
| 22 |
+
padding=original_conv1.padding,
|
| 23 |
+
bias=original_conv1.bias
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
if pretrained:
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
self.model.conv1.weight.data = original_conv1.weight.data.mean(dim=1, keepdim=True)
|
| 29 |
+
|
| 30 |
+
# in_features = self.model.fc.in_features
|
| 31 |
+
# layers = list(self.model.children())[:-2]
|
| 32 |
+
# layers.append(nn.AdaptiveMaxPool2d(1))
|
| 33 |
+
# self.encoder = nn.Sequential(*layers)
|
| 34 |
+
|
| 35 |
+
self.model.fc = nn.Linear(self.model.fc.in_features, n_out)
|
| 36 |
+
# self.model.fc = nn.Sequential(
|
| 37 |
+
# nn.Linear(self.model.fc.in_features, 256),
|
| 38 |
+
# nn.ReLU(),
|
| 39 |
+
# nn.Dropout(dropout),
|
| 40 |
+
# nn.Linear(256, n_out)
|
| 41 |
+
# )
|
| 42 |
+
# Optional: Freeze backbone for fine-tuning (train only the final layer)
|
| 43 |
+
if freeze_backbone:
|
| 44 |
+
for param in self.model.parameters():
|
| 45 |
+
param.requires_grad = False
|
| 46 |
+
# Unfreeze the final layer
|
| 47 |
+
for param in self.model.fc.parameters():
|
| 48 |
+
param.requires_grad = True
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
return self.model(x)
|
| 52 |
+
|
| 53 |
+
class Model:
|
| 54 |
+
def __init__(self, device, n_out=len(IDX2CODE.keys()), loss_fn=nn.CrossEntropyLoss(),
|
| 55 |
+
pretrained=True, freeze_backbone=True, dropout=.1):
|
| 56 |
+
self.n_out = n_out
|
| 57 |
+
self.device = device
|
| 58 |
+
self.model = BirdNet(self.n_out, pretrained=pretrained,
|
| 59 |
+
freeze_backbone=freeze_backbone, dropout=dropout).to(self.device)
|
| 60 |
+
self.lr = 5e-3
|
| 61 |
+
self.loss_fn = loss_fn
|
| 62 |
+
self.opt = optim.Adam(self.model.parameters(), lr=self.lr)
|
| 63 |
+
# self.lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.opt, mode='min', factor=.5, patience=3, min_lr=1e-5)
|
| 64 |
+
self.epoch_train_losses = []
|
| 65 |
+
self.epoch_val_losses = []
|
| 66 |
+
self.epoch_train_accs = []
|
| 67 |
+
self.epoch_val_accs = []
|
| 68 |
+
self.epoch = 0
|
| 69 |
+
|
| 70 |
+
def load_from_chkpt(self, chkpt_path):
|
| 71 |
+
chkpt = torch.load(chkpt_path, weights_only=False, map_location=torch.device(self.device))
|
| 72 |
+
self.epoch = chkpt['epoch']
|
| 73 |
+
self.model.load_state_dict(chkpt['model'])
|
| 74 |
+
self.opt.load_state_dict(chkpt['optim'])
|
| 75 |
+
self.epoch_train_losses = chkpt['train_losses']
|
| 76 |
+
self.epoch_val_losses = chkpt['valid_losses']
|
| 77 |
+
self.epoch_train_accs = chkpt['train_accs']
|
| 78 |
+
self.epoch_val_accs = chkpt['valid_accs']
|
| 79 |
+
|
| 80 |
+
def make_preds(self, fp):
|
| 81 |
+
arrs = generate_test_images(fp)
|
| 82 |
+
self.model.eval();
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
out = self.model(arrs.to(self.device).float())
|
| 85 |
+
labels = out.argmax(dim=1)
|
| 86 |
+
vc = labels.unique(return_counts=True)
|
| 87 |
+
return IDX2CODE[vc[0][vc[1].argmax()].item()]
|
src/processing.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import soundfile as sf
|
| 4 |
+
from src.audio import load_audio, get_melspec
|
| 5 |
+
from src.config import SR
|
| 6 |
+
from src.utils import get_idx, to_square
|
| 7 |
+
|
| 8 |
+
# https://www.kaggle.com/code/tarunpaparaju/birdcall-identification-spectrogram-loader
|
| 9 |
+
def to_imagenet(X, mean=None, std=None, norm_max=None, norm_min=None, eps=1e-6):
|
| 10 |
+
mean = mean or X.mean()
|
| 11 |
+
X = X - mean
|
| 12 |
+
std = std or X.std()
|
| 13 |
+
Xstd = X / (std + eps)
|
| 14 |
+
_min, _max = Xstd.min(), Xstd.max()
|
| 15 |
+
norm_max = norm_max or _max
|
| 16 |
+
norm_min = norm_min or _min
|
| 17 |
+
if (_max - _min) > eps:
|
| 18 |
+
# Normalize to [0, 255]
|
| 19 |
+
V = Xstd
|
| 20 |
+
V[V < norm_min] = norm_min
|
| 21 |
+
V[V > norm_max] = norm_max
|
| 22 |
+
V = (V - norm_min) / (norm_max - norm_min)
|
| 23 |
+
else:
|
| 24 |
+
# Just zero
|
| 25 |
+
V = np.zeros_like(Xstd, dtype=np.uint8)
|
| 26 |
+
return V #np.stack([V]*3, axis=-1)
|
| 27 |
+
|
| 28 |
+
def extract_melspec_as_imgarr(fp, n_secs=8, random_chunk=True, convert_to_int8=False):
|
| 29 |
+
info = sf.info(fp)
|
| 30 |
+
y, _ = load_audio(fp, SR) #, offset=start, duration=n_secs
|
| 31 |
+
while True:
|
| 32 |
+
start, end = get_idx(info.duration, n_secs, random_chunk=random_chunk)
|
| 33 |
+
y2 = y[start:end]
|
| 34 |
+
if len(y2):
|
| 35 |
+
y = y2
|
| 36 |
+
break
|
| 37 |
+
mel_dB = to_square(get_melspec(y, SR))
|
| 38 |
+
try:
|
| 39 |
+
normalised_db = to_imagenet(mel_dB) # replaced minmax_scale
|
| 40 |
+
except:
|
| 41 |
+
normalised_db = torch.zeros_like(torch.as_tensor(mel_dB))
|
| 42 |
+
db_array = np.asarray(normalised_db)*255
|
| 43 |
+
if convert_to_int8:
|
| 44 |
+
db_array = db_array.astype(np.uint8)
|
| 45 |
+
return db_array[::-1].astype(float)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def generate_test_images(fp, n=10):
|
| 49 |
+
arrs = []
|
| 50 |
+
for _ in range(n):
|
| 51 |
+
arrs.append(extract_melspec_as_imgarr(fp))
|
| 52 |
+
return torch.as_tensor(np.array(arrs)).unsqueeze(1)
|
src/utils.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from src.config import SR, CODE2IDX
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_idx(duration, n_secs=5, sr=SR, random_chunk=True):
|
| 8 |
+
num_frames = np.ceil(sr * duration)
|
| 9 |
+
chunk_idx = (n_secs*sr)
|
| 10 |
+
DEFAULT_OFFSET = 10
|
| 11 |
+
start = np.random.randint(DEFAULT_OFFSET, num_frames-chunk_idx) if random_chunk else DEFAULT_OFFSET
|
| 12 |
+
return start, start+chunk_idx
|
| 13 |
+
|
| 14 |
+
def to_square(arr):
|
| 15 |
+
"""Convert (almost square) array to a square array by padding/truncating."""
|
| 16 |
+
rows, cols = arr.shape
|
| 17 |
+
|
| 18 |
+
if cols < rows:
|
| 19 |
+
pad_width = ((0, 0), (0, rows - cols))
|
| 20 |
+
return np.pad(arr, pad_width, mode='constant')
|
| 21 |
+
else:
|
| 22 |
+
return arr[:, :rows]
|
| 23 |
+
|
| 24 |
+
def to_tensor(data):
|
| 25 |
+
return [torch.FloatTensor(x) for x in data]
|
| 26 |
+
|
| 27 |
+
def one_hot(idx):
|
| 28 |
+
y = torch.zeros(len(CODE2IDX.keys()))
|
| 29 |
+
y[idx] = 1.
|
| 30 |
+
return y
|