Marek Bukowicki commited on
Commit
d0daecc
·
1 Parent(s): 7fcb1aa

fix extra spectra evaluation

Browse files
Files changed (3) hide show
  1. predict.py +0 -4
  2. shimnet/predict_utils.py +1 -0
  3. train.py +17 -9
predict.py CHANGED
@@ -14,10 +14,6 @@ from shimnet.predict_utils import Defaults, resample_input_spectrum, resample_ou
14
  import warnings
15
  warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
16
 
17
- class Defaults:
18
- SCALE = 16.0
19
- SUFFIX = "_processed"
20
-
21
  def parse_args():
22
  parser = argparse.ArgumentParser()
23
  parser.add_argument("input_files", help="Input files", nargs="+")
 
14
  import warnings
15
  warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
16
 
 
 
 
 
17
  def parse_args():
18
  parser = argparse.ArgumentParser()
19
  parser.add_argument("input_files", help="Input files", nargs="+")
shimnet/predict_utils.py CHANGED
@@ -4,6 +4,7 @@ from .models import ShimNetWithSCRF, Predictor
4
 
5
  class Defaults:
6
  SCALE = 16.0
 
7
 
8
  # functions
9
  def resample_input_spectrum(input_freqs, input_spectrum, Mhz_per_point):
 
4
 
5
  class Defaults:
6
  SCALE = 16.0
7
+ SUFFIX = "_processed"
8
 
9
  # functions
10
  def resample_input_spectrum(input_freqs, input_spectrum, Mhz_per_point):
train.py CHANGED
@@ -17,6 +17,7 @@ warnings.filterwarnings("ignore", category=UserWarning, module='torchdata')
17
 
18
  # from shiment import models
19
  from shimnet.generators import get_datapipe
 
20
 
21
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
  if len(sys.argv) < 2:
@@ -35,6 +36,7 @@ else:
35
  extra_spectra_for_evaluation = {}
36
  frq_step = config.data.get("frq_step") or config.metadata.get("frq_step")
37
  model_ppm_per_point = frq_step / config.metadata.spectrometer_frequency
 
38
 
39
  for spectra_data in config.logging.get('extra_spectra_for_evaluation', []):
40
  spectrum_file = Path(spectra_data.path)
@@ -48,6 +50,9 @@ for spectra_data in config.logging.get('extra_spectra_for_evaluation', []):
48
  spectrum_freqs = np.arange(spectrum_freqs_model_ppm.min(), spectrum_freqs_model_ppm.max(), model_ppm_per_point)
49
  spectrum = np.interp(spectrum_freqs, spectrum_freqs_model_ppm, spectrum)
