BSG-BAT / original_code /compute_logits.py
tphakala's picture
Add BSG-BAT v0.21 ONNX ensemble (6 checkpoints), labels, original preprocessing code, model card
a4d9f29 verified
Raw
History Blame Contribute Delete
1.31 kB
import torch
import supervised as s
import numpy as np
import data384 as data
import time
import os
import sys
def do_computation(wavlist_file, modelfile, speciesfile, outdir):
species = data.read_filelist(speciesfile)
nspecies = len(species)
filelist = data.read_filelist(wavlist_file)
n = len(filelist)
device=torch.device('cpu')
model=s.Net(nclasses=nspecies)
model.load_state_dict(torch.load(modelfile, map_location=device))
for wavfile in filelist:
print(f'wavfile {wavfile}')
dat = data.wav2spectrograms(wavfile)
oname = os.path.basename(wavfile)
# if dat is long, use s.classify1_cpu instead of s.classify_cpu so that memory is not exceeded
logits = s.classify1_cpu(dat,model,nspecies)
# logits = s.classify_cpu(dat,model)
np.savetxt(outdir + '/' + oname + '.logits',logits,fmt='%.2f')
return n
def main(args):
wavlist_file = args[0]
modelfile = args[1]
speciesfile = args[2]
outdir = args[3]
starttime=time.time()
n = do_computation(wavlist_file, modelfile, speciesfile, outdir)
endtime=time.time()
print(f'Computation of logits for {n} wavfiles done. Runtime {endtime-starttime:.1f} seconds')
return 0
if __name__=="__main__":
main(sys.argv[1:])