Marek Bukowicki commited on
Commit
7fcb1aa
·
1 Parent(s): cce557d

add intesity scale and suffix parameters to predict scripts

Browse files
Files changed (2) hide show
  1. predict-gui.py +30 -7
  2. predict.py +7 -2
predict-gui.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import torch
3
  torch.set_grad_enabled(False)
4
  import numpy as np
@@ -30,7 +29,7 @@ parser.add_argument(
30
  )
31
  args = parser.parse_args()
32
 
33
- def process_file(input_file, config_file, weights_file, input_spectrometer_frequency=None,reference_spectrum=None):
34
  if input_spectrometer_frequency == 0:
35
  input_spectrometer_frequency = None
36
  # Load configuration and initialize predictor
@@ -53,7 +52,9 @@ def process_file(input_file, config_file, weights_file, input_spectrometer_frequ
53
 
54
  # Scale and process spectrum
55
  spectrum_tensor = torch.tensor(spectrum).float()
56
- scaling_factor = Defaults.SCALE / spectrum_tensor.max()
 
 
57
  spectrum_tensor *= scaling_factor
58
  prediction = predictor(spectrum_tensor).numpy()
59
  prediction /= scaling_factor
@@ -63,7 +64,9 @@ def process_file(input_file, config_file, weights_file, input_spectrometer_frequ
63
 
64
  # Prepare output data for download
65
  output_data = np.column_stack((input_freqs_input_ppm, output_prediction))
66
- output_file = f"{Path(input_file).stem}_processed{Path(input_file).suffix}"
 
 
67
  np.savetxt(output_file, output_data)
68
 
69
  # Create Plotly figure
@@ -141,6 +144,18 @@ with gr.Blocks() as app:
141
  )
142
  config_file = gr.File(label="Custom Config File (.yaml)", visible=False, height=120)
143
  weights_file = gr.File(label="Custom Weights File (.pt)", visible=False, height=120)
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  with gr.Column():
146
  input_file = gr.File(label="Input File (.txt | .csv)", height=120)
@@ -166,7 +181,7 @@ with gr.Blocks() as app:
166
  )
167
 
168
  # Process button click logic
169
- def process_file_with_model(input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file):
170
  if model_selection == "600 MHz":
171
  config_file = os.path.join(os.path.dirname(__file__), "configs/shimnet_600.yaml")
172
  weights_file = os.path.join(os.path.dirname(__file__), "weights/shimnet_600MHz.pt")
@@ -177,11 +192,19 @@ with gr.Blocks() as app:
177
  config_file = config_file.name
178
  weights_file = weights_file.name
179
 
180
- return process_file(input_file.name, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file.name if reference_spectrum_file else None)
 
 
 
 
 
 
 
 
181
 
182
  process_button.click(
183
  process_file_with_model,
184
- inputs=[input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file],
185
  outputs=[plot_output, download_button]
186
  )
187
 
 
 
1
  import torch
2
  torch.set_grad_enabled(False)
3
  import numpy as np
 
29
  )
30
  args = parser.parse_args()
31
 
32
+ def process_file(input_file, config_file, weights_file, input_spectrometer_frequency=None,reference_spectrum=None, scale=None, suffix=None):
33
  if input_spectrometer_frequency == 0:
34
  input_spectrometer_frequency = None
35
  # Load configuration and initialize predictor
 
52
 
53
  # Scale and process spectrum
54
  spectrum_tensor = torch.tensor(spectrum).float()
55
+ if scale is None:
56
+ scale = Defaults.SCALE
57
+ scaling_factor = scale / spectrum_tensor.max()
58
  spectrum_tensor *= scaling_factor
59
  prediction = predictor(spectrum_tensor).numpy()
60
  prediction /= scaling_factor
 
64
 
65
  # Prepare output data for download
66
  output_data = np.column_stack((input_freqs_input_ppm, output_prediction))
67
+ if suffix is None:
68
+ suffix = Defaults.SUFFIX
69
+ output_file = f"{Path(input_file).stem}{suffix}{Path(input_file).suffix}"
70
  np.savetxt(output_file, output_data)
71
 
72
  # Create Plotly figure
 
144
  )
145
  config_file = gr.File(label="Custom Config File (.yaml)", visible=False, height=120)
146
  weights_file = gr.File(label="Custom Weights File (.pt)", visible=False, height=120)
