Marek Bukowicki commited on
Commit
48fbb90
·
1 Parent(s): 2ac645e

improve GUI

Browse files
Geraniol_up_1mm_600MHz_processed.csv DELETED
The diff for this file is too large to render. See raw diff
 
predict-gui.py CHANGED
@@ -13,19 +13,33 @@ from predict import Defaults, resample_input_spectrum, resample_output_spectrum,
13
  import warnings
14
  warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
15
 
16
- def fast_normalize(x):
17
- return x / np.max(x)
18
-
19
- def process_file(input_file, config_file, weights_file, input_spectrometer_frequency=None,reference_spectra=[], normalize_spectra_for_plotting=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  if input_spectrometer_frequency == 0:
21
  input_spectrometer_frequency = None
22
  # Load configuration and initialize predictor
23
- config = OmegaConf.load(config_file.name)
24
  model_ppm_per_point = config.data.frq_step / config.metadata.spectrometer_frequency
25
- predictor = initialize_predictor(config, weights_file.name)
26
 
27
  # Load input data
28
- input_data = np.loadtxt(input_file.name)
29
  input_freqs_input_ppm, input_spectrum = input_data[:, 0], input_data[:, 1]
30
 
31
  # Convert input frequencies to model's frequency
@@ -49,60 +63,135 @@ def process_file(input_file, config_file, weights_file, input_spectrometer_frequ
49
 
50
  # Prepare output data for download
51
  output_data = np.column_stack((input_freqs_input_ppm, output_prediction))
52
- output_file = f"{Path(input_file.name).stem}_processed{Path(input_file.name).suffix}"
53
  np.savetxt(output_file, output_data)
54
 
55
  # Create Plotly figure
56
  fig = go.Figure()
57
- fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=fast_normalize(input_spectrum) if normalize_spectra_for_plotting else input_spectrum, mode='lines', name='Input Spectrum'))
58
- fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=fast_normalize(output_prediction) if normalize_spectra_for_plotting else output_prediction, mode='lines', name='Corrected Spectrum'))
59
- for reference_spectrum_file in reference_spectra:
60
- reference_data = np.loadtxt(reference_spectrum_file.name)
61
- fig.add_trace(go.Scatter(x=reference_data[:, 0], y=fast_normalize(reference_data[:, 1]) if normalize_spectra_for_plotting else reference_data[:, 1], mode='lines', name=f'Reference Spectrum {Path(reference_spectrum_file.name).stem}'))
62
- fig.update_layout(title="Spectrum Visualization", xaxis_title="Frequency (ppm)", yaxis_title="Intensity")
63
 
64
- return fig, output_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- # app = gr.Interface(
67
- # fn=process_file,
68
- # inputs=[
69
- # gr.File(label="Input File (.txt | .csv)"),
70
- # gr.File(label="Config File (.yaml)"),
71
- # gr.File(label="Weights File (.pt)"),
72
- # gr.Number(label="Input Spectrometer Frequency (MHz)", value=None)
73
- # ],
74
- # outputs=[
75
- # gr.Plot(label="Spectrum Visualization"),
76
- # gr.File(label="Download Processed File")
77
- # ],
78
- # title="NMR Spectrum Prediction",
79
- # description="Upload your input file, configuration, and weights to process the NMR spectrum."
80
- # )
81
 
82
  # Gradio app
83
  with gr.Blocks() as app:
84
  gr.Markdown("# ShimNet Spectra Correction")
 
85
  gr.Markdown("Upload your input file, configuration, and weights to process the NMR spectrum.")
86
 
87
  with gr.Row():
88
  with gr.Column():
89
- config_file = gr.File(label="Config File (.yaml)", height=120, value="configs/shimnet_600.yaml")
90
- weights_file = gr.File(label="Weights File (.pt)", height=120, value="weights/shimnet_600MHz.pt")
 
 
 
 
 
91
 
92
  with gr.Column():
93
  input_file = gr.File(label="Input File (.txt | .csv)", height=120)
94
  input_spectrometer_frequency = gr.Number(label="Input Spectrometer Frequency (MHz) (0 or empty if the same as in the loaded model)", value=None)
