BSG-BAT / original_code /data384.py
tphakala's picture
Add BSG-BAT v0.21 ONNX ensemble (6 checkpoints), labels, original preprocessing code, model card
a4d9f29 verified
Raw
History Blame Contribute Delete
6.51 kB
import os
import librosa
import numpy as np
import matplotlib.pyplot as plt
# example of species dictionary:
species_dict = {
'Eptesicus_nilssonii': 0,
'Pipistrellus_nathusii': 1,
'Pipistrellus_pipistrellus': 2,
'Pipistrellus_pygmaeus': 3,
}
def read_species_dict(file):
d=dict()
with open(file) as f:
for line in f:
species, label = line.split()[0:2]
d[species] = int(label)
return d
def read_filelist(file):
filelist = []
with open(file) as f:
for line in f:
fname = line.split()[0]
filelist.append(fname)
return filelist
def read_filelist_and_labels(file, species_dict=species_dict, flim=True):
filelist = []
start_times = []
durations = []
labels = []
fmins = []
fmaxs = []
with open(file) as f:
for line in f:
if flim == True:
fname,start,dur,sp,f1,f2 = line.split()
else:
fname,start,dur,sp = line.split()
filelist.append(fname)
start_times.append(float(start))
durations.append(float(dur))
labels.append(species_dict.get(sp,-1))
if flim == True:
fmins.append(float(f1))
fmaxs.append(float(f2))
if flim == True:
return filelist, start_times, durations, labels, fmins, fmaxs
else:
return filelist, start_times, durations, labels
def wav2image(wavfile, start_time, dur=1.5, ntime=750, nfreq=128):
y,sr = librosa.load(wavfile,sr=384000,offset=start_time, duration=dur)
# note: librosa returns freq-by-time matrix -> transpose it!!!
S = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=1024, hop_length=768, n_mels=nfreq, fmin=9000, fmax=150000).T
tlen, flen = S.shape
return np.log10(S[:ntime,:] + 1e-6)
def normalize(data):
x = (data-np.mean(data))/np.std(data)
return np.clip(x-np.median(x,axis=0), 0.0, 6.0)
def freq2mel(f, f2mel):
# a=librosa.mel_frequencies(n_mels=128,fmin=9000,fmax=150000)
# np.savetxt('mel128_freq9k_150k.txt',a,fmt='%.1f')
# f2mel=np.loadtxt('mel128_freq9k_150k.txt')
return np.searchsorted(f2mel, f)
def extract_band(S, fmin, fmax, f2mel):
i1=freq2mel(fmin, f2mel)
i2=freq2mel(fmax, f2mel)
S2 = np.copy(S)
mi = np.quantile(S[:,i1:i2], 0.1)
S2[:,:i1] = mi
S2[:,i2:] = mi
return S2
def compute_spectrograms(filelist, start_times, durations, ntime=750, nfreq=128, fmin=[], fmax=[], f2mel=[]):
if len(fmin) > 0:
assert len(f2mel) == nfreq, f"length of f2mel must equal to nfreq {nfreq}"
n = len(filelist)
data=np.ndarray((n,ntime,nfreq),dtype='float32')
for i, ifile in enumerate(filelist):
data[i] = normalize(wav2image(ifile, start_times[i], durations[i], ntime, nfreq))
if len(fmin) > 0:
data[i] = extract_band(data[i], fmin[i], fmax[i], f2mel)
return data
def wav2spectrograms(wavfile, ntime=512, nhop=250, nfreq=128):
# note: if original sampling rate not 384000Hz, librosa default resampling method is SLOW, better use faster method
# y,sr = librosa.load(wavfile,sr=384000)
y,sr = librosa.load(wavfile,sr=384000, mono=True, res_type='kaiser_fast')
S = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=1024, hop_length=768, n_mels=nfreq, fmin=9000, fmax=150000).T
n = int(np.max((np.ceil((len(S)-ntime)/nhop),1)))
data=np.ndarray((n,ntime,nfreq), dtype='float32')
if len(S) < ntime:
# recording shorter than desired segment length, do zero padding
X = np.zeros((512,nfreq),dtype='float32')
X[:len(S)] = S
data[0] = normalize(np.log10(X + 1e-6))
else:
# chop into segments every nhop frames (default 250 frames == 0.5s)
for i in range(n):
start_i = i*nhop
if start_i+ntime <= len(S):
data[i] = normalize(np.log10(S[start_i:start_i+ntime] + 1e-6))
else:
# last segment too short, include data from left
start_i = len(S) - ntime
data[i] = normalize(np.log10(S[start_i:start_i+ntime] + 1e-6))
return data
def plot_spectrograms(data,labels=[],ny=1,nx=1,start_index=0,vmi=None,vma=None,num=None):
cm='gray_r'
if num==None:
fig,ax=plt.subplots(ny,nx)
else:
# don't create new figure but overdraw to existing one
fig,ax=plt.subplots(ny,nx, num=num, clear=True)
k=start_index
for j in range(ny):
for i in range(nx):
if k<len(data):
if vmi:
v1=vmi
else:
v1=np.min(data[k])
if vma:
v2=vma
else:
v2=np.max(data[k])
if ny == 1 and nx == 1:
img=ax.imshow(data[k].T, origin='lower', cmap=cm, aspect='auto', vmin=v1, vmax=v2)
fig.colorbar(img,ax=ax)
if len(labels):
ax.set_title(f'{k} ({labels[k]})')
elif ny == 1:
img=ax[i].imshow(data[k].T, origin='lower', cmap=cm, aspect='auto', vmin=v1, vmax=v2)
fig.colorbar(img,ax=ax[i])
if len(labels):
ax[i].set_title(f'{k} ({labels[k]})')
elif nx == 1:
img=ax[j].imshow(data[k].T, origin='lower', cmap=cm, aspect='auto', vmin=v1, vmax=v2)
fig.colorbar(img,ax=ax[j])
if len(labels):
ax[j].set_title(f'{k} ({labels[k]})')
else:
img=ax[j,i].imshow(data[k].T, origin='lower', cmap=cm, aspect='auto', vmin=v1, vmax=v2)
fig.colorbar(img,ax=ax[j,i])
if len(labels):
ax[j,i].set_title(f'{k} ({labels[k]})')
k=k+1
# tight_layout sometimes good sometimes not
plt.tight_layout()
plt.show(block=False)
def plot_probabilities(logits,species=[],num=None):
if num!=None:
# don't create new figure but overdraw to existing one
fig,ax=plt.subplots(1,1, num=num, clear=True)
X = 1/(1+np.exp(-logits))
cm='gray_r'
plt.imshow(X.T, origin='lower', cmap=cm, aspect='auto', vmin=0, vmax=1)
if len(species)>0:
foo=plt.yticks(np.arange(X.shape[1]),species)
plt.tight_layout()
plt.show(block=False)