Marek Bukowicki commited on
Commit
7bc8c9d
·
1 Parent(s): 7c6af4e

add output name setting, hide advanced options

Browse files
Files changed (1) hide show
  1. predict-gui.py +39 -24
predict-gui.py CHANGED
@@ -30,7 +30,10 @@ parser.add_argument(
30
  )
31
  args = parser.parse_args()
32
 
33
- def process_file(input_file, config_file, weights_file, input_spectrometer_frequency=None,reference_spectrum=None, scale=None):
 
 
 
34
  if input_spectrometer_frequency == 0:
35
  input_spectrometer_frequency = None
36
  # Load configuration and initialize predictor
@@ -65,7 +68,10 @@ def process_file(input_file, config_file, weights_file, input_spectrometer_frequ
65
 
66
  # Prepare output data for download
67
  output_data = np.column_stack((input_freqs_input_ppm, output_prediction))
68
- output_file = f"{Path(input_file).stem}_processed{Path(input_file).suffix}"
 
 
 
69
  np.savetxt(output_file, output_data)
70
 
71
  # Create Plotly figure
@@ -128,11 +134,13 @@ def process_file(input_file, config_file, weights_file, input_spectrometer_frequ
128
 
129
  return fig, output_file
130
 
 
131
  # Gradio app
132
  with gr.Blocks() as app:
133
  gr.Markdown("# ShimNet Spectra Correction")
134
- gr.Markdown("[ShimNet: A neural network for post-acquisition improvement of NMR spectra distorted by magnetic-field inhomogeneity](https://chemrxiv.org/engage/chemrxiv/article-details/67ef86686dde43c90860d315)")
135
- gr.Markdown("Upload your input file, configuration, and weights to process the NMR spectrum.")
 
136
 
137
  with gr.Row():
138
  with gr.Column():
@@ -144,12 +152,14 @@ with gr.Blocks() as app:
144
  config_file = gr.File(label="Custom Config File (.yaml)", visible=False, height=120)
145
  weights_file = gr.File(label="Custom Weights File (.pt)", visible=False, height=120)
146
 
147
- scale_input = gr.Number(
148
- label="Scale (Intensity Normalization)",
149
- value=Defaults.SCALE,
150
- info="Adjust the scaling factor for intensity normalization. Default is 16.",
151
- )
152
-
 
 
153
  with gr.Column():
154
  input_file = gr.File(label="Input File (.txt | .csv)", height=120)
155
  input_spectrometer_frequency = gr.Number(label="Input Spectrometer Frequency (MHz) (0 or empty if the same as in the loaded model)", value=None)
@@ -174,8 +184,23 @@ with gr.Blocks() as app:
174
  outputs=[config_file, weights_file]
175
  )
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  # Process button click logic
178
- def process_file_with_model(input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file, scale):
179
  if model_selection == "600 MHz":
180
  config_file = "configs/shimnet_600.yaml"
181
  weights_file = "weights/shimnet_600MHz.pt"
@@ -192,24 +217,14 @@ with gr.Blocks() as app:
192
  weights_file,
193
  input_spectrometer_frequency,
194
  reference_spectrum_file.name if reference_spectrum_file else None,
195
- scale
 
196
  )
197
 
198
  process_button.click(
199
  process_file_with_model,
200
- inputs=[input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file, scale_input],
201
  outputs=[plot_output, download_button]
202
  )
203
 
204
  app.launch(share=args.share, server_name=args.server_name)
205
-
206
- # '#636efa',
207
- # '#EF553B',
208
- # '#00cc96',
209
- # '#ab63fa',
210
- # '#FFA15A',
211
- # '#19d3f3',
212
- # '#FF6692',
213
- # '#B6E880',
214
- # '#FF97FF',
215
- # '#FECB52'
 
30
  )
31
  args = parser.parse_args()
32
 
33
+ def default_out_filename_from_input_filename(input_file):
34
+ return f"{Path(input_file.name).stem}_processed{Path(input_file.name).suffix}"
35
+
36
+ def process_file(input_file, config_file, weights_file, input_spectrometer_frequency=None,reference_spectrum=None, scale=None, output_filename=None):
37
  if input_spectrometer_frequency == 0:
38
  input_spectrometer_frequency = None
39
  # Load configuration and initialize predictor
 
68
 
69
  # Prepare output data for download
70
  output_data = np.column_stack((input_freqs_input_ppm, output_prediction))
71
+ if output_filename is None or output_filename.strip() == "":
72
+ output_file = default_out_filename_from_input_filename(input_file)
73
+ else:
74
+ output_file = output_filename
75
  np.savetxt(output_file, output_data)
76
 
77
  # Create Plotly figure
 
134
 
135
  return fig, output_file
136
 
137
+
138
  # Gradio app
139
  with gr.Blocks() as app:
140
  gr.Markdown("# ShimNet Spectra Correction")
141
+ gr.Markdown("[ShimNet: A neural network for post-acquisition improvement of NMR spectra distorted by magnetic-field inhomogeneity](https://doi.org/10.1021/acs.jpcb.5c02632)")
142
+ gr.Markdown("Upload your input file and select the model to process the NMR spectrum. Select 'Custom' to provide your own configuration and weights files.")
143
+ gr.Markdown("Sample input files: https://huggingface.co/spaces/NMR-CeNT-UW/ShimNet/tree/main/sample_data")
144
 
145
  with gr.Row():
146
  with gr.Column():
 
152
  config_file = gr.File(label="Custom Config File (.yaml)", visible=False, height=120)
153
  weights_file = gr.File(label="Custom Weights File (.pt)", visible=False, height=120)
154
 
155
+ with gr.Accordion("Custom output file name / advanced options", open=False):
156
+ output_filename = gr.Textbox(label="Output File Name", placeholder="set automatically after input file upload", interactive=True)
157
+ scale_input = gr.Number(
158
+ label="Scale (Intensity Normalization)",
159
+ value=Defaults.SCALE,
160
+ info="Adjust the scaling factor for intensity normalization. Default is 16.",
161
+ )
162
+
163
  with gr.Column():
164
  input_file = gr.File(label="Input File (.txt | .csv)", height=120)
165
  input_spectrometer_frequency = gr.Number(label="Input Spectrometer Frequency (MHz) (0 or empty if the same as in the loaded model)", value=None)
 
184
  outputs=[config_file, weights_file]
185
  )
186
 
187
+ # Auto-populate output filename when input file is uploaded
188
+ def generate_output_filename(input_file, output_filename):
189
+ # do not overwrite if user has provided a filename
190
+ if output_filename is not None and output_filename.strip() != "":
191
+ return output_filename
192
+ if input_file is None:
193
+ return ""
194
+ return default_out_filename_from_input_filename(input_file)
195
+
196
+ input_file.change(
197
+ generate_output_filename,
198
+ inputs=[input_file, output_filename],
199
+ outputs=[output_filename]
200
+ )
201
+
202
  # Process button click logic
203
+ def process_file_with_model(input_file, output_filename, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file, scale):
204
  if model_selection == "600 MHz":
205
  config_file = "configs/shimnet_600.yaml"
206
  weights_file = "weights/shimnet_600MHz.pt"
 
217
  weights_file,
218
  input_spectrometer_frequency,
219
  reference_spectrum_file.name if reference_spectrum_file else None,
220
+ scale,
221
+ output_filename
222
  )
223
 
224
  process_button.click(
225
  process_file_with_model,
226
+ inputs=[input_file, output_filename, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file, scale_input],
227
  outputs=[plot_output, download_button]
228
  )
229
 
230
  app.launch(share=args.share, server_name=args.server_name)