Marek Bukowicki commited on
Commit
73942d1
·
1 Parent(s): 0268fba

add experimental model M-E01

Browse files
configs/shimnet_600_M-E01.yaml ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ _target_: shimnet.models.ShimnetModular
3
+ encoder:
4
+ _target_: shimnet.models.ConvEncoder
5
+ hidden_dim: 64
6
+ output_dim: 128
7
+ activation: gelu
8
+ kernel_size: 7
9
+ local_feature_processor:
10
+ _target_: shimnet.models.ConvMLP
11
+ input_dim: 128
12
+ output_dim: 64
13
+ hidden_dims:
14
+ - 256
15
+ - 128
16
+ activation: gelu
17
+ attention_module:
18
+ _target_: shimnet.models.KVAttention
19
+ kv_dim: 32
20
+ num_heads: 8
21
+ k_processor:
22
+ _target_: shimnet.models.ConvMLP
23
+ input_dim: 128
24
+ output_dim: 256
25
+ hidden_dims:
26
+ - 512
27
+ - 256
28
+ activation: gelu
29
+ v_processor:
30
+ _target_: shimnet.models.MLP
31
+ input_dim: 128
32
+ output_dim: 256
33
+ hidden_dims:
34
+ - 512
35
+ - 256
36
+ activation: gelu
37
+ global_feature_processor:
38
+ _target_: shimnet.models.MLP
39
+ input_dim: 256
40
+ output_dim: 64
41
+ hidden_dims:
42
+ - 512
43
+ - 256
44
+ activation: gelu
45
+ response_head:
46
+ _target_: shimnet.models.MLP
47
+ input_dim: 256
48
+ output_dim: 81
49
+ hidden_dims:
50
+ - 512
51
+ - 256
52
+ activation: gelu
53
+ decoder:
54
+ _target_: shimnet.models.ConvDecoder
55
+ input_dim: 128
56
+ hidden_dim: 128
57
+ activation: gelu
58
+ kernel_size: 7
59
+ last_bias: false
60
+ last_activation: false
61
+ training:
62
+ #- batch_size: 64
63
+ #learning_rate: 0.001
64
+ #max_iters: 1600000
65
+ - batch_size: 256
66
+ learning_rate: 0.0001
67
+ max_iters: 25600000
68
+ - batch_size: 256
69
+ learning_rate: 0.00002
70
+ max_iters: 12800000
71
+ losses:
72
+ clean:
73
+ function: mae
74
+ weight: 1.0
75
+ noised:
76
+ function: mae
77
+ weight: 1.0
78
+ response:
79
+ function: mae
80
+ weight: 1.0
81
+ data:
82
+ _target_: shimnet.generators.Generator
83
+ include_response_function: true
84
+ seed: null # null means random seed
85
+ batch_size: null # to be set in training script
86
+ clean_spectra_generator:
87
+ _target_: shimnet.generators.TheoreticalMultipletSpectraGenerator
88
+ pixels: 2048
89
+ frq_step: ${metadata.frq_step}
90
+ peaks_parameter_generator:
91
+ _target_: shimnet.generators.PeaksParameterDataGenerator
92
+ atom_groups_data_file: data/multiplets_10000_parsed.txt
93
+ number_of_signals_min: 2
94
+ number_of_signals_max: 5
95
+ relative_frequency_min: -0.4
96
+ relative_frequency_max: 0.4
97
+ spectrum_width_min: 0.2
98
+ spectrum_width_max: 1.0
99
+ relative_width_min: 1.0
100
+ relative_width_max: 2.0
101
+ relative_height_min: 0.5
102
+ relative_height_max: 4
103
+ thf_min: 0.5
104
+ thf_max: 2
105
+ trf_min: 0.0
106
+ trf_max: 1.0
107
+ multiplicity_j1_min: 0.0
108
+ multiplicity_j1_max: 15
109
+ multiplicity_j2_min: 0.0
110
+ multiplicity_j2_max: 15
111
+ response_generator:
112
+ _target_: shimnet.generators.ResponseGenerator
113
+ response_function_library:
114
+ _target_: shimnet.generators.ResponseLibrary
115
+ response_files:
116
+ - data/smoothed_scrf_kernels/scrf_81_600MHz_smoothed_1-1-1.pt
117
+ - data/smoothed_scrf_kernels/scrf_81_600MHz_smoothed_1-2-1.pt
118
+ - data/smoothed_scrf_kernels/scrf_81_600MHz_smoothed_1-4-1.pt
119
+ - data/smoothed_scrf_kernels/scrf_81_600MHz_smoothed_1-3-3-1.pt
120
+ pad_to: 81
121
+ response_function_stretch_min: 0.8
122
+ response_function_stretch_max: 1.5
123
+ response_function_noise: 0.0
124
+ flip_response_function: false
125
+ noise_generator:
126
+ _target_: shimnet.generators.NoiseGenerator
127
+ spectrum_noise_min: 0.0
128
+ spectrum_noise_max: 0.015625
129
+ logging:
130
+ step: 1000000
131
+ num_plots: 32
132
+ metadata:
133
+ frq_step: 0.30048
134
+ spectrometer_frequency: 600.0
download_files.py CHANGED
@@ -10,6 +10,10 @@ ALL_FILES_TO_DOWNLOAD = {
10
  {
11
  "url": "https://drive.google.com/uc?export=download&id=1_VxOpFGJcFsOa5DHOW2GJbP8RvHCmC1N",
12
  "destination": "weights/shimnet_600MHz.pt"
 
 
 
 
13
  }],
