import torch import supervised as s import numpy as np import data384 as data import time import os import sys def do_classification(wavlist_file, modelfile, speciesfile, logit_threshold, outfile): species = data.read_filelist(speciesfile) nspecies = len(species) filelist = data.read_filelist(wavlist_file) n = len(filelist) device=torch.device('cpu') model=s.Net(nclasses=nspecies) model.load_state_dict(torch.load(modelfile, map_location=device)) # summarize detections in each wavfile by number of species detections with open(outfile, 'w') as fd: for wavfile in filelist: print(f'wavfile {wavfile}') dat = data.wav2spectrograms(wavfile) oname = os.path.basename(wavfile) # if dat is long, use s.classify1_cpu instead of s.classify_cpu so that memory is not exceeded logits = s.classify1_cpu(dat,model,nspecies) # logits = s.classify_cpu(dat,model) detected = []; for i in range(nspecies): m=sum(logits[:,i] > logit_threshold) if m>0: detected.append(species[i] + ',' + str(m)) line = ','.join(detected) fd.write(wavfile + '\t' + line + '\n') return n def main(args): wavlist_file = args[0] modelfile = args[1] speciesfile = args[2] outfile = args[3] logit_threshold = 0.0 starttime=time.time() n = do_classification(wavlist_file, modelfile, speciesfile, logit_threshold, outfile) endtime=time.time() print(f'Classification of {n} wavfiles done. Runtime {endtime-starttime:.1f} seconds') return 0 if __name__=="__main__": main(sys.argv[1:])