Spaces:
Sleeping
Sleeping
Marek Bukowicki commited on
Commit ·
d0daecc
1
Parent(s): 7fcb1aa
fix extra spectra evaluation
Browse files- predict.py +0 -4
- shimnet/predict_utils.py +1 -0
- train.py +17 -9
predict.py
CHANGED
|
@@ -14,10 +14,6 @@ from shimnet.predict_utils import Defaults, resample_input_spectrum, resample_ou
|
|
| 14 |
import warnings
|
| 15 |
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
| 16 |
|
| 17 |
-
class Defaults:
|
| 18 |
-
SCALE = 16.0
|
| 19 |
-
SUFFIX = "_processed"
|
| 20 |
-
|
| 21 |
def parse_args():
|
| 22 |
parser = argparse.ArgumentParser()
|
| 23 |
parser.add_argument("input_files", help="Input files", nargs="+")
|
|
|
|
| 14 |
import warnings
|
| 15 |
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
def parse_args():
|
| 18 |
parser = argparse.ArgumentParser()
|
| 19 |
parser.add_argument("input_files", help="Input files", nargs="+")
|
shimnet/predict_utils.py
CHANGED
|
@@ -4,6 +4,7 @@ from .models import ShimNetWithSCRF, Predictor
|
|
| 4 |
|
| 5 |
class Defaults:
|
| 6 |
SCALE = 16.0
|
|
|
|
| 7 |
|
| 8 |
# functions
|
| 9 |
def resample_input_spectrum(input_freqs, input_spectrum, Mhz_per_point):
|
|
|
|
| 4 |
|
| 5 |
class Defaults:
|
| 6 |
SCALE = 16.0
|
| 7 |
+
SUFFIX = "_processed"
|
| 8 |
|
| 9 |
# functions
|
| 10 |
def resample_input_spectrum(input_freqs, input_spectrum, Mhz_per_point):
|
train.py
CHANGED
|
@@ -17,6 +17,7 @@ warnings.filterwarnings("ignore", category=UserWarning, module='torchdata')
|
|
| 17 |
|
| 18 |
# from shiment import models
|
| 19 |
from shimnet.generators import get_datapipe
|
|
|
|
| 20 |
|
| 21 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 22 |
if len(sys.argv) < 2:
|
|
@@ -35,6 +36,7 @@ else:
|
|
| 35 |
extra_spectra_for_evaluation = {}
|
| 36 |
frq_step = config.data.get("frq_step") or config.metadata.get("frq_step")
|
| 37 |
model_ppm_per_point = frq_step / config.metadata.spectrometer_frequency
|
|
|
|
| 38 |
|
| 39 |
for spectra_data in config.logging.get('extra_spectra_for_evaluation', []):
|
| 40 |
spectrum_file = Path(spectra_data.path)
|
|
@@ -48,6 +50,9 @@ for spectra_data in config.logging.get('extra_spectra_for_evaluation', []):
|
|
| 48 |
spectrum_freqs = np.arange(spectrum_freqs_model_ppm.min(), spectrum_freqs_model_ppm.max(), model_ppm_per_point)
|
| 49 |
spectrum = np.interp(spectrum_freqs, spectrum_freqs_model_ppm, spectrum)
|
| 50 |
|
|
|
|
|
|
|
|
|
|
| 51 |
extra_spectra_for_evaluation[Path(spectrum_file).stem] = {
|
| 52 |
'frequencies': spectrum_freqs,
|
| 53 |
'spectrum': spectrum,
|
|
@@ -106,22 +111,23 @@ def evaluate_model(stage=0, epoch=0):
|
|
| 106 |
extra_spectra_dir = plot_dir / "extra_spectra"
|
| 107 |
extra_spectra_dir.mkdir(exist_ok=True, parents=True)
|
| 108 |
for spectrum_name, spectrum_data in extra_spectra_for_evaluation.items():
|
| 109 |
-
|
| 110 |
with torch.no_grad():
|
| 111 |
-
out = model(
|
| 112 |
-
noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same").cpu().squeeze(0)
|
|
|
|
| 113 |
|
| 114 |
plt.figure(figsize=(30,6))
|
| 115 |
-
plt.plot(spectrum_data['frequencies'], spectrum
|
| 116 |
-
plt.plot(spectrum_data['frequencies'],
|
| 117 |
plt.savefig(extra_spectra_dir / f"{spectrum_name}_clean.png")
|
| 118 |
-
np.savetxt(extra_spectra_dir / f"{spectrum_name}_clean.csv", np.column_stack((spectrum_data['frequencies'],
|
| 119 |
|
| 120 |
plt.figure(figsize=(30,6))
|
| 121 |
-
plt.plot(spectrum_data['frequencies'], spectrum
|
| 122 |
-
plt.plot(spectrum_data['frequencies'], noised_est
|
| 123 |
plt.savefig(extra_spectra_dir / f"{spectrum_name}_noised.png")
|
| 124 |
-
np.savetxt(extra_spectra_dir / f"{spectrum_name}_noised.csv", np.column_stack((spectrum_data['frequencies'], noised_est
|
| 125 |
|
| 126 |
# Save response and attention if available
|
| 127 |
if 'response' in out:
|
|
@@ -139,6 +145,8 @@ def evaluate_model(stage=0, epoch=0):
|
|
| 139 |
plt.close("all")
|
| 140 |
|
| 141 |
|
|
|
|
|
|
|
| 142 |
for i_stage, training_stage in enumerate(config.training):
|
| 143 |
if model_weights_file.is_file():
|
| 144 |
model.load_state_dict(torch.load(model_weights_file, weights_only=True))
|
|
|
|
| 17 |
|
| 18 |
# from shiment import models
|
| 19 |
from shimnet.generators import get_datapipe
|
| 20 |
+
from shimnet.predict_utils import Defaults as PredictDefaults
|
| 21 |
|
| 22 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 23 |
if len(sys.argv) < 2:
|
|
|
|
| 36 |
extra_spectra_for_evaluation = {}
|
| 37 |
frq_step = config.data.get("frq_step") or config.metadata.get("frq_step")
|
| 38 |
model_ppm_per_point = frq_step / config.metadata.spectrometer_frequency
|
| 39 |
+
evaluation_spectra_normalization = config.logging.get("evaluation_spectra_normalization", PredictDefaults.SCALE)
|
| 40 |
|
| 41 |
for spectra_data in config.logging.get('extra_spectra_for_evaluation', []):
|
| 42 |
spectrum_file = Path(spectra_data.path)
|
|
|
|
| 50 |
spectrum_freqs = np.arange(spectrum_freqs_model_ppm.min(), spectrum_freqs_model_ppm.max(), model_ppm_per_point)
|
| 51 |
spectrum = np.interp(spectrum_freqs, spectrum_freqs_model_ppm, spectrum)
|
| 52 |
|
| 53 |
+
if evaluation_spectra_normalization is not None:
|
| 54 |
+
spectrum = spectrum * (evaluation_spectra_normalization / np.max(spectrum))
|
| 55 |
+
|
| 56 |
extra_spectra_for_evaluation[Path(spectrum_file).stem] = {
|
| 57 |
'frequencies': spectrum_freqs,
|
| 58 |
'spectrum': spectrum,
|
|
|
|
| 111 |
extra_spectra_dir = plot_dir / "extra_spectra"
|
| 112 |
extra_spectra_dir.mkdir(exist_ok=True, parents=True)
|
| 113 |
for spectrum_name, spectrum_data in extra_spectra_for_evaluation.items():
|
| 114 |
+
spectrum_input = torch.tensor(spectrum_data['spectrum']).float().to(device).unsqueeze(0).unsqueeze(0)
|
| 115 |
with torch.no_grad():
|
| 116 |
+
out = model(spectrum_input)
|
| 117 |
+
noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same").cpu().squeeze(0).squeeze(0).numpy()
|
| 118 |
+
denoised_est = out['denoised'].cpu().squeeze(0).squeeze(0).numpy()
|
| 119 |
|
| 120 |
plt.figure(figsize=(30,6))
|
| 121 |
+
plt.plot(spectrum_data['frequencies'], spectrum_data['spectrum'])
|
| 122 |
+
plt.plot(spectrum_data['frequencies'], denoised_est)
|
| 123 |
plt.savefig(extra_spectra_dir / f"{spectrum_name}_clean.png")
|
| 124 |
+
np.savetxt(extra_spectra_dir / f"{spectrum_name}_clean.csv", np.column_stack((spectrum_data['frequencies'], denoised_est)))
|
| 125 |
|
| 126 |
plt.figure(figsize=(30,6))
|
| 127 |
+
plt.plot(spectrum_data['frequencies'], spectrum_data['spectrum'])
|
| 128 |
+
plt.plot(spectrum_data['frequencies'], noised_est)
|
| 129 |
plt.savefig(extra_spectra_dir / f"{spectrum_name}_noised.png")
|
| 130 |
+
np.savetxt(extra_spectra_dir / f"{spectrum_name}_noised.csv", np.column_stack((spectrum_data['frequencies'], noised_est)))
|
| 131 |
|
| 132 |
# Save response and attention if available
|
| 133 |
if 'response' in out:
|
|
|
|
| 145 |
plt.close("all")
|
| 146 |
|
| 147 |
|
| 148 |
+
|
| 149 |
+
print("BatchInStage Loss AvgLoss CleanLoss RespLoss NoisedLoss MultiscaleCleanLoss")
|
| 150 |
for i_stage, training_stage in enumerate(config.training):
|
| 151 |
if model_weights_file.is_file():
|
| 152 |
model.load_state_dict(torch.load(model_weights_file, weights_only=True))
|