Spaces:
Sleeping
Sleeping
Marek Bukowicki commited on
Commit ·
7bc8c9d
1
Parent(s): 7c6af4e
add output name setting, hide advanced options
Browse files- predict-gui.py +39 -24
predict-gui.py
CHANGED
|
@@ -30,7 +30,10 @@ parser.add_argument(
|
|
| 30 |
)
|
| 31 |
args = parser.parse_args()
|
| 32 |
|
| 33 |
-
def
|
|
|
|
|
|
|
|
|
|
| 34 |
if input_spectrometer_frequency == 0:
|
| 35 |
input_spectrometer_frequency = None
|
| 36 |
# Load configuration and initialize predictor
|
|
@@ -65,7 +68,10 @@ def process_file(input_file, config_file, weights_file, input_spectrometer_frequ
|
|
| 65 |
|
| 66 |
# Prepare output data for download
|
| 67 |
output_data = np.column_stack((input_freqs_input_ppm, output_prediction))
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
| 69 |
np.savetxt(output_file, output_data)
|
| 70 |
|
| 71 |
# Create Plotly figure
|
|
@@ -128,11 +134,13 @@ def process_file(input_file, config_file, weights_file, input_spectrometer_frequ
|
|
| 128 |
|
| 129 |
return fig, output_file
|
| 130 |
|
|
|
|
| 131 |
# Gradio app
|
| 132 |
with gr.Blocks() as app:
|
| 133 |
gr.Markdown("# ShimNet Spectra Correction")
|
| 134 |
-
gr.Markdown("[ShimNet: A neural network for post-acquisition improvement of NMR spectra distorted by magnetic-field inhomogeneity](https://
|
| 135 |
-
gr.Markdown("Upload your input file
|
|
|
|
| 136 |
|
| 137 |
with gr.Row():
|
| 138 |
with gr.Column():
|
|
@@ -144,12 +152,14 @@ with gr.Blocks() as app:
|
|
| 144 |
config_file = gr.File(label="Custom Config File (.yaml)", visible=False, height=120)
|
| 145 |
weights_file = gr.File(label="Custom Weights File (.pt)", visible=False, height=120)
|
| 146 |
|
| 147 |
-
|
| 148 |
-
label="
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
| 153 |
with gr.Column():
|
| 154 |
input_file = gr.File(label="Input File (.txt | .csv)", height=120)
|
| 155 |
input_spectrometer_frequency = gr.Number(label="Input Spectrometer Frequency (MHz) (0 or empty if the same as in the loaded model)", value=None)
|
|
@@ -174,8 +184,23 @@ with gr.Blocks() as app:
|
|
| 174 |
outputs=[config_file, weights_file]
|
| 175 |
)
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
# Process button click logic
|
| 178 |
-
def process_file_with_model(input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file, scale):
|
| 179 |
if model_selection == "600 MHz":
|
| 180 |
config_file = "configs/shimnet_600.yaml"
|
| 181 |
weights_file = "weights/shimnet_600MHz.pt"
|
|
@@ -192,24 +217,14 @@ with gr.Blocks() as app:
|
|
| 192 |
weights_file,
|
| 193 |
input_spectrometer_frequency,
|
| 194 |
reference_spectrum_file.name if reference_spectrum_file else None,
|
| 195 |
-
scale
|
|
|
|
| 196 |
)
|
| 197 |
|
| 198 |
process_button.click(
|
| 199 |
process_file_with_model,
|
| 200 |
-
inputs=[input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file, scale_input],
|
| 201 |
outputs=[plot_output, download_button]
|
| 202 |
)
|
| 203 |
|
| 204 |
app.launch(share=args.share, server_name=args.server_name)
|
| 205 |
-
|
| 206 |
-
# '#636efa',
|
| 207 |
-
# '#EF553B',
|
| 208 |
-
# '#00cc96',
|
| 209 |
-
# '#ab63fa',
|
| 210 |
-
# '#FFA15A',
|
| 211 |
-
# '#19d3f3',
|
| 212 |
-
# '#FF6692',
|
| 213 |
-
# '#B6E880',
|
| 214 |
-
# '#FF97FF',
|
| 215 |
-
# '#FECB52'
|
|
|
|
| 30 |
)
|
| 31 |
args = parser.parse_args()
|
| 32 |
|
| 33 |
+
def default_out_filename_from_input_filename(input_file):
|
| 34 |
+
return f"{Path(input_file.name).stem}_processed{Path(input_file.name).suffix}"
|
| 35 |
+
|
| 36 |
+
def process_file(input_file, config_file, weights_file, input_spectrometer_frequency=None,reference_spectrum=None, scale=None, output_filename=None):
|
| 37 |
if input_spectrometer_frequency == 0:
|
| 38 |
input_spectrometer_frequency = None
|
| 39 |
# Load configuration and initialize predictor
|
|
|
|
| 68 |
|
| 69 |
# Prepare output data for download
|
| 70 |
output_data = np.column_stack((input_freqs_input_ppm, output_prediction))
|
| 71 |
+
if output_filename is None or output_filename.strip() == "":
|
| 72 |
+
output_file = default_out_filename_from_input_filename(input_file)
|
| 73 |
+
else:
|
| 74 |
+
output_file = output_filename
|
| 75 |
np.savetxt(output_file, output_data)
|
| 76 |
|
| 77 |
# Create Plotly figure
|
|
|
|
| 134 |
|
| 135 |
return fig, output_file
|
| 136 |
|
| 137 |
+
|
| 138 |
# Gradio app
|
| 139 |
with gr.Blocks() as app:
|
| 140 |
gr.Markdown("# ShimNet Spectra Correction")
|
| 141 |
+
gr.Markdown("[ShimNet: A neural network for post-acquisition improvement of NMR spectra distorted by magnetic-field inhomogeneity](https://doi.org/10.1021/acs.jpcb.5c02632)")
|
| 142 |
+
gr.Markdown("Upload your input file and select the model to process the NMR spectrum. Select 'Custom' to provide your own configuration and weights files.")
|
| 143 |
+
gr.Markdown("Sample input files: https://huggingface.co/spaces/NMR-CeNT-UW/ShimNet/tree/main/sample_data")
|
| 144 |
|
| 145 |
with gr.Row():
|
| 146 |
with gr.Column():
|
|
|
|
| 152 |
config_file = gr.File(label="Custom Config File (.yaml)", visible=False, height=120)
|
| 153 |
weights_file = gr.File(label="Custom Weights File (.pt)", visible=False, height=120)
|
| 154 |
|
| 155 |
+
with gr.Accordion("Custom output file name / advanced options", open=False):
|
| 156 |
+
output_filename = gr.Textbox(label="Output File Name", placeholder="set automatically after input file upload", interactive=True)
|
| 157 |
+
scale_input = gr.Number(
|
| 158 |
+
label="Scale (Intensity Normalization)",
|
| 159 |
+
value=Defaults.SCALE,
|
| 160 |
+
info="Adjust the scaling factor for intensity normalization. Default is 16.",
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
with gr.Column():
|
| 164 |
input_file = gr.File(label="Input File (.txt | .csv)", height=120)
|
| 165 |
input_spectrometer_frequency = gr.Number(label="Input Spectrometer Frequency (MHz) (0 or empty if the same as in the loaded model)", value=None)
|
|
|
|
| 184 |
outputs=[config_file, weights_file]
|
| 185 |
)
|
| 186 |
|
| 187 |
+
# Auto-populate output filename when input file is uploaded
|
| 188 |
+
def generate_output_filename(input_file, output_filename):
|
| 189 |
+
# do not overwrite if user has provided a filename
|
| 190 |
+
if output_filename is not None and output_filename.strip() != "":
|
| 191 |
+
return output_filename
|
| 192 |
+
if input_file is None:
|
| 193 |
+
return ""
|
| 194 |
+
return default_out_filename_from_input_filename(input_file)
|
| 195 |
+
|
| 196 |
+
input_file.change(
|
| 197 |
+
generate_output_filename,
|
| 198 |
+
inputs=[input_file, output_filename],
|
| 199 |
+
outputs=[output_filename]
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
# Process button click logic
|
| 203 |
+
def process_file_with_model(input_file, output_filename, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file, scale):
|
| 204 |
if model_selection == "600 MHz":
|
| 205 |
config_file = "configs/shimnet_600.yaml"
|
| 206 |
weights_file = "weights/shimnet_600MHz.pt"
|
|
|
|
| 217 |
weights_file,
|
| 218 |
input_spectrometer_frequency,
|
| 219 |
reference_spectrum_file.name if reference_spectrum_file else None,
|
| 220 |
+
scale,
|
| 221 |
+
output_filename
|
| 222 |
)
|
| 223 |
|
| 224 |
process_button.click(
|
| 225 |
process_file_with_model,
|
| 226 |
+
inputs=[input_file, output_filename, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file, scale_input],
|
| 227 |
outputs=[plot_output, download_button]
|
| 228 |
)
|
| 229 |
|
| 230 |
app.launch(share=args.share, server_name=args.server_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|