Spaces:
Sleeping
Sleeping
Marek Bukowicki
commited on
Commit
·
cce557d
1
Parent(s):
20ba4e5
add experimental spectra evaluation
Browse files
train.py
CHANGED
|
@@ -31,6 +31,28 @@ if (run_dir / "train.txt").is_file():
|
|
| 31 |
else:
|
| 32 |
minimum = float("inf")
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
# initialization
|
| 35 |
model = instantiate({"_target_": f"shimnet.models.{config.model.name}", **config.model.kwargs}).to(device)
|
| 36 |
model_weights_file = run_dir / f'model.pt'
|
|
@@ -79,6 +101,44 @@ def evaluate_model(stage=0, epoch=0):
|
|
| 79 |
|
| 80 |
plt.close("all")
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
for i_stage, training_stage in enumerate(config.training):
|
| 83 |
if model_weights_file.is_file():
|
| 84 |
model.load_state_dict(torch.load(model_weights_file, weights_only=True))
|
|
|
|
| 31 |
else:
|
| 32 |
minimum = float("inf")
|
| 33 |
|
| 34 |
+
# prepare spectra for evaluation
|
| 35 |
+
extra_spectra_for_evaluation = {}
|
| 36 |
+
frq_step = config.data.get("frq_step") or config.metadata.get("frq_step")
|
| 37 |
+
model_ppm_per_point = frq_step / config.metadata.spectrometer_frequency
|
| 38 |
+
|
| 39 |
+
for spectra_data in config.logging.get('extra_spectra_for_evaluation', []):
|
| 40 |
+
spectrum_file = Path(spectra_data.path)
|
| 41 |
+
spectrum_freqs_input_ppm, spectrum = np.loadtxt(spectrum_file).T
|
| 42 |
+
|
| 43 |
+
spectrometer_frequency = spectra_data.get("spectrometer_frequency")
|
| 44 |
+
if spectrometer_frequency is None: # spectrometer frequency unknown, assume the same as the model
|
| 45 |
+
spectrum_freqs = spectrum_freqs_input_ppm
|
| 46 |
+
else:
|
| 47 |
+
spectrum_freqs_model_ppm = spectrum_freqs_input_ppm * spectrometer_frequency / config.metadata.spectrometer_frequency
|
| 48 |
+
spectrum_freqs = np.arange(spectrum_freqs_model_ppm.min(), spectrum_freqs_model_ppm.max(), model_ppm_per_point)
|
| 49 |
+
spectrum = np.interp(spectrum_freqs, spectrum_freqs_model_ppm, spectrum)
|
| 50 |
+
|
| 51 |
+
extra_spectra_for_evaluation[Path(spectrum_file).stem] = {
|
| 52 |
+
'frequencies': spectrum_freqs,
|
| 53 |
+
'spectrum': spectrum,
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
# initialization
|
| 57 |
model = instantiate({"_target_": f"shimnet.models.{config.model.name}", **config.model.kwargs}).to(device)
|
| 58 |
model_weights_file = run_dir / f'model.pt'
|
|
|
|
| 101 |
|
| 102 |
plt.close("all")
|
| 103 |
|
| 104 |
+
# evaluate extra spectra
|
| 105 |
+
if len(extra_spectra_for_evaluation) > 0:
|
| 106 |
+
extra_spectra_dir = plot_dir / "extra_spectra"
|
| 107 |
+
extra_spectra_dir.mkdir(exist_ok=True, parents=True)
|
| 108 |
+
for spectrum_name, spectrum_data in extra_spectra_for_evaluation.items():
|
| 109 |
+
spectrum = torch.tensor(spectrum_data['spectrum']).to(device)
|
| 110 |
+
with torch.no_grad():
|
| 111 |
+
out = model(spectrum.unsqueeze(0))
|
| 112 |
+
noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same").cpu().squeeze(0)
|
| 113 |
+
|
| 114 |
+
plt.figure(figsize=(30,6))
|
| 115 |
+
plt.plot(spectrum_data['frequencies'], spectrum.cpu().numpy())
|
| 116 |
+
plt.plot(spectrum_data['frequencies'], out['denoised'].cpu().squeeze(0).numpy())
|
| 117 |
+
plt.savefig(extra_spectra_dir / f"{spectrum_name}_clean.png")
|
| 118 |
+
np.savetxt(extra_spectra_dir / f"{spectrum_name}_clean.csv", np.column_stack((spectrum_data['frequencies'], out['denoised'].cpu().squeeze(0).numpy())))
|
| 119 |
+
|
| 120 |
+
plt.figure(figsize=(30,6))
|
| 121 |
+
plt.plot(spectrum_data['frequencies'], spectrum.cpu().numpy())
|
| 122 |
+
plt.plot(spectrum_data['frequencies'], noised_est.numpy())
|
| 123 |
+
plt.savefig(extra_spectra_dir / f"{spectrum_name}_noised.png")
|
| 124 |
+
np.savetxt(extra_spectra_dir / f"{spectrum_name}_noised.csv", np.column_stack((spectrum_data['frequencies'], noised_est.numpy())))
|
| 125 |
+
|
| 126 |
+
# Save response and attention if available
|
| 127 |
+
if 'response' in out:
|
| 128 |
+
plt.figure(figsize=(10,6))
|
| 129 |
+
plt.plot(out['response'].cpu().squeeze(0).numpy())
|
| 130 |
+
plt.savefig(extra_spectra_dir / f"{spectrum_name}_response.png")
|
| 131 |
+
np.savetxt(extra_spectra_dir / f"{spectrum_name}_response.csv", out['response'].cpu().squeeze(0).numpy())
|
| 132 |
+
|
| 133 |
+
if "attention" in out:
|
| 134 |
+
plt.figure(figsize=(10, 6))
|
| 135 |
+
plt.plot(out['attention'].cpu().numpy())
|
| 136 |
+
plt.savefig(extra_spectra_dir / f"{spectrum_name}_attention.png")
|
| 137 |
+
np.savetxt(extra_spectra_dir / f"{spectrum_name}_attention.csv", out['attention'].cpu().numpy())
|
| 138 |
+
|
| 139 |
+
plt.close("all")
|
| 140 |
+
|
| 141 |
+
|
| 142 |
for i_stage, training_stage in enumerate(config.training):
|
| 143 |
if model_weights_file.is_file():
|
| 144 |
model.load_state_dict(torch.load(model_weights_file, weights_only=True))
|