File size: 1,738 Bytes
a4d9f29 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | import torch
import supervised as s
import numpy as np
import data384 as data
import time
import os
import sys
def do_classification(wavlist_file, modelfile, speciesfile, logit_threshold, outfile):
species = data.read_filelist(speciesfile)
nspecies = len(species)
filelist = data.read_filelist(wavlist_file)
n = len(filelist)
device=torch.device('cpu')
model=s.Net(nclasses=nspecies)
model.load_state_dict(torch.load(modelfile, map_location=device))
# summarize detections in each wavfile by number of species detections
with open(outfile, 'w') as fd:
for wavfile in filelist:
print(f'wavfile {wavfile}')
dat = data.wav2spectrograms(wavfile)
oname = os.path.basename(wavfile)
# if dat is long, use s.classify1_cpu instead of s.classify_cpu so that memory is not exceeded
logits = s.classify1_cpu(dat,model,nspecies)
# logits = s.classify_cpu(dat,model)
detected = [];
for i in range(nspecies):
m=sum(logits[:,i] > logit_threshold)
if m>0:
detected.append(species[i] + ',' + str(m))
line = ','.join(detected)
fd.write(wavfile + '\t' + line + '\n')
return n
def main(args):
wavlist_file = args[0]
modelfile = args[1]
speciesfile = args[2]
outfile = args[3]
logit_threshold = 0.0
starttime=time.time()
n = do_classification(wavlist_file, modelfile, speciesfile, logit_threshold, outfile)
endtime=time.time()
print(f'Classification of {n} wavfiles done. Runtime {endtime-starttime:.1f} seconds')
return 0
if __name__=="__main__":
main(sys.argv[1:])
|