import os import torch torch.set_grad_enabled(False) import numpy as np from pathlib import Path from omegaconf import OmegaConf import gradio as gr import plotly.graph_objects as go from shimnet.predict_utils import Defaults, resample_input_spectrum, resample_output_spectrum, initialize_predictor, get_model_ppm_per_point # silent deprecation warnings import warnings warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') import argparse # Add argument parsing for server_name parser = argparse.ArgumentParser(description="Launch ShimNet Spectra Correction App") parser.add_argument( "--server_name", type=str, default="127.0.0.1", help="Server name to bind the app (default: 127.0.0.1). Use 0.0.0.0 for external access." ) parser.add_argument( "--share", action="store_true", help="If set, generates a public link to share the app." ) args = parser.parse_args() def process_file(input_file, config_file, weights_file, input_spectrometer_frequency=None,reference_spectrum=None, scale=None, suffix=None): if input_spectrometer_frequency == 0: input_spectrometer_frequency = None # Load configuration and initialize predictor config = OmegaConf.load(config_file) model_ppm_per_point = get_model_ppm_per_point(config) predictor = initialize_predictor(config, weights_file) # Load input data input_data = np.loadtxt(input_file) input_freqs_input_ppm, input_spectrum = input_data[:, 0], input_data[:, 1] # Convert input frequencies to model's frequency if input_spectrometer_frequency is not None: input_freqs_model_ppm = input_freqs_input_ppm * input_spectrometer_frequency / config.metadata.spectrometer_frequency else: input_freqs_model_ppm = input_freqs_input_ppm # Resample input spectrum freqs, spectrum = resample_input_spectrum(input_freqs_model_ppm, input_spectrum, model_ppm_per_point) # Scale and process spectrum spectrum_tensor = torch.tensor(spectrum).float() if scale is None: scale = Defaults.SCALE scaling_factor = scale / spectrum_tensor.max() spectrum_tensor *= scaling_factor prediction = predictor(spectrum_tensor).numpy() prediction /= scaling_factor # Resample output spectrum output_prediction = resample_output_spectrum(input_freqs_model_ppm, freqs, prediction) # Prepare output data for download output_data = np.column_stack((input_freqs_input_ppm, output_prediction)) if suffix is None: suffix = Defaults.SUFFIX output_file = f"{Path(input_file).stem}{suffix}{Path(input_file).suffix}" np.savetxt(output_file, output_data) # Create Plotly figure fig = go.Figure() # Add Input Spectrum and Corrected Spectrum (always visible) normalization_value = input_spectrum.max() fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=input_spectrum/normalization_value, mode='lines', name='Input Spectrum', visible=True, line=dict(color='#EF553B'))) # red fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=output_prediction/normalization_value, mode='lines', name='Corrected Spectrum', visible=True, line=dict(color='#00cc96'))) # green if reference_spectrum is not None: reference_spectrum_freqs, reference_spectrum_intensity = np.loadtxt(reference_spectrum).T reference_spectrum_intensity /= reference_spectrum_intensity.max() n_zooms = 50 zooms = np.geomspace(0.01, 100, 2 * n_zooms + 1) # Add Reference Data traces (initially invisible) for zoom in zooms: fig.add_trace( go.Scatter( x=reference_spectrum_freqs, y=reference_spectrum_intensity * zoom, mode='lines', name=f'Reference Data (Zoom: {zoom:.2f})', visible=False, line=dict(color='#636efa') ) ) # Make the middle zoom level visible by default fig.data[2 * n_zooms // 2 + 2].visible = True # Create and add slider steps = [] for i in range(2, len(fig.data)): # Start from the reference data traces step = dict( method="update", args=[{"visible": [True, True] + [False] * (len(fig.data) - 2)}], # Keep first two traces visible ) step["args"][0]["visible"][i] = True # Toggle i'th reference trace to "visible" steps.append(step) sliders = [dict( active=n_zooms, currentvalue={"prefix": "Reference zoom: "}, pad={"t": 50}, steps=steps )] fig.update_layout( sliders=sliders ) fig.update_layout( title="Spectrum Visualization", xaxis_title="Frequency (ppm)", yaxis_title="Intensity" ) # reverse x-axis fig.update_xaxes(autorange="reversed") return fig, output_file # Gradio app with gr.Blocks() as app: gr.Markdown("# ShimNet Spectra Correction") gr.Markdown("[ShimNet: A neural network for post-acquisition improvement of NMR spectra distorted by magnetic-field inhomogeneity](https://doi.org/10.1021/acs.jpcb.5c02632)") gr.Markdown("Upload your input file, configuration, and weights to process the NMR spectrum.") with gr.Row(): with gr.Column(): model_selection = gr.Radio( label="Select Model", choices=["600 MHz", "700 MHz", "M-E01", "Custom"], value="600 MHz" ) config_file = gr.File(label="Custom Config File (.yaml)", visible=False, height=120) weights_file = gr.File(label="Custom Weights File (.pt)", visible=False, height=120) with gr.Accordion("Advanced", open=False): scale_input = gr.Number( label="Intensity Scale", value=Defaults.SCALE, info=f"Adjust the scaling factor for intensity normalization. Default is {Defaults.SCALE}.", ) suffix_input = gr.Textbox( label="Output File Suffix", value=Defaults.SUFFIX, info=f"Suffix to add to processed output filenames. Default is '{Defaults.SUFFIX}'.", ) with gr.Column(): input_file = gr.File(label="Input File (.txt | .csv)", height=120) input_spectrometer_frequency = gr.Number(label="Input Spectrometer Frequency (MHz) (0 or empty if the same as in the loaded model)", value=None) gr.Markdown("Upload reference spectrum files (optional). Reference spectrum will be plotted for comparison.") reference_spectrum_file = gr.File(label="Reference Spectra File (.txt | .csv)", height=120) process_button = gr.Button("Process File") plot_output = gr.Plot(label="Spectrum Visualization") download_button = gr.File(label="Download Processed File", interactive=False, height=120) # Update visibility of config and weights fields based on model selection def update_visibility(selected_model): if selected_model == "Custom": return gr.update(visible=True), gr.update(visible=True) else: return gr.update(visible=False), gr.update(visible=False) model_selection.change( update_visibility, inputs=[model_selection], outputs=[config_file, weights_file] ) # Process button click logic def process_file_with_model(input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file, scale, suffix): if model_selection == "600 MHz": config_file = os.path.join(os.path.dirname(__file__), "configs/shimnet_600.yaml") weights_file = os.path.join(os.path.dirname(__file__), "weights/shimnet_600MHz.pt") elif model_selection == "700 MHz": config_file = os.path.join(os.path.dirname(__file__), "configs/shimnet_700.yaml") weights_file = os.path.join(os.path.dirname(__file__), "weights/shimnet_700MHz.pt") elif model_selection == "M-E01": config_file = os.path.join(os.path.dirname(__file__), "configs/shimnet_600_M-E01.yaml") weights_file = os.path.join(os.path.dirname(__file__), "weights/shimnet_600MHz_M-E01.pt") else: config_file = config_file.name weights_file = weights_file.name return process_file( input_file.name, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file.name if reference_spectrum_file else None, scale, suffix ) process_button.click( process_file_with_model, inputs=[input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file, scale_input, suffix_input], outputs=[plot_output, download_button] ) app.launch(share=args.share, server_name=args.server_name)