donghufeng commited on
Commit
d57fabf
·
1 Parent(s): 7fb034a
Quantispect_RF13_v1.0.10.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c899ee6674d1d78bb7570c5284086b4dfc5a2d3aba63fbac6ded0beaecfb831e
3
+ size 2693053
README.md CHANGED
@@ -1,3 +1,216 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: ising-decoding
3
+ tags:
4
+ - quantum
5
+ - qec
6
+ - error_correction
7
+ - decoders
8
+ - surface_code
9
+ - predecoder
10
+ license: apache-2.0
11
+ ---
12
+
13
+ # Quantispect Overview
14
+
15
+ ![Quantispect Neural Pre-Decoder Architecture](framework.png)
16
+
17
+ ## Model Summary
18
+
19
+ | Item | Value |
20
+ |---|---:|
21
+ | Model name | Quantispect |
22
+ | Checkpoint file | `Quantispect_RF13_v1.0.10.pt` |
23
+ | Total parameters | ~0.663M |
24
+ | Checkpoint size | ~2.63 MB |
25
+ | Architecture | FastHyper-style 3D CNN neural pre-decoder |
26
+ | Receptive field | R=13 |
27
+ | Input tensor | `(B, 4, T, D, D)` |
28
+ | Output tensor | `(B, 4, T, D, D)` |
29
+ | Release date | April 26, 2026 |
30
+
31
+ ## Description:
32
+
33
+ Quantispect is a compact neural pre-decoder for rotated surface-code quantum error correction. It consumes five-dimensional syndrome volumes across batch, channel, time, and two spatial dimensions, and predicts local correction maps that are consumed by a downstream global decoder such as MWPM / PyMatching or an Ising-decoding post-processing pipeline.
34
+
35
+ Quantispect is designed to run inside an NVIDIA Ising-Decoding-compatible workflow after applying the Quantispect code patch included with this model release.
36
+
37
+ ## Model Architecture:
38
+
39
+ Architecture type: 3D Convolutional Neural Network (3D CNN)
40
+
41
+ Network architecture: custom multi-branch spatio-temporal 3D CNN with residual FastHyper blocks.
42
+
43
+ ### Input
44
+
45
+ Input shape:
46
+
47
+ ```text
48
+ (B, 4, T, D, D)
49
+ ```
50
+
51
+ ### Stem
52
+
53
+ ```text
54
+ Conv3D 4 -> 96, kernel 3x3x3
55
+ GroupNorm
56
+ GELU
57
+ ```
58
+
59
+ Stem output shape:
60
+
61
+ ```text
62
+ (B, 96, T, D, D)
63
+ ```
64
+
65
+ ### Main Body
66
+
67
+ The main body contains five repeated `FastHyperBlock` modules:
68
+
69
+ ```text
70
+ FastHyperBlock x5
71
+ ```
72
+
73
+ Each `FastHyperBlock` first expands the feature width from 96 to 144 channels with a 1x1x1 convolution, then applies three parallel feature extraction branches:
74
+
75
+ ```text
76
+ Pre-projection: GroupNorm -> 1x1x1 Conv3D, 96 -> 144 -> GELU
77
+
78
+ Branch A: Depthwise Conv3D, kernel 1x3x3, spatial branch
79
+ Branch B: Depthwise Conv3D, kernel 3x1x1, temporal branch
80
+ Branch C: GroupNorm -> Grouped Conv3D, kernel 3x3x3, groups=6, joint local spatio-temporal branch
81
+ ```
82
+
83
+ The three branch outputs are aligned and fused by element-wise summation rather than channel concatenation. The fused feature is then projected and recalibrated:
84
+
85
+ ```text
86
+ Element-wise sum fusion
87
+ 1x1x1 Conv3D projection, 144 -> 96
88
+ GELU
89
+ ChannelGate / SE-style channel attention
90
+ Dropout3D
91
+ Residual connection
92
+ ```
93
+
94
+ Main body output shape:
95
+
96
+ ```text
97
+ (B, 96, T, D, D)
98
+ ```
99
+
100
+ ### Head
101
+
102
+ ```text
103
+ GroupNorm
104
+ 1x1x1 Conv3D, 96 -> 96
105
+ GELU
106
+ 1x1x1 Conv3D, 96 -> 4
107
+ ```
108
+
109
+ Output shape:
110
+
111
+ ```text
112
+ (B, 4, T, D, D)
113
+ ```
114
+
115
+ The output maps are used by the residual-syndrome construction module and then passed to MWPM / Ising-decoder post-processing.
116
+
117
+ ## Usage:
118
+
119
+ Quantispect is intended to be used with the NVIDIA Ising-Decoding environment:
120
+
121
+ ```text
122
+ https://github.com/NVIDIA/Ising-Decoding
123
+ ```
124
+
125
+ A clean NVIDIA Ising-Decoding checkout does not natively know the Quantispect / FastHyper architecture. To run `Quantispect_RF13_v1.0.10.pt`, first apply the Quantispect code patch included in this model repository.
126
+
127
+
128
+ ### Required code patch files
129
+
130
+ The patch package should preserve the following relative paths:
131
+
132
+ ```text
133
+ quantispect_code_patch/
134
+ ├── conf/
135
+ │ └── config_public.yaml
136
+ └── code/
137
+ ├── model/
138
+ │ ├── predecoder_fasthyper_rf13_v1.py
139
+ │ ├── factory.py
140
+ │ └── registry.py
141
+ ├── workflows/
142
+ │ ├── config_validator.py
143
+ │ └── run.py
144
+ └── scripts/
145
+ └── local_run.sh
146
+ ```
147
+
148
+ These files should be copied into the NVIDIA Ising-Decoding repository with the same relative paths:
149
+
150
+ ```text
151
+ conf/config_public.yaml -> Ising-Decoding/conf/config_public.yaml
152
+ code/model/predecoder_fasthyper_rf13_v1.py -> Ising-Decoding/code/model/predecoder_fasthyper_rf13_v1.py
153
+ code/model/factory.py -> Ising-Decoding/code/model/factory.py
154
+ code/model/registry.py -> Ising-Decoding/code/model/registry.py
155
+ code/workflows/config_validator.py -> Ising-Decoding/code/workflows/config_validator.py
156
+ code/workflows/run.py -> Ising-Decoding/code/workflows/run.py
157
+ code/scripts/local_run.sh -> Ising-Decoding/code/scripts/local_run.sh
158
+ ```
159
+
160
+ The patch mainly adds the `predecoder_fasthyper_rf13_v1` model implementation, registers `model_id: 6`, adds the Quantispect model hyperparameters to `config_public.yaml`, and enables explicit `.pt` checkpoint loading through `model_checkpoint_file`.
161
+
162
+ ### Apply the patch
163
+
164
+ From the directory containing both the clean NVIDIA Ising-Decoding repository and this downloaded patch package:
165
+
166
+ ```bash
167
+ cp -r code/* Ising-Decoding/code/
168
+ cp -r conf/* Ising-Decoding/conf/
169
+ ```
170
+
171
+ Then place the Quantispect checkpoint under the repository model directory:
172
+
173
+ ```bash
174
+ mkdir -p Ising-Decoding/models
175
+ cp Quantispect_RF13_v1.0.10.pt Ising-Decoding/models/Quantispect_RF13_v1.0.10.pt
176
+ ```
177
+
178
+ Expected directory layout:
179
+
180
+ ```text
181
+ Ising-Decoding/
182
+ ├── code/
183
+ │ ├── model/
184
+ │ │ └── predecoder_fasthyper_rf13_v1.py
185
+ │ ├── workflows/
186
+ │ │ ├── config_validator.py
187
+ │ │ └── run.py
188
+ │ └── scripts/
189
+ │ └── local_run.sh
190
+ ├── conf/
191
+ │ └── config_public.yaml
192
+ ├── models/
193
+ │ └── Quantispect_RF13_v1.0.10.pt
194
+ └── README.md
195
+ ```
196
+
197
+ ## Inference Deployment:
198
+
199
+ Configure the NVIDIA Ising-Decoding repository for inference, apply the Quantispect patch files above, and place the downloaded model checkpoint at `models/Quantispect_RF13_v1.0.10.pt`.
200
+
201
+ Run from the repository root:
202
+
203
+ ```bash
204
+ cd Ising-Decoding
205
+
206
+ CUDA_VISIBLE_DEVICES=0,1,2,3 \
207
+ PYTHONUNBUFFERED=1 \
208
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
209
+ WORKFLOW=inference \
210
+ EXPERIMENT_NAME=infer_quantispect \
211
+ TORCH_COMPILE=0 \
212
+ EXTRA_PARAMS="+model_checkpoint_file=models/Quantispect_RF13_v1.0.10.pt" \
213
+ bash code/scripts/local_run.sh \
214
+ 2>&1 | tee infer_quantispect.log
215
+ ```
216
+
bias_subcard.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Bias Subcard
2
+
3
+ Field | Response
4
+ :-----|:---------
5
+ Participation considerations from adversely impacted groups [protected classes](https://www.senate.ca.gov/content/protected-classes) in model design and testing: | Not Applicable
6
+ Measures taken to mitigate against unwanted bias: | Not Applicable
7
+ Bias Metric (If Measured): | Not Applicable
code/model/factory.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Factory module for creating models.
17
+
18
+ Provides ModelFactory for instantiating pre-decoder models from config.
19
+ """
20
+
21
+
22
+ class ModelFactory:
23
+
24
+ @staticmethod
25
+ def create_model(cfg):
26
+ if cfg.code == "surface":
27
+ return ModelFactory._create_surface_model(cfg)
28
+ else:
29
+ raise ValueError("Invalid model name")
30
+
31
+ @staticmethod
32
+ def _create_surface_model(cfg):
33
+ if cfg.model.version == "predecoder_memory_v1":
34
+ from model.predecoder import PreDecoderModelMemory_v1
35
+ model = PreDecoderModelMemory_v1(cfg)
36
+ return model
37
+ elif cfg.model.version == "predecoder_sd_litenet_v1":
38
+ from model.predecoder_sd_litenet_v1 import PredecoderSDLiteNetV1
39
+ model = PredecoderSDLiteNetV1(
40
+ input_channels=getattr(cfg.model, "input_channels", 4),
41
+ out_channels=getattr(cfg.model, "out_channels", 4),
42
+ hidden_dim=getattr(cfg.model, "hidden_dim", 64),
43
+ bottleneck_dim=getattr(cfg.model, "bottleneck_dim", 16),
44
+ dropout_p=getattr(cfg.model, "dropout_p", 0.05),
45
+ )
46
+ return model
47
+ elif cfg.model.version == "predecoder_fasthyper_rf13_v1":
48
+ from model.predecoder_fasthyper_rf13_v1 import PredecoderFastHyperRF13V1
49
+ model = PredecoderFastHyperRF13V1(
50
+ input_channels=getattr(cfg.model, "input_channels", 4),
51
+ out_channels=getattr(cfg.model, "out_channels", 4),
52
+ hidden_dim=getattr(cfg.model, "hidden_dim", 96),
53
+ mid_dim=getattr(cfg.model, "mid_dim", 144),
54
+ mix_groups=getattr(cfg.model, "mix_groups", 6),
55
+ num_blocks=getattr(cfg.model, "num_blocks", 5),
56
+ stem_kernel_size=getattr(cfg.model, "stem_kernel_size", 3),
57
+ dropout_p=getattr(cfg.model, "dropout_p", 0.02),
58
+ gate_reduction=getattr(cfg.model, "gate_reduction", 4),
59
+ )
60
+ return model
61
+ else:
62
+ raise ValueError(f"Invalid model version: {cfg.model.version}")
code/model/predecoder_fasthyper_rf13_v1.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ def _choose_gn_groups(channels: int, max_groups: int = 8) -> int:
10
+ for g in range(min(max_groups, channels), 0, -1):
11
+ if channels % g == 0:
12
+ return g
13
+ return 1
14
+
15
+
16
+ class _ChannelGate(nn.Module):
17
+ def __init__(self, channels: int, reduction: int = 4) -> None:
18
+ super().__init__()
19
+ hidden = max(channels // reduction, 8)
20
+ self.pool = nn.AdaptiveAvgPool3d(1)
21
+ self.fc1 = nn.Conv3d(channels, hidden, kernel_size=1, bias=True)
22
+ self.act = nn.GELU()
23
+ self.fc2 = nn.Conv3d(hidden, channels, kernel_size=1, bias=True)
24
+ self.gate = nn.Sigmoid()
25
+
26
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
27
+ s = self.pool(x)
28
+ s = self.fc1(s)
29
+ s = self.act(s)
30
+ s = self.fc2(s)
31
+ return x * self.gate(s)
32
+
33
+
34
+ class _FastHyperBlock(nn.Module):
35
+ """
36
+ Efficient RF-expanding residual block.
37
+
38
+ Each block contributes one effective k=3 receptive-field expansion stage via
39
+ three parallel branches operating on the same expanded activation:
40
+ - spatial depthwise (1,3,3)
41
+ - temporal depthwise (3,1,1)
42
+ - grouped 3D mixing (3,3,3)
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ channels: int,
48
+ mid_dim: int,
49
+ mix_groups: int = 6,
50
+ dropout_p: float = 0.02,
51
+ gate_reduction: int = 4,
52
+ ) -> None:
53
+ super().__init__()
54
+ gn1 = _choose_gn_groups(channels)
55
+ gn2 = _choose_gn_groups(mid_dim)
56
+ mix_groups = max(1, min(mix_groups, mid_dim))
57
+ while mid_dim % mix_groups != 0 and mix_groups > 1:
58
+ mix_groups -= 1
59
+
60
+ self.pre = nn.Sequential(
61
+ nn.GroupNorm(gn1, channels),
62
+ nn.Conv3d(channels, mid_dim, kernel_size=1, bias=True),
63
+ nn.GELU(),
64
+ )
65
+ self.spatial = nn.Sequential(
66
+ nn.Conv3d(
67
+ mid_dim,
68
+ mid_dim,
69
+ kernel_size=(1, 3, 3),
70
+ padding=(0, 1, 1),
71
+ groups=mid_dim,
72
+ bias=True,
73
+ ),
74
+ nn.GELU(),
75
+ )
76
+ self.temporal = nn.Sequential(
77
+ nn.Conv3d(
78
+ mid_dim,
79
+ mid_dim,
80
+ kernel_size=(3, 1, 1),
81
+ padding=(1, 0, 0),
82
+ groups=mid_dim,
83
+ bias=True,
84
+ ),
85
+ nn.GELU(),
86
+ )
87
+ self.mixed = nn.Sequential(
88
+ nn.GroupNorm(gn2, mid_dim),
89
+ nn.Conv3d(
90
+ mid_dim,
91
+ mid_dim,
92
+ kernel_size=3,
93
+ padding=1,
94
+ groups=mix_groups,
95
+ bias=True,
96
+ ),
97
+ nn.GELU(),
98
+ )
99
+ self.fuse = nn.Sequential(
100
+ nn.Conv3d(mid_dim, channels, kernel_size=1, bias=True),
101
+ nn.GELU(),
102
+ )
103
+ self.gate = _ChannelGate(channels, reduction=gate_reduction)
104
+ self.dropout = nn.Dropout3d(dropout_p) if dropout_p > 0 else nn.Identity()
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ h = self.pre(x)
108
+ h = self.spatial(h) + self.temporal(h) + self.mixed(h)
109
+ h = self.fuse(h)
110
+ h = self.gate(h)
111
+ h = self.dropout(h)
112
+ return x + h
113
+
114
+
115
+ class PredecoderFastHyperRF13V1(nn.Module):
116
+ """
117
+ Faster-stronger candidate for model 6 under the public Ising-Decoding API.
118
+
119
+ Input / output shape:
120
+ (B, 4, T, D, D) -> (B, 4, T, D, D)
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ input_channels: int = 4,
126
+ out_channels: int = 4,
127
+ hidden_dim: int = 96,
128
+ mid_dim: int = 144,
129
+ mix_groups: int = 6,
130
+ num_blocks: int = 5,
131
+ stem_kernel_size: int = 3,
132
+ dropout_p: float = 0.02,
133
+ gate_reduction: int = 4,
134
+ **_: Any,
135
+ ) -> None:
136
+ super().__init__()
137
+ pad = stem_kernel_size // 2
138
+ gn = _choose_gn_groups(hidden_dim)
139
+ self.stem = nn.Sequential(
140
+ nn.Conv3d(
141
+ input_channels,
142
+ hidden_dim,
143
+ kernel_size=stem_kernel_size,
144
+ padding=pad,
145
+ bias=True,
146
+ ),
147
+ nn.GroupNorm(gn, hidden_dim),
148
+ nn.GELU(),
149
+ )
150
+ self.blocks = nn.Sequential(*[
151
+ _FastHyperBlock(
152
+ channels=hidden_dim,
153
+ mid_dim=mid_dim,
154
+ mix_groups=mix_groups,
155
+ dropout_p=dropout_p,
156
+ gate_reduction=gate_reduction,
157
+ ) for _ in range(num_blocks)
158
+ ])
159
+ self.head = nn.Sequential(
160
+ nn.GroupNorm(gn, hidden_dim),
161
+ nn.Conv3d(hidden_dim, hidden_dim, kernel_size=1, bias=True),
162
+ nn.GELU(),
163
+ nn.Conv3d(hidden_dim, out_channels, kernel_size=1, bias=True),
164
+ )
165
+
166
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
167
+ x = self.stem(x)
168
+ x = self.blocks(x)
169
+ x = self.head(x)
170
+ return x
code/model/registry.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Public model registry for the early-access public release.
17
+
18
+ External users choose `model_id` in {1..6}. This registry maps model_id to:
19
+ - the underlying architecture parameters (num_filters, kernel_size)
20
+ - the model receptive field R (in rounds / distance units)
21
+
22
+ Receptive field convention matches `compare_receptive_field_with_window_data`
23
+ in `code/training/utils.py`:
24
+ R = 1 + sum_i (k_i - 1) for kernel sizes k_i (assumed odd, with same-padding)
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ from dataclasses import dataclass
30
+ from typing import Dict, List
31
+
32
+
33
+ def compute_receptive_field(kernel_sizes: List[int]) -> int:
34
+ """Compute receptive field R from a list of kernel sizes."""
35
+ if not kernel_sizes:
36
+ raise ValueError("kernel_sizes must be non-empty")
37
+ if any(not isinstance(k, int) for k in kernel_sizes):
38
+ raise ValueError(f"kernel_sizes must be ints, got: {kernel_sizes!r}")
39
+ if any(k <= 0 for k in kernel_sizes):
40
+ raise ValueError(f"kernel_sizes must be positive, got: {kernel_sizes!r}")
41
+ return 1 + sum(kernel_sizes) - len(kernel_sizes)
42
+
43
+
44
+ @dataclass(frozen=True)
45
+ class PublicModelSpec:
46
+ model_id: int
47
+ num_filters: List[int]
48
+ kernel_size: List[int]
49
+ receptive_field: int
50
+ model_version: str = "predecoder_memory_v1"
51
+
52
+
53
+ _MODEL_SPECS: Dict[int, PublicModelSpec] = {
54
+ 1:
55
+ PublicModelSpec(
56
+ model_id=1,
57
+ num_filters=[128, 128, 128, 4],
58
+ kernel_size=[3, 3, 3, 3],
59
+ receptive_field=compute_receptive_field([3, 3, 3, 3]),
60
+ ),
61
+ 2:
62
+ PublicModelSpec(
63
+ model_id=2,
64
+ num_filters=[256, 256, 256, 4],
65
+ kernel_size=[3, 3, 3, 3],
66
+ receptive_field=compute_receptive_field([3, 3, 3, 3]),
67
+ ),
68
+ 3:
69
+ PublicModelSpec(
70
+ model_id=3,
71
+ num_filters=[128, 128, 128, 4],
72
+ kernel_size=[5, 5, 5, 5],
73
+ receptive_field=compute_receptive_field([5, 5, 5, 5]),
74
+ ),
75
+ 4:
76
+ PublicModelSpec(
77
+ model_id=4,
78
+ num_filters=[128, 128, 128, 128, 128, 4],
79
+ kernel_size=[3, 3, 3, 3, 3, 3],
80
+ receptive_field=compute_receptive_field([3, 3, 3, 3, 3, 3]),
81
+ ),
82
+ 5:
83
+ PublicModelSpec(
84
+ model_id=5,
85
+ num_filters=[256, 256, 256, 256, 256, 4],
86
+ kernel_size=[3, 3, 3, 3, 3, 3],
87
+ receptive_field=compute_receptive_field([3, 3, 3, 3, 3, 3]),
88
+ ),
89
+ 6:
90
+ PublicModelSpec(
91
+ model_id=6,
92
+ num_filters=[96, 96, 96, 96, 96, 4],
93
+ kernel_size=[3, 3, 3, 3, 3, 3],
94
+ receptive_field=compute_receptive_field([3, 3, 3, 3, 3, 3]),
95
+ model_version="predecoder_fasthyper_rf13_v1",
96
+ ),
97
+ }
98
+
99
+
100
+ def get_model_spec(model_id: int) -> PublicModelSpec:
101
+ """Return the public model spec for a given model_id (1..6)."""
102
+ try:
103
+ mid = int(model_id)
104
+ except Exception as e:
105
+ raise ValueError(f"model_id must be an int in [1..6], got: {model_id!r}") from e
106
+ if mid == 0:
107
+ raise ValueError("model_id=0 is not supported in the public release")
108
+ if mid not in _MODEL_SPECS:
109
+ raise ValueError(f"model_id must be in [1..6], got: {mid}")
110
+ return _MODEL_SPECS[mid]
code/scripts/local_run.sh ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ set -euo pipefail
18
+
19
+ # Minimal local runner.
20
+ #
21
+ # Examples:
22
+ # bash code/scripts/local_run.sh
23
+ # WORKFLOW=inference bash code/scripts/local_run.sh
24
+ # GPUS=4 bash code/scripts/local_run.sh
25
+ # CUDA_VISIBLE_DEVICES=1 bash code/scripts/local_run.sh # use only GPU 1
26
+ #
27
+ # ONNX / TRT fast inference (requires tensorrt; set ONNX_WORKFLOW before running):
28
+ # ONNX_WORKFLOW=1 WORKFLOW=inference bash code/scripts/local_run.sh # export ONNX only (inspect/reuse later)
29
+ # ONNX_WORKFLOW=2 WORKFLOW=inference bash code/scripts/local_run.sh # export ONNX + build TRT + run TRT inference
30
+ # ONNX_WORKFLOW=2 QUANT_FORMAT=int8 WORKFLOW=inference bash code/scripts/local_run.sh # INT8 quantized TRT
31
+ # ONNX_WORKFLOW=2 QUANT_FORMAT=fp8 WORKFLOW=inference bash code/scripts/local_run.sh # FP8 quantized TRT (requires nvidia-modelopt)
32
+ # ONNX_WORKFLOW=3 WORKFLOW=inference bash code/scripts/local_run.sh # load pre-built engine, skip export
33
+ #
34
+ # Decoder ablation study with cudaq-qec global decoders (requires cudaq-qec):
35
+ # WORKFLOW=decoder_ablation bash code/scripts/local_run.sh
36
+ #
37
+ # Decoder ablation with TRT pre-decoder + cudaq-qec global decoders
38
+ # (combines fast TRT inference for the neural pre-decoder with GPU-accelerated
39
+ # cudaq-qec decoders for the residual syndromes — full GPU pipeline end-to-end):
40
+ # ONNX_WORKFLOW=2 WORKFLOW=decoder_ablation bash code/scripts/local_run.sh # export+build TRT, then ablation
41
+ # ONNX_WORKFLOW=3 WORKFLOW=decoder_ablation bash code/scripts/local_run.sh # load existing engine, then ablation
42
+ #
43
+ # Notes:
44
+ # - Public config is `conf/config_public.yaml`. Users should edit only that file.
45
+ # - Training knobs are auto-managed in code (epochs, shots/epoch, batch schedule, etc.).
46
+ # - SafeTensors (optional): after training, convert the best .pt checkpoint with
47
+ # code/export/checkpoint_to_safetensors.py (see README), then pass the result as:
48
+ # PREDECODER_SAFETENSORS_CHECKPOINT=<path>.safetensors WORKFLOW=inference bash code/scripts/local_run.sh
49
+
50
+ EXPERIMENT_NAME="${EXPERIMENT_NAME:-test1}"
51
+ CONFIG_NAME="${CONFIG_NAME:-config_public}" # conf/<name>.yaml (no extension)
52
+ WORKFLOW="${WORKFLOW:-train}" # train | inference
53
+ WORKFLOW="$(echo "${WORKFLOW}" | tr '[:upper:]' '[:lower:]')"
54
+ GPUS="${GPUS:-}" # if empty, auto-detect
55
+ FRESH_START="${FRESH_START:-0}" # 1 => don't load checkpoint
56
+ EXTRA_PARAMS="${EXTRA_PARAMS:-}" # advanced hydra overrides (discouraged)
57
+ TORCH_COMPILE="${TORCH_COMPILE:-}" # 0/1 to disable/enable torch.compile
58
+ TORCH_COMPILE_MODE="${TORCH_COMPILE_MODE:-}" # optional: default | reduce-overhead | max-autotune
59
+
60
+ DISTANCE="${DISTANCE:-}"
61
+ N_ROUNDS="${N_ROUNDS:-}"
62
+ if [ $# -eq 1 ]; then DISTANCE="$1"; fi
63
+ if [ $# -eq 2 ]; then DISTANCE="$1"; N_ROUNDS="$2"; fi
64
+
65
+ SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)"
66
+ # local_run.sh lives at: <repo_root>/code/scripts/local_run.sh
67
+ # so repo_root is two levels up from SCRIPT_DIR.
68
+ REPO_ROOT="$(cd -- "${SCRIPT_DIR}/../.." && pwd)"
69
+ CODE_ROOT="${CODE_ROOT:-${REPO_ROOT}/code}"
70
+
71
+ # Default output locations live inside the repo (avoid surprises from generic env vars).
72
+ # Some environments set BASE_OUTPUT_DIR/LOG_BASE_DIR globally; ignore those by default to
73
+ # prevent creating confusing extra folders like /root/outputs or /root/logs.
74
+ if [ -n "${BASE_OUTPUT_DIR:-}" ] || [ -n "${LOG_BASE_DIR:-}" ]; then
75
+ echo "[local_run.sh] Note: ignoring BASE_OUTPUT_DIR/LOG_BASE_DIR from the environment."
76
+ echo "[local_run.sh] To override paths, use PREDECODER_BASE_OUTPUT_DIR / PREDECODER_LOG_BASE_DIR."
77
+ fi
78
+ BASE_OUTPUT_DIR="${PREDECODER_BASE_OUTPUT_DIR:-${REPO_ROOT}/outputs}"
79
+ LOG_BASE_DIR="${PREDECODER_LOG_BASE_DIR:-${REPO_ROOT}/logs}"
80
+ mkdir -p "${BASE_OUTPUT_DIR}" "${LOG_BASE_DIR}"
81
+
82
+ if [ "${FRESH_START}" -eq 1 ]; then
83
+ RESUME_FLAG="++load_checkpoint=False"
84
+ else
85
+ RESUME_FLAG="++load_checkpoint=True"
86
+ fi
87
+
88
+ # GPU-only runs: require a visible GPU and nvidia-smi.
89
+ if ! command -v nvidia-smi >/dev/null 2>&1; then
90
+ echo "[local_run.sh] Error: GPU-only mode requires nvidia-smi on PATH." >&2
91
+ echo "[local_run.sh] Hint: run on a GPU host or pass CUDA_VISIBLE_DEVICES." >&2
92
+ exit 1
93
+ fi
94
+
95
+ # Respect CUDA_VISIBLE_DEVICES if set; otherwise auto-detect via nvidia-smi.
96
+ if [ -z "${GPUS}" ]; then
97
+ if [ -n "${CUDA_VISIBLE_DEVICES:-}" ]; then
98
+ GPUS="$(python3 - <<'PY'
99
+ import os
100
+ v=os.environ.get('CUDA_VISIBLE_DEVICES','').strip()
101
+ print(len([x for x in v.split(',') if x.strip()]) or 1)
102
+ PY
103
+ )"
104
+ else
105
+ GPUS="$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l | tr -d ' ')"
106
+ fi
107
+ fi
108
+
109
+ if [ "${GPUS}" -le 0 ]; then
110
+ echo "[local_run.sh] Error: no GPUs detected. GPU-only mode requires CUDA." >&2
111
+ exit 1
112
+ fi
113
+
114
+ if [ -z "${MASTER_PORT:-}" ]; then
115
+ MASTER_PORT="$(python3 - <<'PY'
116
+ import socket
117
+ s=socket.socket()
118
+ s.bind(('127.0.0.1', 0))
119
+ print(s.getsockname()[1])
120
+ s.close()
121
+ PY
122
+ )"
123
+ export MASTER_PORT
124
+ fi
125
+
126
+ TIMESTAMP="$(date +%Y%m%d_%H%M%S)"
127
+ # Add nanoseconds to avoid collisions when launching multiple runs within the same second.
128
+ TIMESTAMP_NS="$(date +%Y%m%d_%H%M%S_%N)"
129
+ RUN_ID="${EXPERIMENT_NAME}_${TIMESTAMP}"
130
+ LOG_DIR="${LOG_BASE_DIR}/${RUN_ID}"
131
+ OUTPUT_DIR="${BASE_OUTPUT_DIR}/${EXPERIMENT_NAME}"
132
+ CHECKPOINT_DIR="${OUTPUT_DIR}/models"
133
+ mkdir -p "${LOG_DIR}" "${OUTPUT_DIR}" "${CHECKPOINT_DIR}"
134
+
135
+ # Force Hydra run dir to writable OUTPUT_DIR (avoids read-only repo/outputs in containers)
136
+ OVERRIDES="hydra.run.dir=${OUTPUT_DIR}"
137
+ if [ -n "${DISTANCE}" ]; then OVERRIDES+=" distance=${DISTANCE}"; fi
138
+ if [ -n "${N_ROUNDS}" ]; then OVERRIDES+=" n_rounds=${N_ROUNDS}"; fi
139
+ if [ -n "${EXTRA_PARAMS}" ]; then OVERRIDES+=" ${EXTRA_PARAMS}"; fi
140
+
141
+ CONFIG_SNAPSHOT_DIR="${OUTPUT_DIR}/config"
142
+ mkdir -p "${CONFIG_SNAPSHOT_DIR}"
143
+ CONFIG_PATH="${REPO_ROOT}/conf/${CONFIG_NAME}.yaml"
144
+ if [ -f "${CONFIG_PATH}" ]; then
145
+ # Never overwrite existing snapshots: keep full history.
146
+ base_yaml="${CONFIG_SNAPSHOT_DIR}/${CONFIG_NAME}_${TIMESTAMP_NS}.yaml"
147
+ dest_yaml="${base_yaml}"
148
+ i=0
149
+ while [ -e "${dest_yaml}" ]; do
150
+ i=$((i+1))
151
+ dest_yaml="${base_yaml%.yaml}_${i}.yaml"
152
+ done
153
+ cp "${CONFIG_PATH}" "${dest_yaml}"
154
+ # Also save the exact CLI overrides used for this run (useful when configs change over time).
155
+ base_ovr="${CONFIG_SNAPSHOT_DIR}/${CONFIG_NAME}_${TIMESTAMP_NS}.overrides.txt"
156
+ dest_ovr="${base_ovr}"
157
+ j=0
158
+ while [ -e "${dest_ovr}" ]; do
159
+ j=$((j+1))
160
+ dest_ovr="${base_ovr%.txt}_${j}.txt"
161
+ done
162
+ {
163
+ echo "workflow.task=${WORKFLOW}"
164
+ echo "exp_tag=${EXPERIMENT_NAME}"
165
+ echo "${RESUME_FLAG}"
166
+ echo "${OVERRIDES:-}"
167
+ } > "${dest_ovr}"
168
+ else
169
+ echo "[local_run.sh] Warning: could not find config file to snapshot: ${CONFIG_PATH}"
170
+ fi
171
+
172
+ echo "=========================================="
173
+ echo "Local run"
174
+ echo "=========================================="
175
+ echo "workflow.task: ${WORKFLOW}"
176
+ echo "config: ${CONFIG_NAME}"
177
+ echo "GPUS: ${GPUS} (CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-<unset>})"
178
+ echo "output: ${OUTPUT_DIR}"
179
+ echo "logs: ${LOG_DIR}"
180
+ echo "overrides: ${OVERRIDES:-<none>}"
181
+ echo "=========================================="
182
+
183
+ export PYTHONPATH="${CODE_ROOT}:${PYTHONPATH:-}"
184
+ export HDF5_USE_FILE_LOCKING=FALSE
185
+ export CUDNN_V8_API_ENABLED=1
186
+ export OMP_NUM_THREADS="$(nproc)"
187
+ export JOB_START_TIMESTAMP="$(date +%s)"
188
+ export JOB_START_DATETIME="$(date)"
189
+ if [ -n "${TORCH_COMPILE}" ]; then
190
+ export PREDECODER_TORCH_COMPILE="${TORCH_COMPILE}"
191
+ fi
192
+ if [ -n "${TORCH_COMPILE_MODE}" ]; then
193
+ export PREDECODER_TORCH_COMPILE_MODE="${TORCH_COMPILE_MODE}"
194
+ fi
195
+
196
+ # Prefer PREDECODER_PYTHON (cluster/container venv) when set
197
+ PYTHON_BIN="${PYTHON_BIN:-${PREDECODER_PYTHON:-python}}"
198
+ if ! command -v "${PYTHON_BIN}" >/dev/null 2>&1; then
199
+ if command -v python3 >/dev/null 2>&1; then
200
+ PYTHON_BIN="python3"
201
+ else
202
+ echo "[local_run.sh] Error: no python interpreter found on PATH." >&2
203
+ exit 1
204
+ fi
205
+ fi
206
+
207
+ # Ensure CUDA is usable before launching the workflow.
208
+ if ! "${PYTHON_BIN}" - <<'PY'
209
+ import sys
210
+ try:
211
+ import torch
212
+ except Exception as exc:
213
+ print(f"[local_run.sh] Error: PyTorch is required for GPU-only runs ({exc}).", file=sys.stderr)
214
+ sys.exit(1)
215
+ if not torch.cuda.is_available():
216
+ print("[local_run.sh] Error: torch.cuda.is_available() is false. GPU-only mode requires CUDA.", file=sys.stderr)
217
+ sys.exit(1)
218
+ PY
219
+ then
220
+ exit 1
221
+ fi
222
+
223
+ # Run from repo root so config defaults like `output: outputs/${exp_tag}` land in <repo_root>/outputs.
224
+ cd "${REPO_ROOT}"
225
+
226
+ LOG_FILE="${LOG_DIR}/${WORKFLOW}.log"
227
+
228
+ if [ "${GPUS}" -gt 1 ]; then
229
+ "${PYTHON_BIN}" -m torch.distributed.run \
230
+ --nproc_per_node="${GPUS}" \
231
+ --nnodes=1 \
232
+ --node_rank=0 \
233
+ --master_port="${MASTER_PORT}" \
234
+ code/workflows/run.py \
235
+ --config-name="${CONFIG_NAME}" \
236
+ workflow.task="${WORKFLOW}" \
237
+ +exp_tag="${EXPERIMENT_NAME}" \
238
+ ${RESUME_FLAG} \
239
+ ${OVERRIDES} \
240
+ 2>&1 | tee -a "${LOG_FILE}"
241
+ else
242
+ "${PYTHON_BIN}" -u code/workflows/run.py \
243
+ --config-name="${CONFIG_NAME}" \
244
+ workflow.task="${WORKFLOW}" \
245
+ +exp_tag="${EXPERIMENT_NAME}" \
246
+ ${RESUME_FLAG} \
247
+ ${OVERRIDES} \
248
+ 2>&1 | tee -a "${LOG_FILE}"
249
+ fi
250
+
251
+ cp -f "${LOG_FILE}" "${OUTPUT_DIR}/run.log"
252
+ echo "Done. Log: ${LOG_FILE}"
code/workflows/config_validator.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Public config normalization / validation for the early-access public release.
17
+
18
+ Responsibilities:
19
+ - Fail-fast if the user tries to set hidden/experimental fields (via Hydra CLI `+foo=...`)
20
+ - Merge in hidden defaults (sourced from model_1_d9 config) so training runs with a minimal public config
21
+ - Apply the selected public model architecture (model_id -> model.*)
22
+ - Clamp distance/n_rounds to the model receptive field:
23
+ D = min(distance, R)
24
+ N_R = min(n_rounds, R)
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ from pathlib import Path
30
+ import os
31
+ from typing import Any, Dict, Iterable, Tuple
32
+
33
+ from omegaconf import DictConfig, OmegaConf
34
+
35
+ from model.registry import PublicModelSpec, get_model_spec
36
+
37
+ _PUBLIC_ROTATION_TO_INTERNAL = {
38
+ # Public user-facing aliases
39
+ "O1": "XV",
40
+ "O2": "XH",
41
+ "O3": "ZV",
42
+ "O4": "ZH",
43
+ }
44
+ _INTERNAL_ROTATION_TO_PUBLIC = {v: k for k, v in _PUBLIC_ROTATION_TO_INTERNAL.items()}
45
+
46
+ _PUBLIC_MODEL_ID_TO_LR = {
47
+ 1: 3e-4,
48
+ 2: 2e-4,
49
+ 3: 1e-4,
50
+ 4: 2e-4,
51
+ 5: 1e-4,
52
+ 6: 2e-4,
53
+ }
54
+
55
+
56
+ def _default_precomputed_frames_dir() -> str:
57
+ """
58
+ Default location for precomputed frames shipped with (or generated inside) this repo.
59
+
60
+ We compute this path relative to the codebase so it is stable regardless of the user's
61
+ current working directory.
62
+ """
63
+ # .../<repo>/code/workflows/config_validator.py -> repo root is parents[2]
64
+ repo_root = Path(__file__).resolve().parents[2]
65
+ return str((repo_root / "frames_data").resolve())
66
+
67
+
68
+ def _get_env_bool(name: str, default: bool) -> bool:
69
+ raw = os.environ.get(name)
70
+ if raw is None:
71
+ return default
72
+ val = str(raw).strip().lower()
73
+ if val in ("0", "false", "no", "off", ""):
74
+ return False
75
+ return True
76
+
77
+
78
+ def _normalize_code_rotation(value: Any) -> str:
79
+ """
80
+ Normalize code rotation values.
81
+
82
+ Public config accepts O1..O4 for user convenience. Internally we keep using:
83
+ XV, XH, ZV, ZH (as expected by SurfaceCode / MemoryCircuit).
84
+ """
85
+ if value is None:
86
+ return value
87
+ s = str(value).strip().upper()
88
+ if s in _PUBLIC_ROTATION_TO_INTERNAL:
89
+ return _PUBLIC_ROTATION_TO_INTERNAL[s]
90
+ if s in _INTERNAL_ROTATION_TO_PUBLIC:
91
+ return s
92
+ raise ValueError(
93
+ f"Invalid data.code_rotation={value!r}. "
94
+ f"Use one of {sorted(_PUBLIC_ROTATION_TO_INTERNAL.keys())} (public) "
95
+ f"or {sorted(_INTERNAL_ROTATION_TO_PUBLIC.keys())} (internal)."
96
+ )
97
+
98
+
99
+ def _base_hidden_defaults_dict() -> Dict[str, Any]:
100
+ """
101
+ Baseline config used as the source-of-truth for hidden defaults.
102
+
103
+ IMPORTANT: We intentionally embed these defaults directly in code so the public
104
+ release does not ship internal/legacy config files. These values were copied
105
+ from the historical `config_pre_decoder_memory_surface_model_1_d9.yaml`.
106
+ """
107
+ base_output_dir = os.environ.get("PREDECODER_BASE_OUTPUT_DIR", "outputs")
108
+ output_root = f"{base_output_dir}/${{exp_tag}}"
109
+ return {
110
+ "exp_tag": "pre-decoder",
111
+ "output": output_root,
112
+ "hydra": {
113
+ "run": {
114
+ "dir": "${output}"
115
+ },
116
+ "output_subdir": "hydra"
117
+ },
118
+ "resume_dir": f"{output_root}/models",
119
+ "enable_fp16": False,
120
+ "enable_bf16": False,
121
+ "enable_matmul_tf32": True,
122
+ "enable_cudnn_tf32": True,
123
+ "enable_cudnn_benchmark": True,
124
+ "torch_compile": _get_env_bool("PREDECODER_TORCH_COMPILE", True),
125
+ "torch_compile_mode": os.environ.get("PREDECODER_TORCH_COMPILE_MODE", "default"),
126
+ "load_checkpoint": False,
127
+ "code": "surface",
128
+ "distance": 9,
129
+ "n_rounds": 9,
130
+ "multiple_distances": [13, 13],
131
+ "multiple_rounds": [13, 13],
132
+ "use_multiple_patches": False,
133
+ "meas_basis": "both",
134
+ "workflow": {
135
+ "task": "train"
136
+ },
137
+ "data":
138
+ {
139
+ "timelike_he": True,
140
+ "num_he_cycles": 1,
141
+ "use_weight2_timelike": False,
142
+ "max_passes_w1": 8,
143
+ "max_passes_w2": 4,
144
+ "decompose_y": True,
145
+ "p_error": None,
146
+ "p_min": 0.001,
147
+ "p_max": 0.006,
148
+ "error_mode": "circuit_level_surface_custom",
149
+ # Public config overrides this; keep the historical default for completeness.
150
+ "precomputed_frames_dir": _default_precomputed_frames_dir(),
151
+ "enable_correlated_pymatching": False,
152
+ "code_rotation": "XV",
153
+ "noise_model": None,
154
+ },
155
+ "model":
156
+ {
157
+ "version": "predecoder_memory_v1",
158
+ "dropout_p": 0.05,
159
+ "activation": "gelu",
160
+ "num_filters": [128, 128, 128, 4],
161
+ "kernel_size": [3, 3, 3, 3],
162
+ "input_channels": 4,
163
+ "out_channels": 4,
164
+ },
165
+ "datapipe": "memory",
166
+ "data_method": "train",
167
+ "train":
168
+ {
169
+ # Production baseline: 2^26 shots / epoch when training with 8 GPUs.
170
+ # The training script will auto-scale this based on detected world size / GPU count.
171
+ "num_samples": 67108864,
172
+ "accumulate_steps": 2,
173
+ "checkpoint_interval": 1,
174
+ "save_every_datasets": 5,
175
+ "epochs": 100,
176
+ },
177
+ # NOTE: temporarily reduced for faster iteration during refactor/testing.
178
+ "val": {
179
+ "num_samples": 65536,
180
+ "threshold": 0.5,
181
+ "trials": 1
182
+ },
183
+ "optimizer_type": "Lion",
184
+ "optimizer": {
185
+ "lr": 1e-4,
186
+ "weight_decay": 1e-7,
187
+ "beta2": 0.95
188
+ },
189
+ "lr_scheduler":
190
+ {
191
+ "type": "warmup_then_decay",
192
+ "warmup_steps": 100,
193
+ "milestones": [0.25, 0.5, 1.0],
194
+ "gamma": 0.7,
195
+ "min_lr": 1e-6,
196
+ },
197
+ "batch_schedule":
198
+ {
199
+ "enabled": True,
200
+ "initial": 256,
201
+ "final": 1024,
202
+ "start_epoch": 1,
203
+ "end_epoch": 3,
204
+ },
205
+ "validation_ler": True,
206
+ "early_stopping": {
207
+ "enabled": True,
208
+ "patience": 100
209
+ },
210
+ "time_based_early_stopping": {
211
+ "enabled": False,
212
+ "safety_margin_minutes": 5
213
+ },
214
+ "ema": {
215
+ "use_ema": True,
216
+ "decay": 0.0001
217
+ },
218
+ "test":
219
+ {
220
+ "num_samples": 262144,
221
+ "trials": 1,
222
+ "distance": 9,
223
+ "n_rounds": 9,
224
+ "noise_model": "train",
225
+ "p_error": 0.006,
226
+ "dataloader":
227
+ {
228
+ "batch_size": 64,
229
+ "num_workers": 0,
230
+ "persistent_workers": False,
231
+ },
232
+ "latency_num_samples": 1000,
233
+ "sampler": {
234
+ "shuffle": False,
235
+ "drop_last": False
236
+ },
237
+ "syn_red": "full",
238
+ "th_data": 0.0,
239
+ "th_syn": 0.0,
240
+ "sampling_mode": "threshold",
241
+ "temperature": 0.0,
242
+ "temperature_data": None,
243
+ "temperature_syn": None,
244
+ "per_round": False,
245
+ "meas_basis_test": "both",
246
+ "use_model_checkpoint": -1,
247
+ },
248
+ "threshold":
249
+ {
250
+ "p_values": [0.002, 0.003, 0.004, 0.005, 0.006, 0.007, 0.008],
251
+ "distances": [5, 7, 9, 11, 13],
252
+ "n_rounds": None,
253
+ },
254
+ }
255
+
256
+
257
+ def _select(cfg: DictConfig, key: str) -> Tuple[bool, Any]:
258
+ """
259
+ Return (exists, value) for a dot-path in cfg.
260
+ Note: OmegaConf.select returns None both for missing keys and explicit nulls,
261
+ so we treat a key as existing iff it is present in the underlying container.
262
+ """
263
+ # OmegaConf doesn't provide a direct 'has_key' for dotted paths; implement via container walk.
264
+ cur: Any = cfg
265
+ parts = key.split(".")
266
+ for p in parts:
267
+ if not isinstance(cur, DictConfig) or p not in cur:
268
+ return False, None
269
+ cur = cur[p]
270
+ return True, cur
271
+
272
+
273
+ def _assert_not_present(cfg: DictConfig, keys: Iterable[str], *, context: str) -> None:
274
+ for k in keys:
275
+ exists, _ = _select(cfg, k)
276
+ if exists:
277
+ raise ValueError(
278
+ f"Config field '{k}' is not supported in the public release ({context}). "
279
+ f"Remove it from the config/CLI overrides."
280
+ )
281
+
282
+
283
+ def validate_public_config(cfg: DictConfig) -> PublicModelSpec:
284
+ """
285
+ Validate the user-facing config BEFORE we merge in hidden defaults.
286
+
287
+ Returns:
288
+ PublicModelSpec for cfg.model_id (validated).
289
+ """
290
+ # model_id must exist in public config
291
+ if "model_id" not in cfg:
292
+ raise ValueError("Missing required field: 'model_id' (choose 1..5).")
293
+
294
+ model_spec = get_model_spec(cfg.model_id)
295
+
296
+ # Public config requires distance/n_rounds (evaluation targets)
297
+ if "distance" not in cfg or "n_rounds" not in cfg:
298
+ raise ValueError("Missing required fields: 'distance' and 'n_rounds'.")
299
+ try:
300
+ d = int(cfg.distance)
301
+ r = int(cfg.n_rounds)
302
+ except Exception as e:
303
+ raise ValueError(
304
+ f"Invalid distance/n_rounds: distance={cfg.distance!r}, n_rounds={cfg.n_rounds!r}"
305
+ ) from e
306
+ if d <= 0 or r <= 0:
307
+ raise ValueError(
308
+ f"Invalid distance/n_rounds: distance={d}, n_rounds={r} (must be positive integers)"
309
+ )
310
+
311
+ if "train" in cfg:
312
+ raise ValueError("Config field 'train' is not supported in the public release.")
313
+ if "val" in cfg:
314
+ raise ValueError("Config field 'val' is not supported in the public release.")
315
+ if "test" in cfg:
316
+ raise ValueError("Config field 'test' is not supported in the public release.")
317
+
318
+ # Fail-fast on known hidden fields if the user tries to inject them.
319
+ _assert_not_present(
320
+ cfg,
321
+ keys=(
322
+ # output paths are managed by the runner scripts; not user-configurable in public release
323
+ "output",
324
+ "resume_dir",
325
+ # precision / tf32 knobs (always fp32 + tf32 enabled)
326
+ "enable_fp16",
327
+ "enable_bf16",
328
+ "enable_matmul_tf32",
329
+ "enable_cudnn_tf32",
330
+ # always both bases
331
+ "meas_basis",
332
+ # multi-patch curriculum mode (hidden)
333
+ "use_multiple_patches",
334
+ "multiple_distances",
335
+ "multiple_rounds",
336
+ # optimizer knobs (only optimizer.lr exposed)
337
+ "optimizer",
338
+ "optimizer_type",
339
+ "lr_scheduler",
340
+ "batch_schedule",
341
+ # obsolete/confusing
342
+ "train.save_every_datasets",
343
+ # validation hidden knobs
344
+ "val.threshold",
345
+ "val.trials",
346
+ # early stopping extras hidden
347
+ "time_based_early_stopping",
348
+ "ema",
349
+ ),
350
+ context="hidden field override",
351
+ )
352
+
353
+ # Restrict cfg.data to a small public surface (others can be too experimental).
354
+ if "data" in cfg and isinstance(cfg.data, DictConfig):
355
+ # NOTE: precomputed frames path is intentionally hidden from the public config.
356
+ # We default it internally to <repo>/frames_data (see _default_precomputed_frames_dir).
357
+ if "precomputed_frames_dir" in cfg.data:
358
+ raise ValueError(
359
+ "Config field 'data.precomputed_frames_dir' is not supported in the public release. "
360
+ "Remove it from the config/CLI overrides."
361
+ )
362
+ allowed_data_keys = {"code_rotation", "noise_model"}
363
+ for k in cfg.data.keys():
364
+ if k not in allowed_data_keys:
365
+ raise ValueError(
366
+ f"Config field 'data.{k}' is not supported in the public release. "
367
+ f"Allowed data fields are: {sorted(allowed_data_keys)}"
368
+ )
369
+ # Validate rotation value (accept O1..O4; also allow internal XV/XH/ZV/ZH for compatibility).
370
+ if "code_rotation" in cfg.data:
371
+ _normalize_code_rotation(cfg.data.code_rotation)
372
+
373
+ # Restrict optimizer sub-keys: only lr is public.
374
+ if "optimizer" in cfg and isinstance(cfg.optimizer, DictConfig):
375
+ for k in cfg.optimizer.keys():
376
+ if k != "lr":
377
+ raise ValueError(
378
+ f"Config field 'optimizer.{k}' is not supported in the public release. "
379
+ f"Only 'optimizer.lr' is user-configurable."
380
+ )
381
+
382
+ return model_spec
383
+
384
+
385
+ def clamp_to_receptive_field(cfg: DictConfig, R: int) -> None:
386
+ """In-place clamp of cfg.distance and cfg.n_rounds to receptive field R."""
387
+ if not isinstance(R, int) or R <= 0:
388
+ raise ValueError(f"Invalid receptive field R={R!r}")
389
+ if "distance" not in cfg or "n_rounds" not in cfg:
390
+ raise ValueError("Both 'distance' and 'n_rounds' must be present in config.")
391
+ cfg.distance = int(min(int(cfg.distance), R))
392
+ cfg.n_rounds = int(min(int(cfg.n_rounds), R))
393
+
394
+
395
+ def apply_public_defaults_and_model(cfg: DictConfig, model_spec: PublicModelSpec) -> DictConfig:
396
+ """
397
+ Merge hidden defaults and apply public model settings.
398
+
399
+ Returns a new DictConfig (does not mutate input).
400
+ """
401
+ base_cfg = OmegaConf.create(_base_hidden_defaults_dict())
402
+
403
+ # Merge: base provides full training-ready config; public cfg overrides user-visible fields.
404
+ merged = OmegaConf.merge(base_cfg, cfg)
405
+ OmegaConf.set_struct(merged, False)
406
+
407
+ # In the public release:
408
+ # - cfg.distance / cfg.n_rounds are the *evaluation targets* the user cares about
409
+ # - training always uses distance=n_rounds=R (the model receptive field)
410
+ requested_distance = int(merged.distance)
411
+ requested_n_rounds = int(merged.n_rounds)
412
+
413
+ # Enforce public invariants (hidden from user)
414
+ merged.enable_fp16 = False
415
+ merged.enable_bf16 = False
416
+ merged.enable_matmul_tf32 = True
417
+ merged.enable_cudnn_tf32 = True
418
+
419
+ merged.meas_basis = "both"
420
+
421
+ # Disable multi-patch mode explicitly
422
+ if "data" not in merged:
423
+ merged.data = {}
424
+ merged.data.use_multiple_patches = False
425
+ merged.multiple_distances = None
426
+ merged.multiple_rounds = None
427
+
428
+ # Always use repo-relative frames_data by default (hidden from public config).
429
+ merged.data.precomputed_frames_dir = _default_precomputed_frames_dir()
430
+
431
+ # Apply model architecture from registry
432
+ if "model" not in merged:
433
+ merged.model = {}
434
+ merged.model.version = model_spec.model_version
435
+ merged.model.num_filters = list(model_spec.num_filters)
436
+ merged.model.kernel_size = list(model_spec.kernel_size)
437
+
438
+ # Public release: hard-code optimizer.lr based on model choice.
439
+ # (User is not allowed to override optimizer settings.)
440
+ if "optimizer" not in merged:
441
+ merged.optimizer = {}
442
+ lr = _PUBLIC_MODEL_ID_TO_LR.get(int(model_spec.model_id))
443
+ if lr is None:
444
+ raise ValueError(f"No public LR mapping for model_id={model_spec.model_id!r}")
445
+ merged.optimizer.lr = float(lr)
446
+
447
+ # Public release: production-like batch schedule defaults.
448
+ # Target behavior: per-GPU batch size is 512 in the first epoch, 2048 thereafter.
449
+ # Model 3 is heavier; use a smaller schedule there.
450
+ if "batch_schedule" not in merged:
451
+ merged.batch_schedule = {}
452
+ merged.batch_schedule.enabled = True
453
+ if int(model_spec.model_id) == 3:
454
+ merged.batch_schedule.initial = 256
455
+ merged.batch_schedule.final = 1024
456
+ elif int(model_spec.model_id) == 6:
457
+ merged.batch_schedule.initial = 256
458
+ merged.batch_schedule.final = 512
459
+ else:
460
+ merged.batch_schedule.initial = 512
461
+ merged.batch_schedule.final = 2048
462
+ # "First epoch only" initial, then final for all later epochs.
463
+ merged.batch_schedule.start_epoch = 0
464
+ merged.batch_schedule.end_epoch = 0
465
+
466
+ # Public release: training epochs default to production values,
467
+ # but honor explicit user overrides for quick validation runs.
468
+ if "train" not in merged:
469
+ merged.train = {}
470
+ if not ("train" in cfg and isinstance(cfg.train, DictConfig) and "epochs" in cfg.train):
471
+ merged.train.epochs = 100
472
+
473
+ # Public release: validation sample count defaults to production values,
474
+ # but honor explicit user overrides for quick validation runs.
475
+ if "val" not in merged:
476
+ merged.val = {}
477
+ # NOTE: temporarily reduced for faster iteration during refactor/testing.
478
+ if not ("val" in cfg and isinstance(cfg.val, DictConfig) and "num_samples" in cfg.val):
479
+ merged.val.num_samples = 65536
480
+
481
+ # Train vs inference window semantics (public release):
482
+ # - Top-level cfg.distance / cfg.n_rounds are the user-specified *evaluation* targets.
483
+ # - Training always runs on the model receptive field R (distance=n_rounds=R).
484
+ task = str(getattr(getattr(merged, "workflow", None), "task", "train")).strip().lower()
485
+ R = int(model_spec.receptive_field)
486
+ if R <= 0:
487
+ raise ValueError(f"Invalid receptive field R={R!r}")
488
+ if task == "train":
489
+ merged.distance = R
490
+ merged.n_rounds = R
491
+ else:
492
+ merged.distance = int(requested_distance)
493
+ merged.n_rounds = int(requested_n_rounds)
494
+
495
+ # Public code_rotation aliases: normalize O1..O4 -> internal XV/XH/ZV/ZH.
496
+ if "data" in merged and "code_rotation" in merged.data:
497
+ merged.data.code_rotation = _normalize_code_rotation(merged.data.code_rotation)
498
+
499
+ # Test/evaluation config is hidden and always uses the user-requested window.
500
+ if "test" not in merged:
501
+ merged.test = {}
502
+ if not ("test" in cfg and isinstance(cfg.test, DictConfig) and "num_samples" in cfg.test):
503
+ merged.test.num_samples = 262144
504
+ merged.test.distance = int(requested_distance)
505
+ merged.test.n_rounds = int(requested_n_rounds)
506
+ merged.test.noise_model = "train"
507
+ return merged
code/workflows/run.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import hydra, sys, torch, os, json, numpy as np
17
+ from omegaconf import DictConfig, OmegaConf
18
+ from training.train import main as train_main
19
+ from model.factory import ModelFactory
20
+ from data.factory import DatapipeFactory
21
+ from hydra.utils import to_absolute_path
22
+ from workflows.config_validator import (
23
+ apply_public_defaults_and_model,
24
+ validate_public_config,
25
+ )
26
+
27
+ from training.distributed import DistributedManager
28
+
29
+ from torch.utils.data import DataLoader
30
+
31
+
32
+ def _ensure_inference_io_channels(cfg):
33
+ # 1) Ensure out_channels matches the model’s heads (4: z_data, x_data, syn_x, syn_z)
34
+ if not getattr(cfg.model, "out_channels", None) or cfg.model.out_channels == 0:
35
+ cfg.model.out_channels = 4
36
+
37
+ # 2) Infer input_channels from a single inference sample if not set
38
+ if not getattr(cfg.model, "input_channels", None) or cfg.model.input_channels == 0:
39
+ ds = DatapipeFactory.create_datapipe_inference(cfg)
40
+ tmp = DataLoader(ds, batch_size=1)
41
+ sample = next(iter(tmp))
42
+ cfg.model.input_channels = int(sample["trainX"].shape[1])
43
+
44
+ # 3) Keep num_filters consistent with out_channels
45
+ if hasattr(cfg.model, "num_filters"):
46
+ filters = list(cfg.model.num_filters)
47
+ if filters and filters[-1] != cfg.model.out_channels:
48
+ print(
49
+ f"[run] Adjusting model.num_filters[-1] {filters[-1]} -> {cfg.model.out_channels}"
50
+ )
51
+ filters[-1] = cfg.model.out_channels
52
+ cfg.model.num_filters = filters
53
+
54
+
55
+ @hydra.main(version_base="1.3", config_path="../../conf", config_name="config")
56
+ def run(cfg: DictConfig) -> None:
57
+ # Early-access public release: validate public surface, then merge in hidden defaults.
58
+ # NOTE: Validation is done BEFORE merging defaults so we can fail fast on injected fields.
59
+ model_spec = validate_public_config(cfg)
60
+ cfg = apply_public_defaults_and_model(cfg, model_spec)
61
+
62
+ torch.backends.cuda.matmul.allow_tf32 = cfg.enable_matmul_tf32
63
+ torch.backends.cudnn.allow_tf32 = cfg.enable_cudnn_tf32
64
+
65
+ if cfg.code == "surface" or cfg.code == "surface_partition":
66
+ run_surface(cfg)
67
+
68
+
69
+ def run_surface(cfg: DictConfig):
70
+ if cfg.workflow.task == "train":
71
+ train_main(cfg)
72
+ elif cfg.workflow.task == "threshold":
73
+ raise ValueError(
74
+ "workflow.task='threshold' has been renamed to workflow.task='inference'. "
75
+ "Please update your config/env var to WORKFLOW=inference."
76
+ )
77
+ elif cfg.workflow.task == "inference":
78
+ from evaluation.inference import run_inference
79
+ DistributedManager.initialize()
80
+ dist = DistributedManager()
81
+ model = _load_model(cfg, dist)
82
+ run_inference(model, dist.device, dist, cfg)
83
+ elif cfg.workflow.task == "data":
84
+ DistributedManager.initialize()
85
+ dist = DistributedManager()
86
+ train_loader, _ = DatapipeFactory.create_dataloader(cfg, dist.world_size, dist.rank)
87
+ for j, dl in enumerate(train_loader):
88
+ print(f"Batch {j}: syndrome_shape: {dl['syndrome'].shape}")
89
+ elif cfg.workflow.task == "decoder_ablation":
90
+ from evaluation.failure_analysis import decoder_ablation_study
91
+ DistributedManager.initialize()
92
+ dist = DistributedManager()
93
+ model = _load_model(cfg, dist)
94
+ decoder_ablation_study(model, dist.device, dist, cfg)
95
+ elif cfg.workflow.task in ("sampling", "visualize"):
96
+ raise ValueError(
97
+ f"workflow.task={cfg.workflow.task!r} is not supported in the early-access public release. "
98
+ "Supported workflows: train, inference, decoder_ablation."
99
+ )
100
+
101
+
102
+ def find_best_model(path, *, rank: int = 0):
103
+ if rank == 0:
104
+ print(f"Searching for best model in: {path}")
105
+ if not os.path.isdir(path):
106
+ raise FileNotFoundError(f"Model directory does not exist: {path}")
107
+
108
+ max_value = -1 # Start with -1 to include epoch 0
109
+ best_file = None
110
+ model_files = []
111
+ # Named .pt files without epoch numbers (e.g. Ising-Decoder-SurfaceCode-1-Fast.pt)
112
+ named_pt_files = []
113
+
114
+ for filename in os.listdir(path):
115
+ if not filename.endswith(".pt"):
116
+ continue
117
+ if filename.startswith("PreDecoderModelMemory_"):
118
+ try:
119
+ value = float(filename.split(".")[2]) # Gets epoch number
120
+ model_files.append((filename, value))
121
+ if value > max_value:
122
+ max_value = value
123
+ best_file = filename
124
+ except (IndexError, ValueError) as e:
125
+ print(f"Warning: could not parse epoch from filename {filename}: {e}")
126
+ else:
127
+ named_pt_files.append(filename)
128
+
129
+ # Fall back to named .pt files when no epoch-numbered checkpoints are present
130
+ if best_file is None and named_pt_files:
131
+ named_pt_files.sort()
132
+ best_file = named_pt_files[-1]
133
+ model_files = [(f, None) for f in named_pt_files]
134
+
135
+ if rank == 0:
136
+ print(f"Found {len(model_files)} model file(s):")
137
+ for filename, epoch in sorted(model_files, key=lambda x: (x[1] is None, x[1] or 0)):
138
+ marker = "*" if filename == best_file else " "
139
+ epoch_str = str(epoch) if epoch is not None else "n/a"
140
+ print(f" [{marker}] {filename} (epoch {epoch_str})")
141
+
142
+ if best_file is None:
143
+ raise FileNotFoundError(
144
+ f"No valid model checkpoint files found in {path}\n"
145
+ f"Expected .pt files (e.g. Ising-Decoder-SurfaceCode-1-Fast.pt or "
146
+ f"PreDecoderModelMemory_*.pt).\n"
147
+ f"Hint: download the pretrained weights and place them in this directory, "
148
+ f"or set model_checkpoint_file in your config to an explicit path."
149
+ )
150
+
151
+ best_model_path = os.path.join(path, best_file)
152
+ if rank == 0:
153
+ epoch_str = str(max_value) if max_value >= 0 else "n/a"
154
+ print(f"Selected best model: {best_file} (epoch {epoch_str})")
155
+
156
+ return best_model_path
157
+
158
+
159
+ def _resolve_dir(path: str) -> str:
160
+ """Return an absolute version of path, resolving relative paths from the repo root."""
161
+ if os.path.isabs(path):
162
+ return path
163
+ repo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
164
+ return os.path.join(repo_root, path)
165
+
166
+
167
+ def _load_state_dict_from_pt(model_path: str, device) -> dict:
168
+ """Load a state dict from a .pt checkpoint, handling multiple saved formats.
169
+
170
+ Supports:
171
+ - bare state dict (keys are layer names)
172
+ - {"model_state_dict": ...}
173
+ - {"state_dict": ...}
174
+ Also strips the DDP "module." prefix if present.
175
+ """
176
+ raw = torch.load(model_path, map_location=device, weights_only=False)
177
+ if isinstance(raw, dict):
178
+ if "model_state_dict" in raw:
179
+ state_dict = raw["model_state_dict"]
180
+ elif "state_dict" in raw:
181
+ state_dict = raw["state_dict"]
182
+ else:
183
+ state_dict = raw
184
+ else:
185
+ raise ValueError(f"Unexpected checkpoint format: expected a dict, got {type(raw).__name__}")
186
+ return {
187
+ (k[len("module."):] if k.startswith("module.") else k): v for k, v in state_dict.items()
188
+ }
189
+
190
+
191
+ def _load_model(cfg, dist):
192
+ if dist.rank == 0:
193
+ print(f"Loading model for task: {cfg.workflow.task}")
194
+
195
+ _ensure_inference_io_channels(cfg)
196
+
197
+ # SafeTensors path: load fp16/fp32 model from SafeTensors file
198
+ safetensors_path = os.environ.get("PREDECODER_SAFETENSORS_CHECKPOINT", "").strip()
199
+ if safetensors_path:
200
+ from export.safetensors_utils import load_safetensors
201
+ if dist.rank == 0:
202
+ print(f"Loading model from SafeTensors: {safetensors_path}")
203
+
204
+ # Auto-detect model_id from SafeTensors metadata (don't override with config)
205
+ model, metadata = load_safetensors(
206
+ safetensors_path,
207
+ model_id=None,
208
+ device=str(dist.device),
209
+ )
210
+ if dist.rank == 0:
211
+ loaded_model_id = metadata.get("model_id", "unknown")
212
+ dtype = metadata.get("quant_format", "fp32")
213
+ receptive_field = metadata.get("receptive_field", "unknown")
214
+ param_count = sum(p.numel() for p in model.parameters())
215
+ print(f" model_id: {loaded_model_id} (from SafeTensors metadata)")
216
+ print(f" receptive_field: {receptive_field}")
217
+ print(f" dtype: {dtype}")
218
+ print(f" parameters: {param_count:,}")
219
+
220
+ # Warn if config model_id doesn't match file metadata
221
+ config_model_id = getattr(cfg, "model_id", None)
222
+ if config_model_id is not None and str(config_model_id) != str(loaded_model_id):
223
+ print(
224
+ f" Warning: config model_id={config_model_id} differs from "
225
+ f"file model_id={loaded_model_id}; using {loaded_model_id}"
226
+ )
227
+
228
+ if metadata.get("quant_format") == "fp16":
229
+ cfg.enable_fp16 = True
230
+ return model
231
+
232
+ # Direct file path override (for named pretrained models without epoch numbers)
233
+ model_checkpoint_file = getattr(cfg, 'model_checkpoint_file', None)
234
+ if model_checkpoint_file:
235
+ model_checkpoint_file = _resolve_dir(str(model_checkpoint_file))
236
+ if not os.path.exists(model_checkpoint_file):
237
+ raise FileNotFoundError(f"Checkpoint not found: {model_checkpoint_file}")
238
+ if dist.rank == 0:
239
+ print(f"Loading model from: {model_checkpoint_file}")
240
+ model = ModelFactory.create_model(cfg).to(dist.device)
241
+ if cfg.enable_fp16:
242
+ model = model.half()
243
+ state_dict = _load_state_dict_from_pt(model_checkpoint_file, dist.device)
244
+ model.load_state_dict(state_dict)
245
+ if dist.rank == 0:
246
+ param_count = sum(p.numel() for p in model.parameters())
247
+ print(f"Model loaded ({param_count:,} parameters)")
248
+ return model
249
+
250
+ model = ModelFactory.create_model(cfg).to(dist.device)
251
+
252
+ if cfg.enable_fp16:
253
+ model = model.half()
254
+ if dist.rank == 0:
255
+ print("Model converted to float16 for fp16 inference")
256
+
257
+ # Determine model directory
258
+ # Priority: 1) model_checkpoint_dir (for inference configs)
259
+ # 2) cfg.output/models (for training configs)
260
+ model_checkpoint_dir = getattr(cfg, 'model_checkpoint_dir', None)
261
+ use_checkpoint = getattr(cfg.test, 'use_model_checkpoint', -1)
262
+
263
+ if use_checkpoint == -1:
264
+ model_dir = _resolve_dir(
265
+ os.path.join(model_checkpoint_dir, "best_model")
266
+ if model_checkpoint_dir else f"{cfg.output}/models/best_model"
267
+ )
268
+ if dist.rank == 0:
269
+ print(f"Loading best model from: {model_dir}")
270
+
271
+ # Fallback: older runs may not have a best_model/ folder
272
+ if not os.path.isdir(model_dir):
273
+ fallback_dir = _resolve_dir(
274
+ model_checkpoint_dir if model_checkpoint_dir else f"{cfg.output}/models"
275
+ )
276
+ if dist.rank == 0:
277
+ print(f"best_model/ not found; falling back to: {fallback_dir}")
278
+ model_dir = fallback_dir
279
+
280
+ model_path = find_best_model(model_dir, rank=dist.rank)
281
+ else:
282
+ checkpoint_dir = _resolve_dir(
283
+ model_checkpoint_dir if model_checkpoint_dir else f"{cfg.output}/models"
284
+ )
285
+ if dist.rank == 0:
286
+ print(f"Loading checkpoint {use_checkpoint} from: {checkpoint_dir}")
287
+
288
+ # Prefer any PreDecoderModelMemory_* file ending with .0.{use_checkpoint}.pt
289
+ target_suffix = f".0.{use_checkpoint}.pt"
290
+ checkpoint_filename = None
291
+ try:
292
+ for f in os.listdir(checkpoint_dir):
293
+ if f.startswith("PreDecoderModelMemory_") and f.endswith(target_suffix):
294
+ checkpoint_filename = f
295
+ break
296
+ except OSError:
297
+ pass
298
+ if checkpoint_filename is None:
299
+ checkpoint_filename = f"PreDecoderModelMemory_v1.0.{use_checkpoint}.pt"
300
+ model_path = os.path.join(checkpoint_dir, checkpoint_filename)
301
+
302
+ if not os.path.exists(model_path):
303
+ raise FileNotFoundError(f"Checkpoint not found: {model_path}")
304
+
305
+ if dist.rank == 0:
306
+ print(f"Loading model parameters from: {model_path}")
307
+
308
+ state_dict = _load_state_dict_from_pt(model_path, dist.device)
309
+ model.load_state_dict(state_dict)
310
+
311
+ if dist.rank == 0:
312
+ param_count = sum(p.numel() for p in model.parameters())
313
+ print(f"Model loaded ({param_count:,} parameters)")
314
+
315
+ return model
316
+
317
+
318
+ if __name__ == "__main__":
319
+ run()
conf/config_public.yaml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Public, single-file config for external users.
17
+ #
18
+ # Users should only edit the fields in this file.
19
+ # Advanced/experimental fields are intentionally omitted and will be populated
20
+ # from internal defaults (and validated to prevent unsupported overrides).
21
+
22
+ # === Model selection (required) ===
23
+ model_id: 6 # Choose 1, 2, 3, 4, or 5
24
+
25
+ model:
26
+ version: predecoder_fasthyper_rf13_v1
27
+ input_channels: 4
28
+ out_channels: 4
29
+ hidden_dim: 96
30
+ mid_dim: 144
31
+ mix_groups: 6
32
+ num_blocks: 5
33
+ stem_kernel_size: 3
34
+ gate_reduction: 4
35
+ dropout_p: 0.02
36
+
37
+
38
+ # === Values for evaluation. Training window is hardcoded to model receptive field. ===
39
+ distance: 13
40
+ n_rounds: 104
41
+
42
+ # === Workflow ===
43
+ workflow:
44
+ task: train # train, inference
45
+ # simplify logs of inference to have only pymatching b4 and after predecoding. TODO: batch_size=1
46
+
47
+ # === Data (public surface only) ===
48
+ data:
49
+ # Surface code orientation (public naming): O1, O2, O3, O4
50
+ code_rotation: O1
51
+ # Circuit-level noise model (25-parameter). This is the default public noise specification.
52
+ # The defaults are chosen for p=0.003.
53
+ noise_model:
54
+ # State preparation errors (2)
55
+ p_prep_X: 0.002 # |+> state-prep fails with this probability (apply Z), 2*p/3
56
+ p_prep_Z: 0.002 # |0> state-prep fails with this probability (apply X), 2*p/3
57
+ # Measurement errors (2)
58
+ p_meas_X: 0.002 # Measurement in X-basis fails with this probability (apply Z before measurement), 2*p/3
59
+ p_meas_Z: 0.002 # Measurement in Z-basis fails with this probability (apply X before measurement), 2*p/3
60
+ # Idle during CNOT layers / bulk (3)
61
+ p_idle_cnot_X: 0.001 # p/3
62
+ p_idle_cnot_Y: 0.001 # p/3
63
+ p_idle_cnot_Z: 0.001 # p/3
64
+ # Idle during SPAM window (ancilla prep+reset) on data qubits only (3)
65
+ p_idle_spam_X: 0.001998 # 2*p/3 - 2*p^2/9
66
+ p_idle_spam_Y: 0.001998 # 2*p/3 - 2*p^2/9
67
+ p_idle_spam_Z: 0.001998 # 2*p/3 - 2*p^2/9
68
+ # CNOT two-qubit errors (15) - keys are p_cnot_{Pauli}{Pauli} excluding II, p/15
69
+ p_cnot_IX: 0.0002
70
+ p_cnot_IY: 0.0002
71
+ p_cnot_IZ: 0.0002
72
+ p_cnot_XI: 0.0002
73
+ p_cnot_XX: 0.0002
74
+ p_cnot_XY: 0.0002
75
+ p_cnot_XZ: 0.0002
76
+ p_cnot_YI: 0.0002
77
+ p_cnot_YX: 0.0002
78
+ p_cnot_YY: 0.0002
79
+ p_cnot_YZ: 0.0002
80
+ p_cnot_ZI: 0.0002
81
+ p_cnot_ZX: 0.0002
82
+ p_cnot_ZY: 0.0002
83
+ p_cnot_ZZ: 0.0002
84
+
framework.png ADDED