Marek Bukowicki commited on
Commit
cce557d
·
1 Parent(s): 20ba4e5

add experimental spectra evaluation

Browse files
Files changed (1) hide show
  1. train.py +60 -0
train.py CHANGED
@@ -31,6 +31,28 @@ if (run_dir / "train.txt").is_file():
31
  else:
32
  minimum = float("inf")
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # initialization
35
  model = instantiate({"_target_": f"shimnet.models.{config.model.name}", **config.model.kwargs}).to(device)
36
  model_weights_file = run_dir / f'model.pt'
@@ -79,6 +101,44 @@ def evaluate_model(stage=0, epoch=0):
79
 
80
  plt.close("all")
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  for i_stage, training_stage in enumerate(config.training):
83
  if model_weights_file.is_file():
84
  model.load_state_dict(torch.load(model_weights_file, weights_only=True))
 
31
  else:
32
  minimum = float("inf")
33
 
34
+ # prepare spectra for evaluation
35
+ extra_spectra_for_evaluation = {}
36
+ frq_step = config.data.get("frq_step") or config.metadata.get("frq_step")
37
+ model_ppm_per_point = frq_step / config.metadata.spectrometer_frequency
38
+
39
+ for spectra_data in config.logging.get('extra_spectra_for_evaluation', []):
40
+ spectrum_file = Path(spectra_data.path)
41
+ spectrum_freqs_input_ppm, spectrum = np.loadtxt(spectrum_file).T
42
+
43
+ spectrometer_frequency = spectra_data.get("spectrometer_frequency")
44
+ if spectrometer_frequency is None: # spectrometer frequency unknown, assume the same as the model
45
+ spectrum_freqs = spectrum_freqs_input_ppm
46
+ else:
47
+ spectrum_freqs_model_ppm = spectrum_freqs_input_ppm * spectrometer_frequency / config.metadata.spectrometer_frequency
48
+ spectrum_freqs = np.arange(spectrum_freqs_model_ppm.min(), spectrum_freqs_model_ppm.max(), model_ppm_per_point)
49
+ spectrum = np.interp(spectrum_freqs, spectrum_freqs_model_ppm, spectrum)
50
+
51
+ extra_spectra_for_evaluation[Path(spectrum_file).stem] = {
52
+ 'frequencies': spectrum_freqs,
53
+ 'spectrum': spectrum,
54
+ }
55
+
56
  # initialization
57
  model = instantiate({"_target_": f"shimnet.models.{config.model.name}", **config.model.kwargs}).to(device)
58
  model_weights_file = run_dir / f'model.pt'
 
101
 
102
  plt.close("all")
103
 
104
+ # evaluate extra spectra
105
+ if len(extra_spectra_for_evaluation) > 0:
106
+ extra_spectra_dir = plot_dir / "extra_spectra"
107
+ extra_spectra_dir.mkdir(exist_ok=True, parents=True)
108
+ for spectrum_name, spectrum_data in extra_spectra_for_evaluation.items():
109
+ spectrum = torch.tensor(spectrum_data['spectrum']).to(device)
110
+ with torch.no_grad():
111
+ out = model(spectrum.unsqueeze(0))
112
+ noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same").cpu().squeeze(0)
113
+
114
+ plt.figure(figsize=(30,6))
115
+ plt.plot(spectrum_data['frequencies'], spectrum.cpu().numpy())
116
+ plt.plot(spectrum_data['frequencies'], out['denoised'].cpu().squeeze(0).numpy())
117
+ plt.savefig(extra_spectra_dir / f"{spectrum_name}_clean.png")
118
+ np.savetxt(extra_spectra_dir / f"{spectrum_name}_clean.csv", np.column_stack((spectrum_data['frequencies'], out['denoised'].cpu().squeeze(0).numpy())))
119
+
120
+ plt.figure(figsize=(30,6))
121
+ plt.plot(spectrum_data['frequencies'], spectrum.cpu().numpy())
122
+ plt.plot(spectrum_data['frequencies'], noised_est.numpy())
123
+ plt.savefig(extra_spectra_dir / f"{spectrum_name}_noised.png")
124
+ np.savetxt(extra_spectra_dir / f"{spectrum_name}_noised.csv", np.column_stack((spectrum_data['frequencies'], noised_est.numpy())))
125
+
126
+ # Save response and attention if available
127
+ if 'response' in out:
128
+ plt.figure(figsize=(10,6))
129
+ plt.plot(out['response'].cpu().squeeze(0).numpy())
130
+ plt.savefig(extra_spectra_dir / f"{spectrum_name}_response.png")
131
+ np.savetxt(extra_spectra_dir / f"{spectrum_name}_response.csv", out['response'].cpu().squeeze(0).numpy())
132
+
133
+ if "attention" in out:
134
+ plt.figure(figsize=(10, 6))
135
+ plt.plot(out['attention'].cpu().numpy())
136
+ plt.savefig(extra_spectra_dir / f"{spectrum_name}_attention.png")
137
+ np.savetxt(extra_spectra_dir / f"{spectrum_name}_attention.csv", out['attention'].cpu().numpy())
138
+
139
+ plt.close("all")
140
+
141
+
142
  for i_stage, training_stage in enumerate(config.training):
143
  if model_weights_file.is_file():
144
  model.load_state_dict(torch.load(model_weights_file, weights_only=True))