Spaces:
Sleeping
Sleeping
Marek Bukowicki commited on
Commit ·
62da572
1
Parent(s): c2e4af2
add reference spectra in GUI
Browse files- predict-gui.py +14 -6
predict-gui.py
CHANGED
|
@@ -13,7 +13,10 @@ from predict import Defaults, resample_input_spectrum, resample_output_spectrum,
|
|
| 13 |
import warnings
|
| 14 |
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
| 15 |
|
| 16 |
-
def
|
|
|
|
|
|
|
|
|
|
| 17 |
if input_spectrometer_frequency == 0:
|
| 18 |
input_spectrometer_frequency = None
|
| 19 |
# Load configuration and initialize predictor
|
|
@@ -51,8 +54,11 @@ def process_file(input_file, config_file, weights_file, input_spectrometer_frequ
|
|
| 51 |
|
| 52 |
# Create Plotly figure
|
| 53 |
fig = go.Figure()
|
| 54 |
-
fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=input_spectrum, mode='lines', name='Input Spectrum'))
|
| 55 |
-
fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=output_prediction, mode='lines', name='Corrected Spectrum'))
|
|
|
|
|
|
|
|
|
|
| 56 |
fig.update_layout(title="Spectrum Visualization", xaxis_title="Frequency (ppm)", yaxis_title="Intensity")
|
| 57 |
|
| 58 |
return fig, output_file
|
|
@@ -84,16 +90,18 @@ with gr.Blocks() as app:
|
|
| 84 |
weights_file = gr.File(label="Weights File (.pt)", height=120, value="weights/shimnet_600MHz.pt")
|
| 85 |
|
| 86 |
with gr.Column():
|
| 87 |
-
input_file = gr.File(label="Input File (.txt | .csv)", height=
|
| 88 |
input_spectrometer_frequency = gr.Number(label="Input Spectrometer Frequency (MHz) (0 or empty if the same as in the loaded model)", value=None)
|
| 89 |
-
|
|
|
|
|
|
|
| 90 |
process_button = gr.Button("Process File")
|
| 91 |
plot_output = gr.Plot(label="Spectrum Visualization")
|
| 92 |
download_button = gr.File(label="Download Processed File", interactive=False, height=120)
|
| 93 |
|
| 94 |
process_button.click(
|
| 95 |
process_file,
|
| 96 |
-
inputs=[input_file, config_file, weights_file, input_spectrometer_frequency],
|
| 97 |
outputs=[plot_output, download_button]
|
| 98 |
)
|
| 99 |
|
|
|
|
| 13 |
import warnings
|
| 14 |
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
| 15 |
|
| 16 |
+
def fast_normalize(x):
|
| 17 |
+
return x / np.max(x)
|
| 18 |
+
|
| 19 |
+
def process_file(input_file, config_file, weights_file, input_spectrometer_frequency=None,reference_spectra=[], normalize_spectra_for_plotting=True):
|
| 20 |
if input_spectrometer_frequency == 0:
|
| 21 |
input_spectrometer_frequency = None
|
| 22 |
# Load configuration and initialize predictor
|
|
|
|
| 54 |
|
| 55 |
# Create Plotly figure
|
| 56 |
fig = go.Figure()
|
| 57 |
+
fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=fast_normalize(input_spectrum) if normalize_spectra_for_plotting else input_spectrum, mode='lines', name='Input Spectrum'))
|
| 58 |
+
fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=fast_normalize(output_prediction) if normalize_spectra_for_plotting else output_prediction, mode='lines', name='Corrected Spectrum'))
|
| 59 |
+
for reference_spectrum_file in reference_spectra:
|
| 60 |
+
reference_data = np.loadtxt(reference_spectrum_file.name)
|
| 61 |
+
fig.add_trace(go.Scatter(x=reference_data[:, 0], y=fast_normalize(reference_data[:, 1]) if normalize_spectra_for_plotting else reference_data[:, 1], mode='lines', name=f'Reference Spectrum {Path(reference_spectrum_file.name).stem}'))
|
| 62 |
fig.update_layout(title="Spectrum Visualization", xaxis_title="Frequency (ppm)", yaxis_title="Intensity")
|
| 63 |
|
| 64 |
return fig, output_file
|
|
|
|
| 90 |
weights_file = gr.File(label="Weights File (.pt)", height=120, value="weights/shimnet_600MHz.pt")
|
| 91 |
|
| 92 |
with gr.Column():
|
| 93 |
+
input_file = gr.File(label="Input File (.txt | .csv)", height=120)
|
| 94 |
input_spectrometer_frequency = gr.Number(label="Input Spectrometer Frequency (MHz) (0 or empty if the same as in the loaded model)", value=None)
|
| 95 |
+
gr.Markdown("Upload reference spectra files (optional). Reference spectra will be plotted for comparison.")
|
| 96 |
+
reference_spectra = gr.Files(label="Reference Spectra File(s) (.txt | .csv)", height=120)
|
| 97 |
+
normalize_spectra_for_plotting = gr.Checkbox(label="Normalize Spectra for Plotting", value=True)
|
| 98 |
process_button = gr.Button("Process File")
|
| 99 |
plot_output = gr.Plot(label="Spectrum Visualization")
|
| 100 |
download_button = gr.File(label="Download Processed File", interactive=False, height=120)
|
| 101 |
|
| 102 |
process_button.click(
|
| 103 |
process_file,
|
| 104 |
+
inputs=[input_file, config_file, weights_file, input_spectrometer_frequency, reference_spectra, normalize_spectra_for_plotting],
|
| 105 |
outputs=[plot_output, download_button]
|
| 106 |
)
|
| 107 |
|