| import torch | |
| import supervised as s | |
| import numpy as np | |
| import data384 as data | |
| import time | |
| import os | |
| import sys | |
| def do_computation(wavlist_file, modelfile, speciesfile, outdir): | |
| species = data.read_filelist(speciesfile) | |
| nspecies = len(species) | |
| filelist = data.read_filelist(wavlist_file) | |
| n = len(filelist) | |
| device=torch.device('cpu') | |
| model=s.Net(nclasses=nspecies) | |
| model.load_state_dict(torch.load(modelfile, map_location=device)) | |
| for wavfile in filelist: | |
| print(f'wavfile {wavfile}') | |
| dat = data.wav2spectrograms(wavfile) | |
| oname = os.path.basename(wavfile) | |
| # if dat is long, use s.classify1_cpu instead of s.classify_cpu so that memory is not exceeded | |
| logits = s.classify1_cpu(dat,model,nspecies) | |
| # logits = s.classify_cpu(dat,model) | |
| np.savetxt(outdir + '/' + oname + '.logits',logits,fmt='%.2f') | |
| return n | |
| def main(args): | |
| wavlist_file = args[0] | |
| modelfile = args[1] | |
| speciesfile = args[2] | |
| outdir = args[3] | |
| starttime=time.time() | |
| n = do_computation(wavlist_file, modelfile, speciesfile, outdir) | |
| endtime=time.time() | |
| print(f'Computation of logits for {n} wavfiles done. Runtime {endtime-starttime:.1f} seconds') | |
| return 0 | |
| if __name__=="__main__": | |
| main(sys.argv[1:]) | |