Spaces:
Sleeping
Sleeping
make scaling adjustable in GUI
Browse files- predict-gui.py +31 -6
predict-gui.py
CHANGED
|
@@ -30,7 +30,7 @@ parser.add_argument(
|
|
| 30 |
)
|
| 31 |
args = parser.parse_args()
|
| 32 |
|
| 33 |
-
def process_file(input_file, config_file, weights_file, input_spectrometer_frequency=None,reference_spectrum=None):
|
| 34 |
if input_spectrometer_frequency == 0:
|
| 35 |
input_spectrometer_frequency = None
|
| 36 |
# Load configuration and initialize predictor
|
|
@@ -53,7 +53,9 @@ def process_file(input_file, config_file, weights_file, input_spectrometer_frequ
|
|
| 53 |
|
| 54 |
# Scale and process spectrum
|
| 55 |
spectrum_tensor = torch.tensor(spectrum).float()
|
| 56 |
-
|
|
|
|
|
|
|
| 57 |
spectrum_tensor *= scaling_factor
|
| 58 |
prediction = predictor(spectrum_tensor).numpy()
|
| 59 |
prediction /= scaling_factor
|
|
@@ -129,7 +131,7 @@ def process_file(input_file, config_file, weights_file, input_spectrometer_frequ
|
|
| 129 |
# Gradio app
|
| 130 |
with gr.Blocks() as app:
|
| 131 |
gr.Markdown("# ShimNet Spectra Correction")
|
| 132 |
-
gr.Markdown("[ShimNet: A neural network for post-acquisition improvement of NMR spectra distorted by magnetic-field inhomogeneity](https://
|
| 133 |
gr.Markdown("Upload your input file, configuration, and weights to process the NMR spectrum.")
|
| 134 |
|
| 135 |
with gr.Row():
|
|
@@ -148,6 +150,11 @@ with gr.Blocks() as app:
|
|
| 148 |
gr.Markdown("Upload reference spectrum files (optional). Reference spectrum will be plotted for comparison.")
|
| 149 |
reference_spectrum_file = gr.File(label="Reference Spectra File (.txt | .csv)", height=120)
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
process_button = gr.Button("Process File")
|
| 152 |
plot_output = gr.Plot(label="Spectrum Visualization")
|
| 153 |
download_button = gr.File(label="Download Processed File", interactive=False, height=120)
|
|
@@ -166,7 +173,7 @@ with gr.Blocks() as app:
|
|
| 166 |
)
|
| 167 |
|
| 168 |
# Process button click logic
|
| 169 |
-
def process_file_with_model(input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file):
|
| 170 |
if model_selection == "600 MHz":
|
| 171 |
config_file = "configs/shimnet_600.yaml"
|
| 172 |
weights_file = "weights/shimnet_600MHz.pt"
|
|
@@ -177,12 +184,30 @@ with gr.Blocks() as app:
|
|
| 177 |
config_file = config_file.name
|
| 178 |
weights_file = weights_file.name
|
| 179 |
|
| 180 |
-
return process_file(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
process_button.click(
|
| 183 |
process_file_with_model,
|
| 184 |
-
inputs=[input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file],
|
| 185 |
outputs=[plot_output, download_button]
|
| 186 |
)
|
| 187 |
|
| 188 |
app.launch(share=args.share, server_name=args.server_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
)
|
| 31 |
args = parser.parse_args()
|
| 32 |
|
| 33 |
+
def process_file(input_file, config_file, weights_file, input_spectrometer_frequency=None,reference_spectrum=None, scale=None):
|
| 34 |
if input_spectrometer_frequency == 0:
|
| 35 |
input_spectrometer_frequency = None
|
| 36 |
# Load configuration and initialize predictor
|
|
|
|
| 53 |
|
| 54 |
# Scale and process spectrum
|
| 55 |
spectrum_tensor = torch.tensor(spectrum).float()
|
| 56 |
+
if scale is None:
|
| 57 |
+
scale = Defaults.SCALE
|
| 58 |
+
scaling_factor = scale / spectrum_tensor.max()
|
| 59 |
spectrum_tensor *= scaling_factor
|
| 60 |
prediction = predictor(spectrum_tensor).numpy()
|
| 61 |
prediction /= scaling_factor
|
|
|
|
| 131 |
# Gradio app
|
| 132 |
with gr.Blocks() as app:
|
| 133 |
gr.Markdown("# ShimNet Spectra Correction")
|
| 134 |
+
gr.Markdown("[ShimNet: A neural network for post-acquisition improvement of NMR spectra distorted by magnetic-field inhomogeneity](https://chemrxiv.org/engage/chemrxiv/article-details/67ef86686dde43c90860d315)")
|
| 135 |
gr.Markdown("Upload your input file, configuration, and weights to process the NMR spectrum.")
|
| 136 |
|
| 137 |
with gr.Row():
|
|
|
|
| 150 |
gr.Markdown("Upload reference spectrum files (optional). Reference spectrum will be plotted for comparison.")
|
| 151 |
reference_spectrum_file = gr.File(label="Reference Spectra File (.txt | .csv)", height=120)
|
| 152 |
|
| 153 |
+
scale_input = gr.Number(
|
| 154 |
+
label="Scale (Intensity Normalization)",
|
| 155 |
+
value=Defaults.SCALE,
|
| 156 |
+
info="Adjust the scaling factor for intensity normalization. Default is 16.",
|
| 157 |
+
)
|
| 158 |
process_button = gr.Button("Process File")
|
| 159 |
plot_output = gr.Plot(label="Spectrum Visualization")
|
| 160 |
download_button = gr.File(label="Download Processed File", interactive=False, height=120)
|
|
|
|
| 173 |
)
|
| 174 |
|
| 175 |
# Process button click logic
|
| 176 |
+
def process_file_with_model(input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file, scale):
|
| 177 |
if model_selection == "600 MHz":
|
| 178 |
config_file = "configs/shimnet_600.yaml"
|
| 179 |
weights_file = "weights/shimnet_600MHz.pt"
|
|
|
|
| 184 |
config_file = config_file.name
|
| 185 |
weights_file = weights_file.name
|
| 186 |
|
| 187 |
+
return process_file(
|
| 188 |
+
input_file.name,
|
| 189 |
+
config_file,
|
| 190 |
+
weights_file,
|
| 191 |
+
input_spectrometer_frequency,
|
| 192 |
+
reference_spectrum_file.name if reference_spectrum_file else None,
|
| 193 |
+
scale
|
| 194 |
+
)
|
| 195 |
|
| 196 |
process_button.click(
|
| 197 |
process_file_with_model,
|
| 198 |
+
inputs=[input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file, scale_input],
|
| 199 |
outputs=[plot_output, download_button]
|
| 200 |
)
|
| 201 |
|
| 202 |
app.launch(share=args.share, server_name=args.server_name)
|
| 203 |
+
|
| 204 |
+
# '#636efa',
|
| 205 |
+
# '#EF553B',
|
| 206 |
+
# '#00cc96',
|
| 207 |
+
# '#ab63fa',
|
| 208 |
+
# '#FFA15A',
|
| 209 |
+
# '#19d3f3',
|
| 210 |
+
# '#FF6692',
|
| 211 |
+
# '#B6E880',
|
| 212 |
+
# '#FF97FF',
|
| 213 |
+
# '#FECB52'
|