147
+
148
+ with gr.Accordion("Advanced", open=False):
149
+ scale_input = gr.Number(
150
+ label="Intensity Scale",
151
+ value=Defaults.SCALE,
152
+ info=f"Adjust the scaling factor for intensity normalization. Default is {Defaults.SCALE}.",
153
+ )
154
+ suffix_input = gr.Textbox(
155
+ label="Output File Suffix",
156
+ value=Defaults.SUFFIX,
157
+ info=f"Suffix to add to processed output filenames. Default is '{Defaults.SUFFIX}'.",
158
+ )
159
 
160
  with gr.Column():
161
  input_file = gr.File(label="Input File (.txt | .csv)", height=120)
 
181
  )
182
 
183
  # Process button click logic
184
+ def process_file_with_model(input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file, scale, suffix):
185
  if model_selection == "600 MHz":
186
  config_file = os.path.join(os.path.dirname(__file__), "configs/shimnet_600.yaml")
187
  weights_file = os.path.join(os.path.dirname(__file__), "weights/shimnet_600MHz.pt")
 
192
  config_file = config_file.name
193
  weights_file = weights_file.name
194
 
195
+ return process_file(
196
+ input_file.name,
197
+ config_file,
198
+ weights_file,
199
+ input_spectrometer_frequency,
200
+ reference_spectrum_file.name if reference_spectrum_file else None,
201
+ scale,
202
+ suffix
203
+ )
204
 
205
  process_button.click(
206
  process_file_with_model,
207
+ inputs=[input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file, scale_input, suffix_input],
208
  outputs=[plot_output, download_button]
209
  )
210
 
predict.py CHANGED
@@ -14,6 +14,9 @@ from shimnet.predict_utils import Defaults, resample_input_spectrum, resample_ou
14
  import warnings
15
  warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
16
 
 
 
 
17
 
18
  def parse_args():
19
  parser = argparse.ArgumentParser()
@@ -22,6 +25,8 @@ def parse_args():
22
  parser.add_argument("--weights", help="model weights")
23
  parser.add_argument("-o", "--output_dir", default=".", help="Output directory")
24
  parser.add_argument("--input_spectrometer_frequency", default=None, type=float, help="spectrometer frequency in MHz (input sample collection frequency). Empty if the same as in the training data")
 
 
25
  args = parser.parse_args()
26
  return args
27
 
@@ -52,7 +57,7 @@ if __name__ == "__main__":
52
 
53
  spectrum = torch.tensor(spectrum).float()
54
  # scale height of the spectrum
55
- scaling_factor = Defaults.SCALE / spectrum.max()
56
  spectrum *= scaling_factor
57
 
58
  # correct spectrum
@@ -65,7 +70,7 @@ if __name__ == "__main__":
65
  output_prediction = resample_output_spectrum(input_freqs_model_ppm, freqs, prediction)
66
 
67
  # save result
68
- output_file = output_dir / f"{Path(input_file).stem}_processed{Path(input_file).suffix}"
69
 
70
  np.savetxt(output_file, np.column_stack((input_freqs_input_ppm, output_prediction)))
71
  print(f"saved to {output_file}")
 
14
  import warnings
15
  warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
16
 
17
+ class Defaults:
18
+ SCALE = 16.0
19
+ SUFFIX = "_processed"
20
 
21
  def parse_args():
22
  parser = argparse.ArgumentParser()
 
25
  parser.add_argument("--weights", help="model weights")
26
  parser.add_argument("-o", "--output_dir", default=".", help="Output directory")
27
  parser.add_argument("--input_spectrometer_frequency", default=None, type=float, help="spectrometer frequency in MHz (input sample collection frequency). Empty if the same as in the training data")
28
+ parser.add_argument("--suffix", default=Defaults.SUFFIX, help=f"Output file suffix (default: {Defaults.SUFFIX})")
29
+ parser.add_argument("--intensity_scale", default=Defaults.SCALE, type=float, help=f"Intensity scaling factor (default: {Defaults.SCALE})")
30
  args = parser.parse_args()
31
  return args
32
 
 
57
 
58
  spectrum = torch.tensor(spectrum).float()
59
  # scale height of the spectrum
60
+ scaling_factor = args.intensity_scale / spectrum.max()
61
  spectrum *= scaling_factor
62
 
63
  # correct spectrum
 
70
  output_prediction = resample_output_spectrum(input_freqs_model_ppm, freqs, prediction)
71
 
72
  # save result
73
+ output_file = output_dir / f"{Path(input_file).stem}{args.suffix}{Path(input_file).suffix}"
74
 
75
  np.savetxt(output_file, np.column_stack((input_freqs_input_ppm, output_prediction)))
76
  print(f"saved to {output_file}")