BSG-BAT / original_code /classify.py
tphakala's picture
Add BSG-BAT v0.21 ONNX ensemble (6 checkpoints), labels, original preprocessing code, model card
a4d9f29 verified
Raw
History Blame Contribute Delete
1.74 kB
import torch
import supervised as s
import numpy as np
import data384 as data
import time
import os
import sys
def do_classification(wavlist_file, modelfile, speciesfile, logit_threshold, outfile):
species = data.read_filelist(speciesfile)
nspecies = len(species)
filelist = data.read_filelist(wavlist_file)
n = len(filelist)
device=torch.device('cpu')
model=s.Net(nclasses=nspecies)
model.load_state_dict(torch.load(modelfile, map_location=device))
# summarize detections in each wavfile by number of species detections
with open(outfile, 'w') as fd:
for wavfile in filelist:
print(f'wavfile {wavfile}')
dat = data.wav2spectrograms(wavfile)
oname = os.path.basename(wavfile)
# if dat is long, use s.classify1_cpu instead of s.classify_cpu so that memory is not exceeded
logits = s.classify1_cpu(dat,model,nspecies)
# logits = s.classify_cpu(dat,model)
detected = [];
for i in range(nspecies):
m=sum(logits[:,i] > logit_threshold)
if m>0:
detected.append(species[i] + ',' + str(m))
line = ','.join(detected)
fd.write(wavfile + '\t' + line + '\n')
return n
def main(args):
wavlist_file = args[0]
modelfile = args[1]
speciesfile = args[2]
outfile = args[3]
logit_threshold = 0.0
starttime=time.time()
n = do_classification(wavlist_file, modelfile, speciesfile, logit_threshold, outfile)
endtime=time.time()
print(f'Classification of {n} wavfiles done. Runtime {endtime-starttime:.1f} seconds')
return 0
if __name__=="__main__":
main(sys.argv[1:])