import torch torch.set_grad_enabled(False) import numpy as np import argparse from pathlib import Path import sys, os from omegaconf import OmegaConf from src.models import ShimNetWithSCRF, Predictor # silent deprecation warnings # https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560 import warnings warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') class Defaults: SCALE = 16.0 def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("input_files", help="Input files", nargs="+") parser.add_argument("--config", help="config file .yaml") parser.add_argument("--weights", help="model weights") parser.add_argument("-o", "--output_dir", default=".", help="Output directory") parser.add_argument("--input_spectrometer_frequency", default=None, type=float, help="spectrometer frequency in MHz (input sample collection frequency). Empty if the same as in the training data") args = parser.parse_args() return args # functions def resample_input_spectrum(input_freqs, input_spectrum, Mhz_per_point): """resample input spectrum to match the model's frequency range""" freqs = np.arange(input_freqs.min(), input_freqs.max(), Mhz_per_point) spectrum = np.interp(freqs, input_freqs, input_spectrum) return freqs, spectrum def resample_output_spectrum(input_freqs, freqs, prediction): """resample prediction to match the input spectrum's frequency range""" prediction = np.interp(input_freqs, freqs, prediction) return prediction def initialize_predictor(config, weights_file): model = ShimNetWithSCRF(**config.model.kwargs) predictor = Predictor(model, weights_file) return predictor # run if __name__ == "__main__": args = parse_args() output_dir = Path(args.output_dir) output_dir.mkdir(exist_ok=True, parents=True) config = OmegaConf.load(args.config) model_ppm_per_point = config.data.frq_step / config.metadata.spectrometer_frequency predictor = initialize_predictor(config, args.weights) for input_file in args.input_files: print(f"processing {input_file} ...") # load data input_data = np.loadtxt(input_file) input_freqs_input_ppm, input_spectrum = input_data[:,0], input_data[:,1] # convert input frequencies to model's frequency - correct for zero filling ad spectrometer frequency if args.input_spectrometer_frequency is not None: input_freqs_model_ppm = input_freqs_input_ppm * args.input_spectrometer_frequency / config.metadata.spectrometer_frequency else: input_freqs_model_ppm = input_freqs_input_ppm freqs, spectrum = resample_input_spectrum(input_freqs_model_ppm, input_spectrum, model_ppm_per_point) spectrum = torch.tensor(spectrum).float() # scale height of the spectrum scaling_factor = Defaults.SCALE / spectrum.max() spectrum *= scaling_factor # correct spectrum prediction = predictor(spectrum).numpy() # rescale height prediction /= scaling_factor # resample the output to match the input spectrum output_prediction = resample_output_spectrum(input_freqs_model_ppm, freqs, prediction) # save result output_file = output_dir / f"{Path(input_file).stem}_processed{Path(input_file).suffix}" np.savetxt(output_file, np.column_stack((input_freqs_input_ppm, output_prediction))) print(f"saved to {output_file}")