Marek Bukowicki commited on
Commit
62da572
·
1 Parent(s): c2e4af2

add reference spectra in GUI

Browse files
Files changed (1) hide show
  1. predict-gui.py +14 -6
predict-gui.py CHANGED
@@ -13,7 +13,10 @@ from predict import Defaults, resample_input_spectrum, resample_output_spectrum,
13
  import warnings
14
  warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
15
 
16
- def process_file(input_file, config_file, weights_file, input_spectrometer_frequency=None):
 
 
 
17
  if input_spectrometer_frequency == 0:
18
  input_spectrometer_frequency = None
19
  # Load configuration and initialize predictor
@@ -51,8 +54,11 @@ def process_file(input_file, config_file, weights_file, input_spectrometer_frequ
51
 
52
  # Create Plotly figure
53
  fig = go.Figure()
54
- fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=input_spectrum, mode='lines', name='Input Spectrum'))
55
- fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=output_prediction, mode='lines', name='Corrected Spectrum'))
 
 
 
56
  fig.update_layout(title="Spectrum Visualization", xaxis_title="Frequency (ppm)", yaxis_title="Intensity")
57
 
58
  return fig, output_file
@@ -84,16 +90,18 @@ with gr.Blocks() as app:
84
  weights_file = gr.File(label="Weights File (.pt)", height=120, value="weights/shimnet_600MHz.pt")
85
 
86
  with gr.Column():
87
- input_file = gr.File(label="Input File (.txt | .csv)", height=150)
88
  input_spectrometer_frequency = gr.Number(label="Input Spectrometer Frequency (MHz) (0 or empty if the same as in the loaded model)", value=None)
89
-
 
 
90
  process_button = gr.Button("Process File")
91
  plot_output = gr.Plot(label="Spectrum Visualization")
92
  download_button = gr.File(label="Download Processed File", interactive=False, height=120)
93
 
94
  process_button.click(
95
  process_file,
96
- inputs=[input_file, config_file, weights_file, input_spectrometer_frequency],
97
  outputs=[plot_output, download_button]
98
  )
99
 
 
13
  import warnings
14
  warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
15
 
16
+ def fast_normalize(x):
17
+ return x / np.max(x)
18
+
19
+ def process_file(input_file, config_file, weights_file, input_spectrometer_frequency=None,reference_spectra=[], normalize_spectra_for_plotting=True):
20
  if input_spectrometer_frequency == 0:
21
  input_spectrometer_frequency = None
22
  # Load configuration and initialize predictor
 
54
 
55
  # Create Plotly figure
56
  fig = go.Figure()
57
+ fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=fast_normalize(input_spectrum) if normalize_spectra_for_plotting else input_spectrum, mode='lines', name='Input Spectrum'))
58
+ fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=fast_normalize(output_prediction) if normalize_spectra_for_plotting else output_prediction, mode='lines', name='Corrected Spectrum'))
59
+ for reference_spectrum_file in reference_spectra:
60
+ reference_data = np.loadtxt(reference_spectrum_file.name)
61
+ fig.add_trace(go.Scatter(x=reference_data[:, 0], y=fast_normalize(reference_data[:, 1]) if normalize_spectra_for_plotting else reference_data[:, 1], mode='lines', name=f'Reference Spectrum {Path(reference_spectrum_file.name).stem}'))
62
  fig.update_layout(title="Spectrum Visualization", xaxis_title="Frequency (ppm)", yaxis_title="Intensity")
63
 
64
  return fig, output_file
 
90
  weights_file = gr.File(label="Weights File (.pt)", height=120, value="weights/shimnet_600MHz.pt")
91
 
92
  with gr.Column():
93
+ input_file = gr.File(label="Input File (.txt | .csv)", height=120)
94
  input_spectrometer_frequency = gr.Number(label="Input Spectrometer Frequency (MHz) (0 or empty if the same as in the loaded model)", value=None)
95
+ gr.Markdown("Upload reference spectra files (optional). Reference spectra will be plotted for comparison.")
96
+ reference_spectra = gr.Files(label="Reference Spectra File(s) (.txt | .csv)", height=120)
97
+ normalize_spectra_for_plotting = gr.Checkbox(label="Normalize Spectra for Plotting", value=True)
98
  process_button = gr.Button("Process File")
99
  plot_output = gr.Plot(label="Spectrum Visualization")
100
  download_button = gr.File(label="Download Processed File", interactive=False, height=120)
101
 
102
  process_button.click(
103
  process_file,
104
+ inputs=[input_file, config_file, weights_file, input_spectrometer_frequency, reference_spectra, normalize_spectra_for_plotting],
105
  outputs=[plot_output, download_button]
106
  )
107