File size: 3,547 Bytes
64b4096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
torch.set_grad_enabled(False)
import numpy as np
import argparse
from pathlib import Path
import sys, os
from omegaconf import OmegaConf

from src.models import ShimNetWithSCRF, Predictor

# silent deprecation warnings
# https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560
import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')

class Defaults:
    SCALE = 16.0

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("input_files", help="Input files", nargs="+") 
    parser.add_argument("--config", help="config file .yaml")
    parser.add_argument("--weights", help="model weights")
    parser.add_argument("-o", "--output_dir", default=".", help="Output directory")
    parser.add_argument("--input_spectrometer_frequency", default=None, type=float, help="spectrometer frequency in MHz (input sample collection frequency). Empty if the same as in the training data")
    args = parser.parse_args()
    return args
    
# functions
def resample_input_spectrum(input_freqs, input_spectrum, Mhz_per_point):
    """resample input spectrum to match the model's frequency range"""
    freqs = np.arange(input_freqs.min(), input_freqs.max(), Mhz_per_point)
    spectrum = np.interp(freqs, input_freqs, input_spectrum)
    return freqs, spectrum

def resample_output_spectrum(input_freqs, freqs, prediction):
    """resample prediction to match the input spectrum's frequency range"""
    prediction = np.interp(input_freqs, freqs, prediction)
    return prediction

def initialize_predictor(config, weights_file):
    model = ShimNetWithSCRF(**config.model.kwargs)
    predictor = Predictor(model, weights_file)
    return predictor

# run
if __name__ == "__main__":
    args = parse_args()
    output_dir = Path(args.output_dir)
    output_dir.mkdir(exist_ok=True, parents=True)

    config = OmegaConf.load(args.config)
    model_ppm_per_point = config.data.frq_step / config.metadata.spectrometer_frequency
    predictor = initialize_predictor(config, args.weights)

    for input_file in args.input_files:
        print(f"processing {input_file} ...")

        # load data
        input_data = np.loadtxt(input_file)
        input_freqs_input_ppm, input_spectrum = input_data[:,0], input_data[:,1]

        # convert input frequencies to model's frequency - correct for zero filling ad spectrometer frequency
        if args.input_spectrometer_frequency is not None:
            input_freqs_model_ppm = input_freqs_input_ppm * args.input_spectrometer_frequency / config.metadata.spectrometer_frequency
        else:
            input_freqs_model_ppm = input_freqs_input_ppm
        
        freqs, spectrum = resample_input_spectrum(input_freqs_model_ppm, input_spectrum, model_ppm_per_point)
        
        spectrum = torch.tensor(spectrum).float()
        # scale height of the spectrum
        scaling_factor = Defaults.SCALE / spectrum.max()
        spectrum *= scaling_factor

        # correct spectrum
        prediction = predictor(spectrum).numpy()

        # rescale height
        prediction /= scaling_factor

        # resample the output to match the input spectrum
        output_prediction = resample_output_spectrum(input_freqs_model_ppm, freqs, prediction)

        # save result
        output_file = output_dir / f"{Path(input_file).stem}_processed{Path(input_file).suffix}"

        np.savetxt(output_file, np.column_stack((input_freqs_input_ppm, output_prediction)))
        print(f"saved to {output_file}")