| import torch |
| import supervised as s |
| import numpy as np |
| import data384 as data |
| import time |
| import os |
| import sys |
|
|
| def do_classification(wavlist_file, modelfile, speciesfile, logit_threshold, outfile): |
| species = data.read_filelist(speciesfile) |
| nspecies = len(species) |
| filelist = data.read_filelist(wavlist_file) |
| n = len(filelist) |
| device=torch.device('cpu') |
| model=s.Net(nclasses=nspecies) |
| model.load_state_dict(torch.load(modelfile, map_location=device)) |
|
|
| |
| |
| with open(outfile, 'w') as fd: |
| for wavfile in filelist: |
| print(f'wavfile {wavfile}') |
| dat = data.wav2spectrograms(wavfile) |
| oname = os.path.basename(wavfile) |
| |
| logits = s.classify1_cpu(dat,model,nspecies) |
| |
| detected = []; |
| for i in range(nspecies): |
| m=sum(logits[:,i] > logit_threshold) |
| if m>0: |
| detected.append(species[i] + ',' + str(m)) |
| line = ','.join(detected) |
| fd.write(wavfile + '\t' + line + '\n') |
| return n |
| |
| def main(args): |
| wavlist_file = args[0] |
| modelfile = args[1] |
| speciesfile = args[2] |
| outfile = args[3] |
| logit_threshold = 0.0 |
| |
| starttime=time.time() |
| n = do_classification(wavlist_file, modelfile, speciesfile, logit_threshold, outfile) |
| endtime=time.time() |
| print(f'Classification of {n} wavfiles done. Runtime {endtime-starttime:.1f} seconds') |
| return 0 |
|
|
| if __name__=="__main__": |
| main(sys.argv[1:]) |
|
|