BSG-BAT / original_code /readme.train
tphakala's picture
Add BSG-BAT v0.21 ONNX ensemble (6 checkpoints), labels, original preprocessing code, model card
a4d9f29 verified
Raw
History Blame Contribute Delete
3.74 kB
# Training data is constructed from annotated wav files
# Each training sample is a 1.5-second long segment that allows max 0.5-second time shifts for 1.0-second long input windows
# example of beginning of filelist_train
# format: filename start_time(seconds) duration(seconds) species min_frequency(Hz) max_frequency(Hz)
train_384kHz/file0001.wav 0.1 1.5 Barbastella_barbastellus 16000.0 53629.6
train_384kHz/file0002.wav 5.0 1.5 Eptesicus_nilssonii 20032.5 61398.4
train_384kHz/file0003.wav 4.0 1.5 Miniopterus_schreibersii 39919.7 120000.0
...
# example of beginning of filelist_bg1
bg1_384kHz/file0001.wav 0.0 1.5 Background
bg1_384kHz/file0002.wav 1.5 1.5 Background
bg1_384kHz/file0003.wav 3.0 1.5 Background
...
# Compute training data spectrograms (1.5-second segments)
import data384 as data
import numpy as np
f2mel=np.loadtxt('mel128_freq9k_150k.txt')
spd=data.read_species_dict('species21bg')
# annotations include bounding boxes, each training segment includes also minimum and maximum frequency
filelist, start_times, durations, labels, fmin, fmax = data.read_filelist_and_labels('filelist_train', spd)
dat = data.compute_spectrograms(filelist, start_times, durations, fmin=fmin, fmax=fmax, f2mel=f2mel)
np.savez('data_train.npz',dat=dat,labels=labels)
# here non-bat background segments don't contain bounding boxes for minimum and maximum frequency
filelist, start_times, durations, labels = data.read_filelist_and_labels('filelist_bg1', spd, flim=False)
dat = data.compute_spectrograms(filelist, start_times, durations)
np.savez('data_bg1.npz',dat=dat,labels=labels)
# second background data represents silence, here for simple example Gaussian noise
dat=np.ndarray((1000,750,128), dtype='float32')
for i in range(len(x)):
dat[i] = data.normalize(np.random.normal(size=(750,128)))
labels=21*np.ones(1000, dtype='int')
np.savez('data_bg2.npz',dat=dat,labels=labels)
# validation data does not include bounding boxes for frequency
filelist, start_times, durations, labels = data.read_filelist_and_labels('filelist_valid', spd, flim=False)
dat = data.compute_spectrograms(filelist, start_times, durations, ntime=512)
np.savez('data_valid.npz',dat=dat,labels=labels)
# CNN training
import torch
import supervised as s
import numpy as np
a1=np.load('data_train.npz')
fgdat=a1['dat']
fglabels=a1['labels']
bg1=np.load('data_bg1.npz')
bg1dat = bg1['dat']
bg1labels =bg1['labels']
bg2=np.load('data_bg2.npz')
bg2dat = bg2['dat']
bg2labels =bg2['labels']
n1=len(fgdat)
n2=len(bg1dat)
n3=len(bg2dat)
trdat=np.concatenate((fgdat,bg1dat,bg2dat))
trlabels=np.concatenate((fglabels,bg1labels,bg2labels))
b1=np.load('data_valid.npz')
valdat=b1['dat']
vallabels=b1['labels']
num_classes=22
ap=s.AugmentationParams(tshift_max=230, tshift_prob=1.0, fshift_max=1, fshift_prob=0.25, fgmix_prob=0.5, bgmix_prob=0.25, scale_min=3.0, scale_max=6.0, scale_prob=0.5)
train_dataset = s.Dataset(trdat, trlabels,n1,n2,n3, ap=ap, eps=0.01, nclasses=num_classes)
train_dataloader = s.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
valid_dataset = s.Dataset(valdat, vallabels, len(valdat), nclasses=num_classes)
valid_dataloader = s.DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=4)
device = torch.device('cuda')
net=s.Net(nclasses=num_classes).to(device=device)
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(net.parameters())
# output model is the model maximizing validation data accuracy during training
tracc, valacc = s.train(net, loss_fn, optimizer, train_dataloader, valid_dataloader, device, nepochs=50, info=1, model_outfile='model1.pt')
# save train and validation data accuracies
np.savez('model1_trainacc.npz',tracc=tracc,valacc=valacc)