50
 
 
 
 
51
  extra_spectra_for_evaluation[Path(spectrum_file).stem] = {
52
  'frequencies': spectrum_freqs,
53
  'spectrum': spectrum,
@@ -106,22 +111,23 @@ def evaluate_model(stage=0, epoch=0):
106
  extra_spectra_dir = plot_dir / "extra_spectra"
107
  extra_spectra_dir.mkdir(exist_ok=True, parents=True)
108
  for spectrum_name, spectrum_data in extra_spectra_for_evaluation.items():
109
- spectrum = torch.tensor(spectrum_data['spectrum']).to(device)
110
  with torch.no_grad():
111
- out = model(spectrum.unsqueeze(0))
112
- noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same").cpu().squeeze(0)
 
113
 
114
  plt.figure(figsize=(30,6))
115
- plt.plot(spectrum_data['frequencies'], spectrum.cpu().numpy())
116
- plt.plot(spectrum_data['frequencies'], out['denoised'].cpu().squeeze(0).numpy())
117
  plt.savefig(extra_spectra_dir / f"{spectrum_name}_clean.png")
118
- np.savetxt(extra_spectra_dir / f"{spectrum_name}_clean.csv", np.column_stack((spectrum_data['frequencies'], out['denoised'].cpu().squeeze(0).numpy())))
119
 
120
  plt.figure(figsize=(30,6))
121
- plt.plot(spectrum_data['frequencies'], spectrum.cpu().numpy())
122
- plt.plot(spectrum_data['frequencies'], noised_est.numpy())
123
  plt.savefig(extra_spectra_dir / f"{spectrum_name}_noised.png")
124
- np.savetxt(extra_spectra_dir / f"{spectrum_name}_noised.csv", np.column_stack((spectrum_data['frequencies'], noised_est.numpy())))
125
 
126
  # Save response and attention if available
127
  if 'response' in out:
@@ -139,6 +145,8 @@ def evaluate_model(stage=0, epoch=0):
139
  plt.close("all")
140
 
141
 
 
 
142
  for i_stage, training_stage in enumerate(config.training):
143
  if model_weights_file.is_file():
144
  model.load_state_dict(torch.load(model_weights_file, weights_only=True))
 
17
 
18
  # from shiment import models
19
  from shimnet.generators import get_datapipe
20
+ from shimnet.predict_utils import Defaults as PredictDefaults
21
 
22
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
23
  if len(sys.argv) < 2:
 
36
  extra_spectra_for_evaluation = {}
37
  frq_step = config.data.get("frq_step") or config.metadata.get("frq_step")
38
  model_ppm_per_point = frq_step / config.metadata.spectrometer_frequency
39
+ evaluation_spectra_normalization = config.logging.get("evaluation_spectra_normalization", PredictDefaults.SCALE)
40
 
41
  for spectra_data in config.logging.get('extra_spectra_for_evaluation', []):
42
  spectrum_file = Path(spectra_data.path)
 
50
  spectrum_freqs = np.arange(spectrum_freqs_model_ppm.min(), spectrum_freqs_model_ppm.max(), model_ppm_per_point)
51
  spectrum = np.interp(spectrum_freqs, spectrum_freqs_model_ppm, spectrum)
52
 
53
+ if evaluation_spectra_normalization is not None:
54
+ spectrum = spectrum * (evaluation_spectra_normalization / np.max(spectrum))
55
+
56
  extra_spectra_for_evaluation[Path(spectrum_file).stem] = {
57
  'frequencies': spectrum_freqs,
58
  'spectrum': spectrum,
 
111
  extra_spectra_dir = plot_dir / "extra_spectra"
112
  extra_spectra_dir.mkdir(exist_ok=True, parents=True)
113
  for spectrum_name, spectrum_data in extra_spectra_for_evaluation.items():
114
+ spectrum_input = torch.tensor(spectrum_data['spectrum']).float().to(device).unsqueeze(0).unsqueeze(0)
115
  with torch.no_grad():
116
+ out = model(spectrum_input)
117
+ noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same").cpu().squeeze(0).squeeze(0).numpy()
118
+ denoised_est = out['denoised'].cpu().squeeze(0).squeeze(0).numpy()
119
 
120
  plt.figure(figsize=(30,6))
121
+ plt.plot(spectrum_data['frequencies'], spectrum_data['spectrum'])
122
+ plt.plot(spectrum_data['frequencies'], denoised_est)
123
  plt.savefig(extra_spectra_dir / f"{spectrum_name}_clean.png")
124
+ np.savetxt(extra_spectra_dir / f"{spectrum_name}_clean.csv", np.column_stack((spectrum_data['frequencies'], denoised_est)))
125
 
126
  plt.figure(figsize=(30,6))
127
+ plt.plot(spectrum_data['frequencies'], spectrum_data['spectrum'])
128
+ plt.plot(spectrum_data['frequencies'], noised_est)
129
  plt.savefig(extra_spectra_dir / f"{spectrum_name}_noised.png")
130
+ np.savetxt(extra_spectra_dir / f"{spectrum_name}_noised.csv", np.column_stack((spectrum_data['frequencies'], noised_est)))
131
 
132
  # Save response and attention if available
133
  if 'response' in out:
 
145
  plt.close("all")
146
 
147
 
148
+
149
+ print("BatchInStage Loss AvgLoss CleanLoss RespLoss NoisedLoss MultiscaleCleanLoss")
150
  for i_stage, training_stage in enumerate(config.training):
151
  if model_weights_file.is_file():
152
  model.load_state_dict(torch.load(model_weights_file, weights_only=True))