Marek Bukowicki commited on
Commit
c2e4af2
·
1 Parent(s): e0a6e6f

working gui

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. Readme.md +22 -0
  3. predict-gui.py +100 -0
  4. requirements-gui.txt +2 -0
.gitignore CHANGED
@@ -9,3 +9,5 @@ data/
9
  # typically weights and data
10
  *.pt
11
 
 
 
 
9
  # typically weights and data
10
  *.pt
11
 
12
+ # gradio
13
+ .gradio/
Readme.md CHANGED
@@ -186,3 +186,25 @@ If you want to train the network using the calibration data from our paper, foll
186
  ```
187
  Training results will appear in `runs/repeat_paper_training_600MHz` or `runs/repeat_paper_training_700MHz` directory.
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  ```
187
  Training results will appear in `runs/repeat_paper_training_600MHz` or `runs/repeat_paper_training_700MHz` directory.
188
 
189
+ ## GUI
190
+
191
+ ### GUI installation
192
+
193
+ GUI requires Python 3.10. Not tested for Python 3.11+
194
+
195
+
196
+ After installing ShimNet requirements (CPU/GPU) install GUI requirements:
197
+
198
+ ```bash
199
+ pip install -r requirements-gui.txt
200
+ ```
201
+
202
+ ### GUI usage
203
+
204
+ ```bash
205
+ python predict-gui.py
206
+ ```
207
+
208
+ Open your browser and go to `http://127.0.0.1:7860` address to use locally.
209
+
210
+ Under address given in the terminal message after `Running on public URL:` the tool may be used on other computers
predict-gui.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ torch.set_grad_enabled(False)
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from omegaconf import OmegaConf
6
+ import gradio as gr
7
+ import plotly.graph_objects as go
8
+
9
+ from src.models import ShimNetWithSCRF, Predictor
10
+ from predict import Defaults, resample_input_spectrum, resample_output_spectrum, initialize_predictor
11
+
12
+ # silent deprecation warnings
13
+ import warnings
14
+ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
15
+
16
+ def process_file(input_file, config_file, weights_file, input_spectrometer_frequency=None):
17
+ if input_spectrometer_frequency == 0:
18
+ input_spectrometer_frequency = None
19
+ # Load configuration and initialize predictor
20
+ config = OmegaConf.load(config_file.name)
21
+ model_ppm_per_point = config.data.frq_step / config.metadata.spectrometer_frequency
22
+ predictor = initialize_predictor(config, weights_file.name)
23
+
24
+ # Load input data
25
+ input_data = np.loadtxt(input_file.name)
26
+ input_freqs_input_ppm, input_spectrum = input_data[:, 0], input_data[:, 1]
27
+
28
+ # Convert input frequencies to model's frequency
29
+ if input_spectrometer_frequency is not None:
30
+ input_freqs_model_ppm = input_freqs_input_ppm * input_spectrometer_frequency / config.metadata.spectrometer_frequency
31
+ else:
32
+ input_freqs_model_ppm = input_freqs_input_ppm
33
+
34
+ # Resample input spectrum
35
+ freqs, spectrum = resample_input_spectrum(input_freqs_model_ppm, input_spectrum, model_ppm_per_point)
36
+
37
+ # Scale and process spectrum
38
+ spectrum_tensor = torch.tensor(spectrum).float()
39
+ scaling_factor = Defaults.SCALE / spectrum_tensor.max()
40
+ spectrum_tensor *= scaling_factor
41
+ prediction = predictor(spectrum_tensor).numpy()
42
+ prediction /= scaling_factor
43
+
44
+ # Resample output spectrum
45
+ output_prediction = resample_output_spectrum(input_freqs_model_ppm, freqs, prediction)
46
+
47
+ # Prepare output data for download
48
+ output_data = np.column_stack((input_freqs_input_ppm, output_prediction))
49
+ output_file = f"{Path(input_file.name).stem}_processed{Path(input_file.name).suffix}"
50
+ np.savetxt(output_file, output_data)
51
+
52
+ # Create Plotly figure
53
+ fig = go.Figure()
54
+ fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=input_spectrum, mode='lines', name='Input Spectrum'))
55
+ fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=output_prediction, mode='lines', name='Corrected Spectrum'))
56
+ fig.update_layout(title="Spectrum Visualization", xaxis_title="Frequency (ppm)", yaxis_title="Intensity")
57
+
58
+ return fig, output_file
59
+
60
+ # app = gr.Interface(
61
+ # fn=process_file,
62
+ # inputs=[
63
+ # gr.File(label="Input File (.txt | .csv)"),
64
+ # gr.File(label="Config File (.yaml)"),
65
+ # gr.File(label="Weights File (.pt)"),
66
+ # gr.Number(label="Input Spectrometer Frequency (MHz)", value=None)
67
+ # ],
68
+ # outputs=[
69
+ # gr.Plot(label="Spectrum Visualization"),
70
+ # gr.File(label="Download Processed File")
71
+ # ],
72
+ # title="NMR Spectrum Prediction",
73
+ # description="Upload your input file, configuration, and weights to process the NMR spectrum."
74
+ # )
75
+
76
+ # Gradio app
77
+ with gr.Blocks() as app:
78
+ gr.Markdown("# ShimNet Spectra Correction")
79
+ gr.Markdown("Upload your input file, configuration, and weights to process the NMR spectrum.")
80
+
81
+ with gr.Row():
82
+ with gr.Column():
83
+ config_file = gr.File(label="Config File (.yaml)", height=120, value="configs/shimnet_600.yaml")
84
+ weights_file = gr.File(label="Weights File (.pt)", height=120, value="weights/shimnet_600MHz.pt")
85
+
86
+ with gr.Column():
87
+ input_file = gr.File(label="Input File (.txt | .csv)", height=150)
88
+ input_spectrometer_frequency = gr.Number(label="Input Spectrometer Frequency (MHz) (0 or empty if the same as in the loaded model)", value=None)
89
+
90
+ process_button = gr.Button("Process File")
91
+ plot_output = gr.Plot(label="Spectrum Visualization")
92
+ download_button = gr.File(label="Download Processed File", interactive=False, height=120)
93
+
94
+ process_button.click(
95
+ process_file,
96
+ inputs=[input_file, config_file, weights_file, input_spectrometer_frequency],
97
+ outputs=[plot_output, download_button]
98
+ )
99
+
100
+ app.launch(share=True)
requirements-gui.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio==5.23.2
2
+ plotly==6.0.1