95
- gr.Markdown("Upload reference spectra files (optional). Reference spectra will be plotted for comparison.")
96
- reference_spectra = gr.Files(label="Reference Spectra File(s) (.txt | .csv)", height=120)
97
- normalize_spectra_for_plotting = gr.Checkbox(label="Normalize Spectra for Plotting", value=True)
98
  process_button = gr.Button("Process File")
99
  plot_output = gr.Plot(label="Spectrum Visualization")
100
  download_button = gr.File(label="Download Processed File", interactive=False, height=120)
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  process_button.click(
103
- process_file,
104
- inputs=[input_file, config_file, weights_file, input_spectrometer_frequency, reference_spectra, normalize_spectra_for_plotting],
105
  outputs=[plot_output, download_button]
106
  )
107
 
108
- app.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
13
  import warnings
14
  warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
15
 
16
+ import argparse
17
+
18
+ # Add argument parsing for server_name
19
+ parser = argparse.ArgumentParser(description="Launch ShimNet Spectra Correction App")
20
+ parser.add_argument(
21
+ "--server_name",
22
+ type=str,
23
+ default="127.0.0.1",
24
+ help="Server name to bind the app (default: 127.0.0.1). Use 0.0.0.0 for external access."
25
+ )
26
+ parser.add_argument(
27
+ "--share",
28
+ action="store_true",
29
+ help="If set, generates a public link to share the app."
30
+ )
31
+ args = parser.parse_args()
32
+
33
+ def process_file(input_file, config_file, weights_file, input_spectrometer_frequency=None,reference_spectrum=None):
34
  if input_spectrometer_frequency == 0:
35
  input_spectrometer_frequency = None
36
  # Load configuration and initialize predictor
37
+ config = OmegaConf.load(config_file)
38
  model_ppm_per_point = config.data.frq_step / config.metadata.spectrometer_frequency
39
+ predictor = initialize_predictor(config, weights_file)
40
 
41
  # Load input data
42
+ input_data = np.loadtxt(input_file)
43
  input_freqs_input_ppm, input_spectrum = input_data[:, 0], input_data[:, 1]
44
 
45
  # Convert input frequencies to model's frequency
 
63
 
64
  # Prepare output data for download
65
  output_data = np.column_stack((input_freqs_input_ppm, output_prediction))
66
+ output_file = f"{Path(input_file).stem}_processed{Path(input_file).suffix}"
67
  np.savetxt(output_file, output_data)
68
 
69
  # Create Plotly figure
70
  fig = go.Figure()
 
 
 
 
 
 
71
 
