marekb-sci commited on
Commit
a69077a
·
verified ·
1 Parent(s): 1c2daf8

make scaling adjustable in GUI

Browse files
Files changed (1) hide show
  1. predict-gui.py +31 -6
predict-gui.py CHANGED
@@ -30,7 +30,7 @@ parser.add_argument(
30
  )
31
  args = parser.parse_args()
32
 
33
- def process_file(input_file, config_file, weights_file, input_spectrometer_frequency=None,reference_spectrum=None):
34
  if input_spectrometer_frequency == 0:
35
  input_spectrometer_frequency = None
36
  # Load configuration and initialize predictor
@@ -53,7 +53,9 @@ def process_file(input_file, config_file, weights_file, input_spectrometer_frequ
53
 
54
  # Scale and process spectrum
55
  spectrum_tensor = torch.tensor(spectrum).float()
56
- scaling_factor = Defaults.SCALE / spectrum_tensor.max()
 
 
57
  spectrum_tensor *= scaling_factor
58
  prediction = predictor(spectrum_tensor).numpy()
59
  prediction /= scaling_factor
@@ -129,7 +131,7 @@ def process_file(input_file, config_file, weights_file, input_spectrometer_frequ
129
  # Gradio app
130
  with gr.Blocks() as app:
131
  gr.Markdown("# ShimNet Spectra Correction")
132
- gr.Markdown("[ShimNet: A neural network for post-acquisition improvement of NMR spectra distorted by magnetic-field inhomogeneity](https://pubs.acs.org/doi/full/10.1021/acs.jpcb.5c02632)")
133
  gr.Markdown("Upload your input file, configuration, and weights to process the NMR spectrum.")
134
 
135
  with gr.Row():
@@ -148,6 +150,11 @@ with gr.Blocks() as app:
148
  gr.Markdown("Upload reference spectrum files (optional). Reference spectrum will be plotted for comparison.")
149
  reference_spectrum_file = gr.File(label="Reference Spectra File (.txt | .csv)", height=120)
150
 
 
 
 
 
 
151
  process_button = gr.Button("Process File")
152
  plot_output = gr.Plot(label="Spectrum Visualization")
153
  download_button = gr.File(label="Download Processed File", interactive=False, height=120)
@@ -166,7 +173,7 @@ with gr.Blocks() as app:
166
  )
167
 
168
  # Process button click logic
169
- def process_file_with_model(input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file):
170
  if model_selection == "600 MHz":
171
  config_file = "configs/shimnet_600.yaml"
172
  weights_file = "weights/shimnet_600MHz.pt"
@@ -177,12 +184,30 @@ with gr.Blocks() as app:
177
  config_file = config_file.name
178
  weights_file = weights_file.name
179
 
180
- return process_file(input_file.name, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file.name if reference_spectrum_file else None)
 
 
 
 
 
 
 
181
 
182
  process_button.click(
183
  process_file_with_model,
184
- inputs=[input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file],
185
  outputs=[plot_output, download_button]
186
  )
187
 
188
  app.launch(share=args.share, server_name=args.server_name)
 
 
 
 
 
 
 
 
 
 
 
 
30
  )
31
  args = parser.parse_args()
32
 
33
+ def process_file(input_file, config_file, weights_file, input_spectrometer_frequency=None,reference_spectrum=None, scale=None):
34
  if input_spectrometer_frequency == 0:
35
  input_spectrometer_frequency = None
36
  # Load configuration and initialize predictor
 
53
 
54
  # Scale and process spectrum
55
  spectrum_tensor = torch.tensor(spectrum).float()
56
+ if scale is None:
57
+ scale = Defaults.SCALE
58
+ scaling_factor = scale / spectrum_tensor.max()
59
  spectrum_tensor *= scaling_factor
60
  prediction = predictor(spectrum_tensor).numpy()
61
  prediction /= scaling_factor
 
131
  # Gradio app
132
  with gr.Blocks() as app:
133
  gr.Markdown("# ShimNet Spectra Correction")
134
+ gr.Markdown("[ShimNet: A neural network for post-acquisition improvement of NMR spectra distorted by magnetic-field inhomogeneity](https://chemrxiv.org/engage/chemrxiv/article-details/67ef86686dde43c90860d315)")
135
  gr.Markdown("Upload your input file, configuration, and weights to process the NMR spectrum.")
136
 
137
  with gr.Row():
 
150
  gr.Markdown("Upload reference spectrum files (optional). Reference spectrum will be plotted for comparison.")
151
  reference_spectrum_file = gr.File(label="Reference Spectra File (.txt | .csv)", height=120)
152
 
153
+ scale_input = gr.Number(
154
+ label="Scale (Intensity Normalization)",
155
+ value=Defaults.SCALE,
156
+ info="Adjust the scaling factor for intensity normalization. Default is 16.",
157
+ )
158
  process_button = gr.Button("Process File")
159
  plot_output = gr.Plot(label="Spectrum Visualization")
160
  download_button = gr.File(label="Download Processed File", interactive=False, height=120)
 
173
  )
174
 
175
  # Process button click logic
176
+ def process_file_with_model(input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file, scale):
177
  if model_selection == "600 MHz":
178
  config_file = "configs/shimnet_600.yaml"
179
  weights_file = "weights/shimnet_600MHz.pt"
 
184
  config_file = config_file.name
185
  weights_file = weights_file.name
186
 
187
+ return process_file(
188
+ input_file.name,
189
+ config_file,
190
+ weights_file,
191
+ input_spectrometer_frequency,
192
+ reference_spectrum_file.name if reference_spectrum_file else None,
193
+ scale
194
+ )
195
 
196
  process_button.click(
197
  process_file_with_model,
198
+ inputs=[input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file, scale_input],
199
  outputs=[plot_output, download_button]
200
  )
201
 
202
  app.launch(share=args.share, server_name=args.server_name)
203
+
204
+ # '#636efa',
205
+ # '#EF553B',
206
+ # '#00cc96',
207
+ # '#ab63fa',
208
+ # '#FFA15A',
209
+ # '#19d3f3',
210
+ # '#FF6692',
211
+ # '#B6E880',
212
+ # '#FF97FF',
213
+ # '#FECB52'