Marek Bukowicki commited on
Commit
a86e7e6
·
1 Parent(s): 9e390be

"allow modular models, move get_model_ppm_per_point to utils"

Browse files
Files changed (3) hide show
  1. predict-gui.py +2 -2
  2. predict.py +1 -1
  3. shimnet/predict_utils.py +8 -1
predict-gui.py CHANGED
@@ -7,7 +7,7 @@ from omegaconf import OmegaConf
7
  import gradio as gr
8
  import plotly.graph_objects as go
9
 
10
- from shimnet.predict_utils import Defaults, resample_input_spectrum, resample_output_spectrum, initialize_predictor
11
 
12
  # silent deprecation warnings
13
  import warnings
@@ -35,7 +35,7 @@ def process_file(input_file, config_file, weights_file, input_spectrometer_frequ
35
  input_spectrometer_frequency = None
36
  # Load configuration and initialize predictor
37
  config = OmegaConf.load(config_file)
38
- model_ppm_per_point = config.data.frq_step / config.metadata.spectrometer_frequency
39
  predictor = initialize_predictor(config, weights_file)
40
 
41
  # Load input data
 
7
  import gradio as gr
8
  import plotly.graph_objects as go
9
 
10
+ from shimnet.predict_utils import Defaults, resample_input_spectrum, resample_output_spectrum, initialize_predictor, get_model_ppm_per_point
11
 
12
  # silent deprecation warnings
13
  import warnings
 
35
  input_spectrometer_frequency = None
36
  # Load configuration and initialize predictor
37
  config = OmegaConf.load(config_file)
38
+ model_ppm_per_point = get_model_ppm_per_point(config)
39
  predictor = initialize_predictor(config, weights_file)
40
 
41
  # Load input data
predict.py CHANGED
@@ -33,7 +33,7 @@ if __name__ == "__main__":
33
  output_dir.mkdir(exist_ok=True, parents=True)
34
 
35
  config = OmegaConf.load(args.config)
36
- model_ppm_per_point = config.data.get("frq_step", config.metadata.get("frq_step")) / config.metadata.spectrometer_frequency
37
  predictor = initialize_predictor(config, args.weights)
38
 
39
  for input_file in args.input_files:
 
33
  output_dir.mkdir(exist_ok=True, parents=True)
34
 
35
  config = OmegaConf.load(args.config)
36
+ model_ppm_per_point = get_model_ppm_per_point(config)
37
  predictor = initialize_predictor(config, args.weights)
38
 
39
  for input_file in args.input_files:
shimnet/predict_utils.py CHANGED
@@ -1,4 +1,5 @@
1
  import numpy as np
 
2
 
3
  from .models import ShimNetWithSCRF, Predictor
4
 
@@ -19,6 +20,12 @@ def resample_output_spectrum(input_freqs, freqs, prediction):
19
  return prediction
20
 
21
  def initialize_predictor(config, weights_file):
22
- model = ShimNetWithSCRF(**config.model.kwargs)
 
 
 
23
  predictor = Predictor(model, weights_file)
24
  return predictor
 
 
 
 
1
  import numpy as np
2
+ from hydra.utils import instantiate
3
 
4
  from .models import ShimNetWithSCRF, Predictor
5
 
 
20
  return prediction
21
 
22
  def initialize_predictor(config, weights_file):
23
+ if "_target_" in config.model:
24
+ model = instantiate(config.model)
25
+ else:
26
+ model = ShimNetWithSCRF(**config.model.kwargs)
27
  predictor = Predictor(model, weights_file)
28
  return predictor
29
+
30
+ def get_model_ppm_per_point(config):
31
+ return config.data.get("frq_step", config.metadata.get("frq_step")) / config.metadata.spectrometer_frequency