14
  "SCRF": [{
15
  "url": "https://drive.google.com/uc?export=download&id=113al7A__yYALx_2hkESuzFIDU3feVtNY",
 
10
  {
11
  "url": "https://drive.google.com/uc?export=download&id=1_VxOpFGJcFsOa5DHOW2GJbP8RvHCmC1N",
12
  "destination": "weights/shimnet_600MHz.pt"
13
+ },
14
+ {
15
+ "url": "https://drive.google.com/uc?export=download&id=1643Il3qgCupY0n8Mar6WBc2WVuoQRzie",
16
+ "destination": "weights/shimnet_600MHz_M-E01.pt"
17
  }],
18
  "SCRF": [{
19
  "url": "https://drive.google.com/uc?export=download&id=113al7A__yYALx_2hkESuzFIDU3feVtNY",
predict-gui.py CHANGED
@@ -140,7 +140,7 @@ with gr.Blocks() as app:
140
  with gr.Column():
141
  model_selection = gr.Radio(
142
  label="Select Model",
143
- choices=["600 MHz", "700 MHz", "Custom"],
144
  value="600 MHz"
145
  )
146
  config_file = gr.File(label="Custom Config File (.yaml)", visible=False, height=120)
@@ -189,6 +189,9 @@ with gr.Blocks() as app:
189
  elif model_selection == "700 MHz":
190
  config_file = os.path.join(os.path.dirname(__file__), "configs/shimnet_700.yaml")
191
  weights_file = os.path.join(os.path.dirname(__file__), "weights/shimnet_700MHz.pt")
 
 
 
192
  else:
193
  config_file = config_file.name
194
  weights_file = weights_file.name
 
140
  with gr.Column():
141
  model_selection = gr.Radio(
142
  label="Select Model",
143
+ choices=["600 MHz", "700 MHz", "M-E01", "Custom"],
144
  value="600 MHz"
145
  )
146
  config_file = gr.File(label="Custom Config File (.yaml)", visible=False, height=120)
 
189
  elif model_selection == "700 MHz":
190
  config_file = os.path.join(os.path.dirname(__file__), "configs/shimnet_700.yaml")
191
  weights_file = os.path.join(os.path.dirname(__file__), "weights/shimnet_700MHz.pt")
192
+ elif model_selection == "M-E01":
193
+ config_file = os.path.join(os.path.dirname(__file__), "configs/shimnet_600_M-E01.yaml")
194
+ weights_file = os.path.join(os.path.dirname(__file__), "weights/shimnet_600MHz_M-E01.pt")
195
  else:
196
  config_file = config_file.name
197
  weights_file = weights_file.name