ShimNet / predict.py
Marek Bukowicki
add shimnet code
64b4096
import torch
torch.set_grad_enabled(False)
import numpy as np
import argparse
from pathlib import Path
import sys, os
from omegaconf import OmegaConf
from src.models import ShimNetWithSCRF, Predictor
# silent deprecation warnings
# https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560
import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
class Defaults:
SCALE = 16.0
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("input_files", help="Input files", nargs="+")
parser.add_argument("--config", help="config file .yaml")
parser.add_argument("--weights", help="model weights")
parser.add_argument("-o", "--output_dir", default=".", help="Output directory")
parser.add_argument("--input_spectrometer_frequency", default=None, type=float, help="spectrometer frequency in MHz (input sample collection frequency). Empty if the same as in the training data")
args = parser.parse_args()
return args
# functions
def resample_input_spectrum(input_freqs, input_spectrum, Mhz_per_point):
"""resample input spectrum to match the model's frequency range"""
freqs = np.arange(input_freqs.min(), input_freqs.max(), Mhz_per_point)
spectrum = np.interp(freqs, input_freqs, input_spectrum)
return freqs, spectrum
def resample_output_spectrum(input_freqs, freqs, prediction):
"""resample prediction to match the input spectrum's frequency range"""
prediction = np.interp(input_freqs, freqs, prediction)
return prediction
def initialize_predictor(config, weights_file):
model = ShimNetWithSCRF(**config.model.kwargs)
predictor = Predictor(model, weights_file)
return predictor
# run
if __name__ == "__main__":
args = parse_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
config = OmegaConf.load(args.config)
model_ppm_per_point = config.data.frq_step / config.metadata.spectrometer_frequency
predictor = initialize_predictor(config, args.weights)
for input_file in args.input_files:
print(f"processing {input_file} ...")
# load data
input_data = np.loadtxt(input_file)
input_freqs_input_ppm, input_spectrum = input_data[:,0], input_data[:,1]
# convert input frequencies to model's frequency - correct for zero filling ad spectrometer frequency
if args.input_spectrometer_frequency is not None:
input_freqs_model_ppm = input_freqs_input_ppm * args.input_spectrometer_frequency / config.metadata.spectrometer_frequency
else:
input_freqs_model_ppm = input_freqs_input_ppm
freqs, spectrum = resample_input_spectrum(input_freqs_model_ppm, input_spectrum, model_ppm_per_point)
spectrum = torch.tensor(spectrum).float()
# scale height of the spectrum
scaling_factor = Defaults.SCALE / spectrum.max()
spectrum *= scaling_factor
# correct spectrum
prediction = predictor(spectrum).numpy()
# rescale height
prediction /= scaling_factor
# resample the output to match the input spectrum
output_prediction = resample_output_spectrum(input_freqs_model_ppm, freqs, prediction)
# save result
output_file = output_dir / f"{Path(input_file).stem}_processed{Path(input_file).suffix}"
np.savetxt(output_file, np.column_stack((input_freqs_input_ppm, output_prediction)))
print(f"saved to {output_file}")