Spaces:
Sleeping
Sleeping
Marek Bukowicki commited on
Commit ·
a86e7e6
1
Parent(s): 9e390be
"allow modular models, move get_model_ppm_per_point to utils"
Browse files- predict-gui.py +2 -2
- predict.py +1 -1
- shimnet/predict_utils.py +8 -1
predict-gui.py
CHANGED
|
@@ -7,7 +7,7 @@ from omegaconf import OmegaConf
|
|
| 7 |
import gradio as gr
|
| 8 |
import plotly.graph_objects as go
|
| 9 |
|
| 10 |
-
from shimnet.predict_utils import Defaults, resample_input_spectrum, resample_output_spectrum, initialize_predictor
|
| 11 |
|
| 12 |
# silent deprecation warnings
|
| 13 |
import warnings
|
|
@@ -35,7 +35,7 @@ def process_file(input_file, config_file, weights_file, input_spectrometer_frequ
|
|
| 35 |
input_spectrometer_frequency = None
|
| 36 |
# Load configuration and initialize predictor
|
| 37 |
config = OmegaConf.load(config_file)
|
| 38 |
-
model_ppm_per_point = config
|
| 39 |
predictor = initialize_predictor(config, weights_file)
|
| 40 |
|
| 41 |
# Load input data
|
|
|
|
| 7 |
import gradio as gr
|
| 8 |
import plotly.graph_objects as go
|
| 9 |
|
| 10 |
+
from shimnet.predict_utils import Defaults, resample_input_spectrum, resample_output_spectrum, initialize_predictor, get_model_ppm_per_point
|
| 11 |
|
| 12 |
# silent deprecation warnings
|
| 13 |
import warnings
|
|
|
|
| 35 |
input_spectrometer_frequency = None
|
| 36 |
# Load configuration and initialize predictor
|
| 37 |
config = OmegaConf.load(config_file)
|
| 38 |
+
model_ppm_per_point = get_model_ppm_per_point(config)
|
| 39 |
predictor = initialize_predictor(config, weights_file)
|
| 40 |
|
| 41 |
# Load input data
|
predict.py
CHANGED
|
@@ -33,7 +33,7 @@ if __name__ == "__main__":
|
|
| 33 |
output_dir.mkdir(exist_ok=True, parents=True)
|
| 34 |
|
| 35 |
config = OmegaConf.load(args.config)
|
| 36 |
-
model_ppm_per_point =
|
| 37 |
predictor = initialize_predictor(config, args.weights)
|
| 38 |
|
| 39 |
for input_file in args.input_files:
|
|
|
|
| 33 |
output_dir.mkdir(exist_ok=True, parents=True)
|
| 34 |
|
| 35 |
config = OmegaConf.load(args.config)
|
| 36 |
+
model_ppm_per_point = get_model_ppm_per_point(config)
|
| 37 |
predictor = initialize_predictor(config, args.weights)
|
| 38 |
|
| 39 |
for input_file in args.input_files:
|
shimnet/predict_utils.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import numpy as np
|
|
|
|
| 2 |
|
| 3 |
from .models import ShimNetWithSCRF, Predictor
|
| 4 |
|
|
@@ -19,6 +20,12 @@ def resample_output_spectrum(input_freqs, freqs, prediction):
|
|
| 19 |
return prediction
|
| 20 |
|
| 21 |
def initialize_predictor(config, weights_file):
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
| 23 |
predictor = Predictor(model, weights_file)
|
| 24 |
return predictor
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
+
from hydra.utils import instantiate
|
| 3 |
|
| 4 |
from .models import ShimNetWithSCRF, Predictor
|
| 5 |
|
|
|
|
| 20 |
return prediction
|
| 21 |
|
| 22 |
def initialize_predictor(config, weights_file):
|
| 23 |
+
if "_target_" in config.model:
|
| 24 |
+
model = instantiate(config.model)
|
| 25 |
+
else:
|
| 26 |
+
model = ShimNetWithSCRF(**config.model.kwargs)
|
| 27 |
predictor = Predictor(model, weights_file)
|
| 28 |
return predictor
|
| 29 |
+
|
| 30 |
+
def get_model_ppm_per_point(config):
|
| 31 |
+
return config.data.get("frq_step", config.metadata.get("frq_step")) / config.metadata.spectrometer_frequency
|