72
+ # Add Input Spectrum and Corrected Spectrum (always visible)
73
+ normalization_value = input_spectrum.max()
74
+ fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=input_spectrum/normalization_value, mode='lines', name='Input Spectrum', visible=True, line=dict(color='#EF553B'))) # red
75
+ fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=output_prediction/normalization_value, mode='lines', name='Corrected Spectrum', visible=True, line=dict(color='#00cc96'))) # green
76
+
77
+ if reference_spectrum is not None:
78
+ reference_spectrum_freqs, reference_spectrum_intensity = np.loadtxt(reference_spectrum).T
79
+ reference_spectrum_intensity /= reference_spectrum_intensity.max()
80
+ n_zooms = 50
81
+ zooms = np.geomspace(0.01, 100, 2 * n_zooms + 1)
82
+
83
+ # Add Reference Data traces (initially invisible)
84
+ for zoom in zooms:
85
+ fig.add_trace(
86
+ go.Scatter(
87
+ x=reference_spectrum_freqs,
88
+ y=reference_spectrum_intensity * zoom,
89
+ mode='lines',
90
+ name=f'Reference Data (Zoom: {zoom:.2f})',
91
+ visible=False,
92
+ line=dict(color='#636efa')
93
+ )
94
+ )
95
+ # Make the middle zoom level visible by default
96
+ fig.data[2 * n_zooms // 2 + 2].visible = True
97
+
98
+ # Create and add slider
99
+ steps = []
100
+ for i in range(2, len(fig.data)): # Start from the reference data traces
101
+ step = dict(
102
+ method="update",
103
+ args=[{"visible": [True, True] + [False] * (len(fig.data) - 2)}], # Keep first two traces visible
104
+ )
105
+ step["args"][0]["visible"][i] = True # Toggle i'th reference trace to "visible"
106
+ steps.append(step)
107
+
108
+ sliders = [dict(
109
+ active=n_zooms,
110
+ currentvalue={"prefix": "Reference zoom: "},
111
+ pad={"t": 50},
112
+ steps=steps
113
+ )]
114
+
115
+ fig.update_layout(
116
+ sliders=sliders
117
+ )
118
+
119
+ fig.update_layout(
120
+ title="Spectrum Visualization",
121
+ xaxis_title="Frequency (ppm)",
122
+ yaxis_title="Intensity"
123
+ )
124
 
125
+ return fig, output_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  # Gradio app
128
  with gr.Blocks() as app:
129
  gr.Markdown("# ShimNet Spectra Correction")
130
+ gr.Markdown("[ShimNet: A neural network for post-acquisition improvement of NMR spectra distorted by magnetic-field inhomogeneity](https://chemrxiv.org/engage/chemrxiv/article-details/67ef86686dde43c90860d315)")
131
  gr.Markdown("Upload your input file, configuration, and weights to process the NMR spectrum.")
132
 
133
  with gr.Row():
134
  with gr.Column():
135
+ model_selection = gr.Radio(
136
+ label="Select Model",
137
+ choices=["600 MHz", "700 MHz", "Custom"],
138
+ value="600 MHz"
139
+ )
140
+ config_file = gr.File(label="Custom Config File (.yaml)", visible=False, height=120)
141
+ weights_file = gr.File(label="Custom Weights File (.pt)", visible=False, height=120)
142
 
143
  with gr.Column():
144
  input_file = gr.File(label="Input File (.txt | .csv)", height=120)
145
  input_spectrometer_frequency = gr.Number(label="Input Spectrometer Frequency (MHz) (0 or empty if the same as in the loaded model)", value=None)
146
+ gr.Markdown("Upload reference spectrum files (optional). Reference spectrum will be plotted for comparison.")
147
+ reference_spectrum_file = gr.File(label="Reference Spectra File (.txt | .csv)", height=120)
148
+
149
  process_button = gr.Button("Process File")
150
  plot_output = gr.Plot(label="Spectrum Visualization")
151
  download_button = gr.File(label="Download Processed File", interactive=False, height=120)
152
 
153
+ # Update visibility of config and weights fields based on model selection
154
+ def update_visibility(selected_model):
155
+ if selected_model == "Custom":
156
+ return gr.update(visible=True), gr.update(visible=True)
157
+ else:
158
+ return gr.update(visible=False), gr.update(visible=False)
159
+
160
+ model_selection.change(
161
+ update_visibility,
162
+ inputs=[model_selection],
163
+ outputs=[config_file, weights_file]
164
+ )
165
+
166
+ # Process button click logic
167
+ def process_file_with_model(input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file):
168
+ if model_selection == "600 MHz":
169
+ config_file = "configs/shimnet_600.yaml"
170
+ weights_file = "weights/shimnet_600MHz.pt"
171
+ elif model_selection == "700 MHz":
172
+ config_file = "configs/shimnet_700.yaml"
173
+ weights_file = "weights/shimnet_700MHz.pt"
174
+ else:
175
+ config_file = config_file.name
176
+ weights_file = weights_file.name
177
+
178
+ return process_file(input_file.name, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file.name if reference_spectrum_file else None)
179
+
180
  process_button.click(
181
+ process_file_with_model,
182
+ inputs=[input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file],
183
  outputs=[plot_output, download_button]
184
  )
185
 
186
+ app.launch(share=args.share, server_name=args.server_name)
187
+
188
+ # '#636efa',
189
+ # '#EF553B',
190
+ # '#00cc96',
191
+ # '#ab63fa',
192
+ # '#FFA15A',
193
+ # '#19d3f3',
194
+ # '#FF6692',
195
+ # '#B6E880',
196
+ # '#FF97FF',
197
+ # '#FECB52'