Instructions to use quantispect/QuantiSpect-V1 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Ising Decoding
How to use quantispect/QuantiSpect-V1 with Ising Decoding:
# No code snippets available yet for this library. # To use this model, check the repository files and the library's documentation. # Want to help? PRs adding snippets are welcome at: # https://github.com/huggingface/huggingface.js
- Notebooks
- Google Colab
- Kaggle
donghufeng commited on
Commit ·
d57fabf
1
Parent(s): 7fb034a
init
Browse files- Quantispect_RF13_v1.0.10.pt +3 -0
- README.md +216 -3
- bias_subcard.md +7 -0
- code/model/factory.py +62 -0
- code/model/predecoder_fasthyper_rf13_v1.py +170 -0
- code/model/registry.py +110 -0
- code/scripts/local_run.sh +252 -0
- code/workflows/config_validator.py +507 -0
- code/workflows/run.py +319 -0
- conf/config_public.yaml +84 -0
- framework.png +0 -0
Quantispect_RF13_v1.0.10.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c899ee6674d1d78bb7570c5284086b4dfc5a2d3aba63fbac6ded0beaecfb831e
|
| 3 |
+
size 2693053
|
README.md
CHANGED
|
@@ -1,3 +1,216 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: ising-decoding
|
| 3 |
+
tags:
|
| 4 |
+
- quantum
|
| 5 |
+
- qec
|
| 6 |
+
- error_correction
|
| 7 |
+
- decoders
|
| 8 |
+
- surface_code
|
| 9 |
+
- predecoder
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# Quantispect Overview
|
| 14 |
+
|
| 15 |
+

|
| 16 |
+
|
| 17 |
+
## Model Summary
|
| 18 |
+
|
| 19 |
+
| Item | Value |
|
| 20 |
+
|---|---:|
|
| 21 |
+
| Model name | Quantispect |
|
| 22 |
+
| Checkpoint file | `Quantispect_RF13_v1.0.10.pt` |
|
| 23 |
+
| Total parameters | ~0.663M |
|
| 24 |
+
| Checkpoint size | ~2.63 MB |
|
| 25 |
+
| Architecture | FastHyper-style 3D CNN neural pre-decoder |
|
| 26 |
+
| Receptive field | R=13 |
|
| 27 |
+
| Input tensor | `(B, 4, T, D, D)` |
|
| 28 |
+
| Output tensor | `(B, 4, T, D, D)` |
|
| 29 |
+
| Release date | April 26, 2026 |
|
| 30 |
+
|
| 31 |
+
## Description:
|
| 32 |
+
|
| 33 |
+
Quantispect is a compact neural pre-decoder for rotated surface-code quantum error correction. It consumes five-dimensional syndrome volumes across batch, channel, time, and two spatial dimensions, and predicts local correction maps that are consumed by a downstream global decoder such as MWPM / PyMatching or an Ising-decoding post-processing pipeline.
|
| 34 |
+
|
| 35 |
+
Quantispect is designed to run inside an NVIDIA Ising-Decoding-compatible workflow after applying the Quantispect code patch included with this model release.
|
| 36 |
+
|
| 37 |
+
## Model Architecture:
|
| 38 |
+
|
| 39 |
+
Architecture type: 3D Convolutional Neural Network (3D CNN)
|
| 40 |
+
|
| 41 |
+
Network architecture: custom multi-branch spatio-temporal 3D CNN with residual FastHyper blocks.
|
| 42 |
+
|
| 43 |
+
### Input
|
| 44 |
+
|
| 45 |
+
Input shape:
|
| 46 |
+
|
| 47 |
+
```text
|
| 48 |
+
(B, 4, T, D, D)
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### Stem
|
| 52 |
+
|
| 53 |
+
```text
|
| 54 |
+
Conv3D 4 -> 96, kernel 3x3x3
|
| 55 |
+
GroupNorm
|
| 56 |
+
GELU
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
Stem output shape:
|
| 60 |
+
|
| 61 |
+
```text
|
| 62 |
+
(B, 96, T, D, D)
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
### Main Body
|
| 66 |
+
|
| 67 |
+
The main body contains five repeated `FastHyperBlock` modules:
|
| 68 |
+
|
| 69 |
+
```text
|
| 70 |
+
FastHyperBlock x5
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
Each `FastHyperBlock` first expands the feature width from 96 to 144 channels with a 1x1x1 convolution, then applies three parallel feature extraction branches:
|
| 74 |
+
|
| 75 |
+
```text
|
| 76 |
+
Pre-projection: GroupNorm -> 1x1x1 Conv3D, 96 -> 144 -> GELU
|
| 77 |
+
|
| 78 |
+
Branch A: Depthwise Conv3D, kernel 1x3x3, spatial branch
|
| 79 |
+
Branch B: Depthwise Conv3D, kernel 3x1x1, temporal branch
|
| 80 |
+
Branch C: GroupNorm -> Grouped Conv3D, kernel 3x3x3, groups=6, joint local spatio-temporal branch
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
The three branch outputs are aligned and fused by element-wise summation rather than channel concatenation. The fused feature is then projected and recalibrated:
|
| 84 |
+
|
| 85 |
+
```text
|
| 86 |
+
Element-wise sum fusion
|
| 87 |
+
1x1x1 Conv3D projection, 144 -> 96
|
| 88 |
+
GELU
|
| 89 |
+
ChannelGate / SE-style channel attention
|
| 90 |
+
Dropout3D
|
| 91 |
+
Residual connection
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
Main body output shape:
|
| 95 |
+
|
| 96 |
+
```text
|
| 97 |
+
(B, 96, T, D, D)
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
### Head
|
| 101 |
+
|
| 102 |
+
```text
|
| 103 |
+
GroupNorm
|
| 104 |
+
1x1x1 Conv3D, 96 -> 96
|
| 105 |
+
GELU
|
| 106 |
+
1x1x1 Conv3D, 96 -> 4
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
Output shape:
|
| 110 |
+
|
| 111 |
+
```text
|
| 112 |
+
(B, 4, T, D, D)
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
The output maps are used by the residual-syndrome construction module and then passed to MWPM / Ising-decoder post-processing.
|
| 116 |
+
|
| 117 |
+
## Usage:
|
| 118 |
+
|
| 119 |
+
Quantispect is intended to be used with the NVIDIA Ising-Decoding environment:
|
| 120 |
+
|
| 121 |
+
```text
|
| 122 |
+
https://github.com/NVIDIA/Ising-Decoding
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
A clean NVIDIA Ising-Decoding checkout does not natively know the Quantispect / FastHyper architecture. To run `Quantispect_RF13_v1.0.10.pt`, first apply the Quantispect code patch included in this model repository.
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
### Required code patch files
|
| 129 |
+
|
| 130 |
+
The patch package should preserve the following relative paths:
|
| 131 |
+
|
| 132 |
+
```text
|
| 133 |
+
quantispect_code_patch/
|
| 134 |
+
├── conf/
|
| 135 |
+
│ └── config_public.yaml
|
| 136 |
+
└── code/
|
| 137 |
+
├── model/
|
| 138 |
+
│ ├── predecoder_fasthyper_rf13_v1.py
|
| 139 |
+
│ ├── factory.py
|
| 140 |
+
│ └── registry.py
|
| 141 |
+
├── workflows/
|
| 142 |
+
│ ├── config_validator.py
|
| 143 |
+
│ └── run.py
|
| 144 |
+
└── scripts/
|
| 145 |
+
└── local_run.sh
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
These files should be copied into the NVIDIA Ising-Decoding repository with the same relative paths:
|
| 149 |
+
|
| 150 |
+
```text
|
| 151 |
+
conf/config_public.yaml -> Ising-Decoding/conf/config_public.yaml
|
| 152 |
+
code/model/predecoder_fasthyper_rf13_v1.py -> Ising-Decoding/code/model/predecoder_fasthyper_rf13_v1.py
|
| 153 |
+
code/model/factory.py -> Ising-Decoding/code/model/factory.py
|
| 154 |
+
code/model/registry.py -> Ising-Decoding/code/model/registry.py
|
| 155 |
+
code/workflows/config_validator.py -> Ising-Decoding/code/workflows/config_validator.py
|
| 156 |
+
code/workflows/run.py -> Ising-Decoding/code/workflows/run.py
|
| 157 |
+
code/scripts/local_run.sh -> Ising-Decoding/code/scripts/local_run.sh
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
The patch mainly adds the `predecoder_fasthyper_rf13_v1` model implementation, registers `model_id: 6`, adds the Quantispect model hyperparameters to `config_public.yaml`, and enables explicit `.pt` checkpoint loading through `model_checkpoint_file`.
|
| 161 |
+
|
| 162 |
+
### Apply the patch
|
| 163 |
+
|
| 164 |
+
From the directory containing both the clean NVIDIA Ising-Decoding repository and this downloaded patch package:
|
| 165 |
+
|
| 166 |
+
```bash
|
| 167 |
+
cp -r code/* Ising-Decoding/code/
|
| 168 |
+
cp -r conf/* Ising-Decoding/conf/
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
Then place the Quantispect checkpoint under the repository model directory:
|
| 172 |
+
|
| 173 |
+
```bash
|
| 174 |
+
mkdir -p Ising-Decoding/models
|
| 175 |
+
cp Quantispect_RF13_v1.0.10.pt Ising-Decoding/models/Quantispect_RF13_v1.0.10.pt
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
Expected directory layout:
|
| 179 |
+
|
| 180 |
+
```text
|
| 181 |
+
Ising-Decoding/
|
| 182 |
+
├── code/
|
| 183 |
+
│ ├── model/
|
| 184 |
+
│ │ └── predecoder_fasthyper_rf13_v1.py
|
| 185 |
+
│ ├── workflows/
|
| 186 |
+
│ │ ├── config_validator.py
|
| 187 |
+
│ │ └── run.py
|
| 188 |
+
│ └── scripts/
|
| 189 |
+
│ └── local_run.sh
|
| 190 |
+
├── conf/
|
| 191 |
+
│ └── config_public.yaml
|
| 192 |
+
├── models/
|
| 193 |
+
│ └── Quantispect_RF13_v1.0.10.pt
|
| 194 |
+
└── README.md
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
## Inference Deployment:
|
| 198 |
+
|
| 199 |
+
Configure the NVIDIA Ising-Decoding repository for inference, apply the Quantispect patch files above, and place the downloaded model checkpoint at `models/Quantispect_RF13_v1.0.10.pt`.
|
| 200 |
+
|
| 201 |
+
Run from the repository root:
|
| 202 |
+
|
| 203 |
+
```bash
|
| 204 |
+
cd Ising-Decoding
|
| 205 |
+
|
| 206 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
| 207 |
+
PYTHONUNBUFFERED=1 \
|
| 208 |
+
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
|
| 209 |
+
WORKFLOW=inference \
|
| 210 |
+
EXPERIMENT_NAME=infer_quantispect \
|
| 211 |
+
TORCH_COMPILE=0 \
|
| 212 |
+
EXTRA_PARAMS="+model_checkpoint_file=models/Quantispect_RF13_v1.0.10.pt" \
|
| 213 |
+
bash code/scripts/local_run.sh \
|
| 214 |
+
2>&1 | tee infer_quantispect.log
|
| 215 |
+
```
|
| 216 |
+
|
bias_subcard.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Bias Subcard
|
| 2 |
+
|
| 3 |
+
Field | Response
|
| 4 |
+
:-----|:---------
|
| 5 |
+
Participation considerations from adversely impacted groups [protected classes](https://www.senate.ca.gov/content/protected-classes) in model design and testing: | Not Applicable
|
| 6 |
+
Measures taken to mitigate against unwanted bias: | Not Applicable
|
| 7 |
+
Bias Metric (If Measured): | Not Applicable
|
code/model/factory.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Factory module for creating models.
|
| 17 |
+
|
| 18 |
+
Provides ModelFactory for instantiating pre-decoder models from config.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ModelFactory:
|
| 23 |
+
|
| 24 |
+
@staticmethod
|
| 25 |
+
def create_model(cfg):
|
| 26 |
+
if cfg.code == "surface":
|
| 27 |
+
return ModelFactory._create_surface_model(cfg)
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError("Invalid model name")
|
| 30 |
+
|
| 31 |
+
@staticmethod
|
| 32 |
+
def _create_surface_model(cfg):
|
| 33 |
+
if cfg.model.version == "predecoder_memory_v1":
|
| 34 |
+
from model.predecoder import PreDecoderModelMemory_v1
|
| 35 |
+
model = PreDecoderModelMemory_v1(cfg)
|
| 36 |
+
return model
|
| 37 |
+
elif cfg.model.version == "predecoder_sd_litenet_v1":
|
| 38 |
+
from model.predecoder_sd_litenet_v1 import PredecoderSDLiteNetV1
|
| 39 |
+
model = PredecoderSDLiteNetV1(
|
| 40 |
+
input_channels=getattr(cfg.model, "input_channels", 4),
|
| 41 |
+
out_channels=getattr(cfg.model, "out_channels", 4),
|
| 42 |
+
hidden_dim=getattr(cfg.model, "hidden_dim", 64),
|
| 43 |
+
bottleneck_dim=getattr(cfg.model, "bottleneck_dim", 16),
|
| 44 |
+
dropout_p=getattr(cfg.model, "dropout_p", 0.05),
|
| 45 |
+
)
|
| 46 |
+
return model
|
| 47 |
+
elif cfg.model.version == "predecoder_fasthyper_rf13_v1":
|
| 48 |
+
from model.predecoder_fasthyper_rf13_v1 import PredecoderFastHyperRF13V1
|
| 49 |
+
model = PredecoderFastHyperRF13V1(
|
| 50 |
+
input_channels=getattr(cfg.model, "input_channels", 4),
|
| 51 |
+
out_channels=getattr(cfg.model, "out_channels", 4),
|
| 52 |
+
hidden_dim=getattr(cfg.model, "hidden_dim", 96),
|
| 53 |
+
mid_dim=getattr(cfg.model, "mid_dim", 144),
|
| 54 |
+
mix_groups=getattr(cfg.model, "mix_groups", 6),
|
| 55 |
+
num_blocks=getattr(cfg.model, "num_blocks", 5),
|
| 56 |
+
stem_kernel_size=getattr(cfg.model, "stem_kernel_size", 3),
|
| 57 |
+
dropout_p=getattr(cfg.model, "dropout_p", 0.02),
|
| 58 |
+
gate_reduction=getattr(cfg.model, "gate_reduction", 4),
|
| 59 |
+
)
|
| 60 |
+
return model
|
| 61 |
+
else:
|
| 62 |
+
raise ValueError(f"Invalid model version: {cfg.model.version}")
|
code/model/predecoder_fasthyper_rf13_v1.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _choose_gn_groups(channels: int, max_groups: int = 8) -> int:
|
| 10 |
+
for g in range(min(max_groups, channels), 0, -1):
|
| 11 |
+
if channels % g == 0:
|
| 12 |
+
return g
|
| 13 |
+
return 1
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class _ChannelGate(nn.Module):
|
| 17 |
+
def __init__(self, channels: int, reduction: int = 4) -> None:
|
| 18 |
+
super().__init__()
|
| 19 |
+
hidden = max(channels // reduction, 8)
|
| 20 |
+
self.pool = nn.AdaptiveAvgPool3d(1)
|
| 21 |
+
self.fc1 = nn.Conv3d(channels, hidden, kernel_size=1, bias=True)
|
| 22 |
+
self.act = nn.GELU()
|
| 23 |
+
self.fc2 = nn.Conv3d(hidden, channels, kernel_size=1, bias=True)
|
| 24 |
+
self.gate = nn.Sigmoid()
|
| 25 |
+
|
| 26 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 27 |
+
s = self.pool(x)
|
| 28 |
+
s = self.fc1(s)
|
| 29 |
+
s = self.act(s)
|
| 30 |
+
s = self.fc2(s)
|
| 31 |
+
return x * self.gate(s)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class _FastHyperBlock(nn.Module):
|
| 35 |
+
"""
|
| 36 |
+
Efficient RF-expanding residual block.
|
| 37 |
+
|
| 38 |
+
Each block contributes one effective k=3 receptive-field expansion stage via
|
| 39 |
+
three parallel branches operating on the same expanded activation:
|
| 40 |
+
- spatial depthwise (1,3,3)
|
| 41 |
+
- temporal depthwise (3,1,1)
|
| 42 |
+
- grouped 3D mixing (3,3,3)
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
channels: int,
|
| 48 |
+
mid_dim: int,
|
| 49 |
+
mix_groups: int = 6,
|
| 50 |
+
dropout_p: float = 0.02,
|
| 51 |
+
gate_reduction: int = 4,
|
| 52 |
+
) -> None:
|
| 53 |
+
super().__init__()
|
| 54 |
+
gn1 = _choose_gn_groups(channels)
|
| 55 |
+
gn2 = _choose_gn_groups(mid_dim)
|
| 56 |
+
mix_groups = max(1, min(mix_groups, mid_dim))
|
| 57 |
+
while mid_dim % mix_groups != 0 and mix_groups > 1:
|
| 58 |
+
mix_groups -= 1
|
| 59 |
+
|
| 60 |
+
self.pre = nn.Sequential(
|
| 61 |
+
nn.GroupNorm(gn1, channels),
|
| 62 |
+
nn.Conv3d(channels, mid_dim, kernel_size=1, bias=True),
|
| 63 |
+
nn.GELU(),
|
| 64 |
+
)
|
| 65 |
+
self.spatial = nn.Sequential(
|
| 66 |
+
nn.Conv3d(
|
| 67 |
+
mid_dim,
|
| 68 |
+
mid_dim,
|
| 69 |
+
kernel_size=(1, 3, 3),
|
| 70 |
+
padding=(0, 1, 1),
|
| 71 |
+
groups=mid_dim,
|
| 72 |
+
bias=True,
|
| 73 |
+
),
|
| 74 |
+
nn.GELU(),
|
| 75 |
+
)
|
| 76 |
+
self.temporal = nn.Sequential(
|
| 77 |
+
nn.Conv3d(
|
| 78 |
+
mid_dim,
|
| 79 |
+
mid_dim,
|
| 80 |
+
kernel_size=(3, 1, 1),
|
| 81 |
+
padding=(1, 0, 0),
|
| 82 |
+
groups=mid_dim,
|
| 83 |
+
bias=True,
|
| 84 |
+
),
|
| 85 |
+
nn.GELU(),
|
| 86 |
+
)
|
| 87 |
+
self.mixed = nn.Sequential(
|
| 88 |
+
nn.GroupNorm(gn2, mid_dim),
|
| 89 |
+
nn.Conv3d(
|
| 90 |
+
mid_dim,
|
| 91 |
+
mid_dim,
|
| 92 |
+
kernel_size=3,
|
| 93 |
+
padding=1,
|
| 94 |
+
groups=mix_groups,
|
| 95 |
+
bias=True,
|
| 96 |
+
),
|
| 97 |
+
nn.GELU(),
|
| 98 |
+
)
|
| 99 |
+
self.fuse = nn.Sequential(
|
| 100 |
+
nn.Conv3d(mid_dim, channels, kernel_size=1, bias=True),
|
| 101 |
+
nn.GELU(),
|
| 102 |
+
)
|
| 103 |
+
self.gate = _ChannelGate(channels, reduction=gate_reduction)
|
| 104 |
+
self.dropout = nn.Dropout3d(dropout_p) if dropout_p > 0 else nn.Identity()
|
| 105 |
+
|
| 106 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 107 |
+
h = self.pre(x)
|
| 108 |
+
h = self.spatial(h) + self.temporal(h) + self.mixed(h)
|
| 109 |
+
h = self.fuse(h)
|
| 110 |
+
h = self.gate(h)
|
| 111 |
+
h = self.dropout(h)
|
| 112 |
+
return x + h
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class PredecoderFastHyperRF13V1(nn.Module):
|
| 116 |
+
"""
|
| 117 |
+
Faster-stronger candidate for model 6 under the public Ising-Decoding API.
|
| 118 |
+
|
| 119 |
+
Input / output shape:
|
| 120 |
+
(B, 4, T, D, D) -> (B, 4, T, D, D)
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
def __init__(
|
| 124 |
+
self,
|
| 125 |
+
input_channels: int = 4,
|
| 126 |
+
out_channels: int = 4,
|
| 127 |
+
hidden_dim: int = 96,
|
| 128 |
+
mid_dim: int = 144,
|
| 129 |
+
mix_groups: int = 6,
|
| 130 |
+
num_blocks: int = 5,
|
| 131 |
+
stem_kernel_size: int = 3,
|
| 132 |
+
dropout_p: float = 0.02,
|
| 133 |
+
gate_reduction: int = 4,
|
| 134 |
+
**_: Any,
|
| 135 |
+
) -> None:
|
| 136 |
+
super().__init__()
|
| 137 |
+
pad = stem_kernel_size // 2
|
| 138 |
+
gn = _choose_gn_groups(hidden_dim)
|
| 139 |
+
self.stem = nn.Sequential(
|
| 140 |
+
nn.Conv3d(
|
| 141 |
+
input_channels,
|
| 142 |
+
hidden_dim,
|
| 143 |
+
kernel_size=stem_kernel_size,
|
| 144 |
+
padding=pad,
|
| 145 |
+
bias=True,
|
| 146 |
+
),
|
| 147 |
+
nn.GroupNorm(gn, hidden_dim),
|
| 148 |
+
nn.GELU(),
|
| 149 |
+
)
|
| 150 |
+
self.blocks = nn.Sequential(*[
|
| 151 |
+
_FastHyperBlock(
|
| 152 |
+
channels=hidden_dim,
|
| 153 |
+
mid_dim=mid_dim,
|
| 154 |
+
mix_groups=mix_groups,
|
| 155 |
+
dropout_p=dropout_p,
|
| 156 |
+
gate_reduction=gate_reduction,
|
| 157 |
+
) for _ in range(num_blocks)
|
| 158 |
+
])
|
| 159 |
+
self.head = nn.Sequential(
|
| 160 |
+
nn.GroupNorm(gn, hidden_dim),
|
| 161 |
+
nn.Conv3d(hidden_dim, hidden_dim, kernel_size=1, bias=True),
|
| 162 |
+
nn.GELU(),
|
| 163 |
+
nn.Conv3d(hidden_dim, out_channels, kernel_size=1, bias=True),
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 167 |
+
x = self.stem(x)
|
| 168 |
+
x = self.blocks(x)
|
| 169 |
+
x = self.head(x)
|
| 170 |
+
return x
|
code/model/registry.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Public model registry for the early-access public release.
|
| 17 |
+
|
| 18 |
+
External users choose `model_id` in {1..6}. This registry maps model_id to:
|
| 19 |
+
- the underlying architecture parameters (num_filters, kernel_size)
|
| 20 |
+
- the model receptive field R (in rounds / distance units)
|
| 21 |
+
|
| 22 |
+
Receptive field convention matches `compare_receptive_field_with_window_data`
|
| 23 |
+
in `code/training/utils.py`:
|
| 24 |
+
R = 1 + sum_i (k_i - 1) for kernel sizes k_i (assumed odd, with same-padding)
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
|
| 29 |
+
from dataclasses import dataclass
|
| 30 |
+
from typing import Dict, List
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def compute_receptive_field(kernel_sizes: List[int]) -> int:
|
| 34 |
+
"""Compute receptive field R from a list of kernel sizes."""
|
| 35 |
+
if not kernel_sizes:
|
| 36 |
+
raise ValueError("kernel_sizes must be non-empty")
|
| 37 |
+
if any(not isinstance(k, int) for k in kernel_sizes):
|
| 38 |
+
raise ValueError(f"kernel_sizes must be ints, got: {kernel_sizes!r}")
|
| 39 |
+
if any(k <= 0 for k in kernel_sizes):
|
| 40 |
+
raise ValueError(f"kernel_sizes must be positive, got: {kernel_sizes!r}")
|
| 41 |
+
return 1 + sum(kernel_sizes) - len(kernel_sizes)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass(frozen=True)
|
| 45 |
+
class PublicModelSpec:
|
| 46 |
+
model_id: int
|
| 47 |
+
num_filters: List[int]
|
| 48 |
+
kernel_size: List[int]
|
| 49 |
+
receptive_field: int
|
| 50 |
+
model_version: str = "predecoder_memory_v1"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
_MODEL_SPECS: Dict[int, PublicModelSpec] = {
|
| 54 |
+
1:
|
| 55 |
+
PublicModelSpec(
|
| 56 |
+
model_id=1,
|
| 57 |
+
num_filters=[128, 128, 128, 4],
|
| 58 |
+
kernel_size=[3, 3, 3, 3],
|
| 59 |
+
receptive_field=compute_receptive_field([3, 3, 3, 3]),
|
| 60 |
+
),
|
| 61 |
+
2:
|
| 62 |
+
PublicModelSpec(
|
| 63 |
+
model_id=2,
|
| 64 |
+
num_filters=[256, 256, 256, 4],
|
| 65 |
+
kernel_size=[3, 3, 3, 3],
|
| 66 |
+
receptive_field=compute_receptive_field([3, 3, 3, 3]),
|
| 67 |
+
),
|
| 68 |
+
3:
|
| 69 |
+
PublicModelSpec(
|
| 70 |
+
model_id=3,
|
| 71 |
+
num_filters=[128, 128, 128, 4],
|
| 72 |
+
kernel_size=[5, 5, 5, 5],
|
| 73 |
+
receptive_field=compute_receptive_field([5, 5, 5, 5]),
|
| 74 |
+
),
|
| 75 |
+
4:
|
| 76 |
+
PublicModelSpec(
|
| 77 |
+
model_id=4,
|
| 78 |
+
num_filters=[128, 128, 128, 128, 128, 4],
|
| 79 |
+
kernel_size=[3, 3, 3, 3, 3, 3],
|
| 80 |
+
receptive_field=compute_receptive_field([3, 3, 3, 3, 3, 3]),
|
| 81 |
+
),
|
| 82 |
+
5:
|
| 83 |
+
PublicModelSpec(
|
| 84 |
+
model_id=5,
|
| 85 |
+
num_filters=[256, 256, 256, 256, 256, 4],
|
| 86 |
+
kernel_size=[3, 3, 3, 3, 3, 3],
|
| 87 |
+
receptive_field=compute_receptive_field([3, 3, 3, 3, 3, 3]),
|
| 88 |
+
),
|
| 89 |
+
6:
|
| 90 |
+
PublicModelSpec(
|
| 91 |
+
model_id=6,
|
| 92 |
+
num_filters=[96, 96, 96, 96, 96, 4],
|
| 93 |
+
kernel_size=[3, 3, 3, 3, 3, 3],
|
| 94 |
+
receptive_field=compute_receptive_field([3, 3, 3, 3, 3, 3]),
|
| 95 |
+
model_version="predecoder_fasthyper_rf13_v1",
|
| 96 |
+
),
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def get_model_spec(model_id: int) -> PublicModelSpec:
|
| 101 |
+
"""Return the public model spec for a given model_id (1..6)."""
|
| 102 |
+
try:
|
| 103 |
+
mid = int(model_id)
|
| 104 |
+
except Exception as e:
|
| 105 |
+
raise ValueError(f"model_id must be an int in [1..6], got: {model_id!r}") from e
|
| 106 |
+
if mid == 0:
|
| 107 |
+
raise ValueError("model_id=0 is not supported in the public release")
|
| 108 |
+
if mid not in _MODEL_SPECS:
|
| 109 |
+
raise ValueError(f"model_id must be in [1..6], got: {mid}")
|
| 110 |
+
return _MODEL_SPECS[mid]
|
code/scripts/local_run.sh
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
set -euo pipefail
|
| 18 |
+
|
| 19 |
+
# Minimal local runner.
|
| 20 |
+
#
|
| 21 |
+
# Examples:
|
| 22 |
+
# bash code/scripts/local_run.sh
|
| 23 |
+
# WORKFLOW=inference bash code/scripts/local_run.sh
|
| 24 |
+
# GPUS=4 bash code/scripts/local_run.sh
|
| 25 |
+
# CUDA_VISIBLE_DEVICES=1 bash code/scripts/local_run.sh # use only GPU 1
|
| 26 |
+
#
|
| 27 |
+
# ONNX / TRT fast inference (requires tensorrt; set ONNX_WORKFLOW before running):
|
| 28 |
+
# ONNX_WORKFLOW=1 WORKFLOW=inference bash code/scripts/local_run.sh # export ONNX only (inspect/reuse later)
|
| 29 |
+
# ONNX_WORKFLOW=2 WORKFLOW=inference bash code/scripts/local_run.sh # export ONNX + build TRT + run TRT inference
|
| 30 |
+
# ONNX_WORKFLOW=2 QUANT_FORMAT=int8 WORKFLOW=inference bash code/scripts/local_run.sh # INT8 quantized TRT
|
| 31 |
+
# ONNX_WORKFLOW=2 QUANT_FORMAT=fp8 WORKFLOW=inference bash code/scripts/local_run.sh # FP8 quantized TRT (requires nvidia-modelopt)
|
| 32 |
+
# ONNX_WORKFLOW=3 WORKFLOW=inference bash code/scripts/local_run.sh # load pre-built engine, skip export
|
| 33 |
+
#
|
| 34 |
+
# Decoder ablation study with cudaq-qec global decoders (requires cudaq-qec):
|
| 35 |
+
# WORKFLOW=decoder_ablation bash code/scripts/local_run.sh
|
| 36 |
+
#
|
| 37 |
+
# Decoder ablation with TRT pre-decoder + cudaq-qec global decoders
|
| 38 |
+
# (combines fast TRT inference for the neural pre-decoder with GPU-accelerated
|
| 39 |
+
# cudaq-qec decoders for the residual syndromes — full GPU pipeline end-to-end):
|
| 40 |
+
# ONNX_WORKFLOW=2 WORKFLOW=decoder_ablation bash code/scripts/local_run.sh # export+build TRT, then ablation
|
| 41 |
+
# ONNX_WORKFLOW=3 WORKFLOW=decoder_ablation bash code/scripts/local_run.sh # load existing engine, then ablation
|
| 42 |
+
#
|
| 43 |
+
# Notes:
|
| 44 |
+
# - Public config is `conf/config_public.yaml`. Users should edit only that file.
|
| 45 |
+
# - Training knobs are auto-managed in code (epochs, shots/epoch, batch schedule, etc.).
|
| 46 |
+
# - SafeTensors (optional): after training, convert the best .pt checkpoint with
|
| 47 |
+
# code/export/checkpoint_to_safetensors.py (see README), then pass the result as:
|
| 48 |
+
# PREDECODER_SAFETENSORS_CHECKPOINT=<path>.safetensors WORKFLOW=inference bash code/scripts/local_run.sh
|
| 49 |
+
|
| 50 |
+
EXPERIMENT_NAME="${EXPERIMENT_NAME:-test1}"
|
| 51 |
+
CONFIG_NAME="${CONFIG_NAME:-config_public}" # conf/<name>.yaml (no extension)
|
| 52 |
+
WORKFLOW="${WORKFLOW:-train}" # train | inference
|
| 53 |
+
WORKFLOW="$(echo "${WORKFLOW}" | tr '[:upper:]' '[:lower:]')"
|
| 54 |
+
GPUS="${GPUS:-}" # if empty, auto-detect
|
| 55 |
+
FRESH_START="${FRESH_START:-0}" # 1 => don't load checkpoint
|
| 56 |
+
EXTRA_PARAMS="${EXTRA_PARAMS:-}" # advanced hydra overrides (discouraged)
|
| 57 |
+
TORCH_COMPILE="${TORCH_COMPILE:-}" # 0/1 to disable/enable torch.compile
|
| 58 |
+
TORCH_COMPILE_MODE="${TORCH_COMPILE_MODE:-}" # optional: default | reduce-overhead | max-autotune
|
| 59 |
+
|
| 60 |
+
DISTANCE="${DISTANCE:-}"
|
| 61 |
+
N_ROUNDS="${N_ROUNDS:-}"
|
| 62 |
+
if [ $# -eq 1 ]; then DISTANCE="$1"; fi
|
| 63 |
+
if [ $# -eq 2 ]; then DISTANCE="$1"; N_ROUNDS="$2"; fi
|
| 64 |
+
|
| 65 |
+
SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)"
|
| 66 |
+
# local_run.sh lives at: <repo_root>/code/scripts/local_run.sh
|
| 67 |
+
# so repo_root is two levels up from SCRIPT_DIR.
|
| 68 |
+
REPO_ROOT="$(cd -- "${SCRIPT_DIR}/../.." && pwd)"
|
| 69 |
+
CODE_ROOT="${CODE_ROOT:-${REPO_ROOT}/code}"
|
| 70 |
+
|
| 71 |
+
# Default output locations live inside the repo (avoid surprises from generic env vars).
|
| 72 |
+
# Some environments set BASE_OUTPUT_DIR/LOG_BASE_DIR globally; ignore those by default to
|
| 73 |
+
# prevent creating confusing extra folders like /root/outputs or /root/logs.
|
| 74 |
+
if [ -n "${BASE_OUTPUT_DIR:-}" ] || [ -n "${LOG_BASE_DIR:-}" ]; then
|
| 75 |
+
echo "[local_run.sh] Note: ignoring BASE_OUTPUT_DIR/LOG_BASE_DIR from the environment."
|
| 76 |
+
echo "[local_run.sh] To override paths, use PREDECODER_BASE_OUTPUT_DIR / PREDECODER_LOG_BASE_DIR."
|
| 77 |
+
fi
|
| 78 |
+
BASE_OUTPUT_DIR="${PREDECODER_BASE_OUTPUT_DIR:-${REPO_ROOT}/outputs}"
|
| 79 |
+
LOG_BASE_DIR="${PREDECODER_LOG_BASE_DIR:-${REPO_ROOT}/logs}"
|
| 80 |
+
mkdir -p "${BASE_OUTPUT_DIR}" "${LOG_BASE_DIR}"
|
| 81 |
+
|
| 82 |
+
if [ "${FRESH_START}" -eq 1 ]; then
|
| 83 |
+
RESUME_FLAG="++load_checkpoint=False"
|
| 84 |
+
else
|
| 85 |
+
RESUME_FLAG="++load_checkpoint=True"
|
| 86 |
+
fi
|
| 87 |
+
|
| 88 |
+
# GPU-only runs: require a visible GPU and nvidia-smi.
|
| 89 |
+
if ! command -v nvidia-smi >/dev/null 2>&1; then
|
| 90 |
+
echo "[local_run.sh] Error: GPU-only mode requires nvidia-smi on PATH." >&2
|
| 91 |
+
echo "[local_run.sh] Hint: run on a GPU host or pass CUDA_VISIBLE_DEVICES." >&2
|
| 92 |
+
exit 1
|
| 93 |
+
fi
|
| 94 |
+
|
| 95 |
+
# Respect CUDA_VISIBLE_DEVICES if set; otherwise auto-detect via nvidia-smi.
|
| 96 |
+
if [ -z "${GPUS}" ]; then
|
| 97 |
+
if [ -n "${CUDA_VISIBLE_DEVICES:-}" ]; then
|
| 98 |
+
GPUS="$(python3 - <<'PY'
|
| 99 |
+
import os
|
| 100 |
+
v=os.environ.get('CUDA_VISIBLE_DEVICES','').strip()
|
| 101 |
+
print(len([x for x in v.split(',') if x.strip()]) or 1)
|
| 102 |
+
PY
|
| 103 |
+
)"
|
| 104 |
+
else
|
| 105 |
+
GPUS="$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l | tr -d ' ')"
|
| 106 |
+
fi
|
| 107 |
+
fi
|
| 108 |
+
|
| 109 |
+
if [ "${GPUS}" -le 0 ]; then
|
| 110 |
+
echo "[local_run.sh] Error: no GPUs detected. GPU-only mode requires CUDA." >&2
|
| 111 |
+
exit 1
|
| 112 |
+
fi
|
| 113 |
+
|
| 114 |
+
if [ -z "${MASTER_PORT:-}" ]; then
|
| 115 |
+
MASTER_PORT="$(python3 - <<'PY'
|
| 116 |
+
import socket
|
| 117 |
+
s=socket.socket()
|
| 118 |
+
s.bind(('127.0.0.1', 0))
|
| 119 |
+
print(s.getsockname()[1])
|
| 120 |
+
s.close()
|
| 121 |
+
PY
|
| 122 |
+
)"
|
| 123 |
+
export MASTER_PORT
|
| 124 |
+
fi
|
| 125 |
+
|
| 126 |
+
TIMESTAMP="$(date +%Y%m%d_%H%M%S)"
|
| 127 |
+
# Add nanoseconds to avoid collisions when launching multiple runs within the same second.
|
| 128 |
+
TIMESTAMP_NS="$(date +%Y%m%d_%H%M%S_%N)"
|
| 129 |
+
RUN_ID="${EXPERIMENT_NAME}_${TIMESTAMP}"
|
| 130 |
+
LOG_DIR="${LOG_BASE_DIR}/${RUN_ID}"
|
| 131 |
+
OUTPUT_DIR="${BASE_OUTPUT_DIR}/${EXPERIMENT_NAME}"
|
| 132 |
+
CHECKPOINT_DIR="${OUTPUT_DIR}/models"
|
| 133 |
+
mkdir -p "${LOG_DIR}" "${OUTPUT_DIR}" "${CHECKPOINT_DIR}"
|
| 134 |
+
|
| 135 |
+
# Force Hydra run dir to writable OUTPUT_DIR (avoids read-only repo/outputs in containers)
|
| 136 |
+
OVERRIDES="hydra.run.dir=${OUTPUT_DIR}"
|
| 137 |
+
if [ -n "${DISTANCE}" ]; then OVERRIDES+=" distance=${DISTANCE}"; fi
|
| 138 |
+
if [ -n "${N_ROUNDS}" ]; then OVERRIDES+=" n_rounds=${N_ROUNDS}"; fi
|
| 139 |
+
if [ -n "${EXTRA_PARAMS}" ]; then OVERRIDES+=" ${EXTRA_PARAMS}"; fi
|
| 140 |
+
|
| 141 |
+
CONFIG_SNAPSHOT_DIR="${OUTPUT_DIR}/config"
|
| 142 |
+
mkdir -p "${CONFIG_SNAPSHOT_DIR}"
|
| 143 |
+
CONFIG_PATH="${REPO_ROOT}/conf/${CONFIG_NAME}.yaml"
|
| 144 |
+
if [ -f "${CONFIG_PATH}" ]; then
|
| 145 |
+
# Never overwrite existing snapshots: keep full history.
|
| 146 |
+
base_yaml="${CONFIG_SNAPSHOT_DIR}/${CONFIG_NAME}_${TIMESTAMP_NS}.yaml"
|
| 147 |
+
dest_yaml="${base_yaml}"
|
| 148 |
+
i=0
|
| 149 |
+
while [ -e "${dest_yaml}" ]; do
|
| 150 |
+
i=$((i+1))
|
| 151 |
+
dest_yaml="${base_yaml%.yaml}_${i}.yaml"
|
| 152 |
+
done
|
| 153 |
+
cp "${CONFIG_PATH}" "${dest_yaml}"
|
| 154 |
+
# Also save the exact CLI overrides used for this run (useful when configs change over time).
|
| 155 |
+
base_ovr="${CONFIG_SNAPSHOT_DIR}/${CONFIG_NAME}_${TIMESTAMP_NS}.overrides.txt"
|
| 156 |
+
dest_ovr="${base_ovr}"
|
| 157 |
+
j=0
|
| 158 |
+
while [ -e "${dest_ovr}" ]; do
|
| 159 |
+
j=$((j+1))
|
| 160 |
+
dest_ovr="${base_ovr%.txt}_${j}.txt"
|
| 161 |
+
done
|
| 162 |
+
{
|
| 163 |
+
echo "workflow.task=${WORKFLOW}"
|
| 164 |
+
echo "exp_tag=${EXPERIMENT_NAME}"
|
| 165 |
+
echo "${RESUME_FLAG}"
|
| 166 |
+
echo "${OVERRIDES:-}"
|
| 167 |
+
} > "${dest_ovr}"
|
| 168 |
+
else
|
| 169 |
+
echo "[local_run.sh] Warning: could not find config file to snapshot: ${CONFIG_PATH}"
|
| 170 |
+
fi
|
| 171 |
+
|
| 172 |
+
echo "=========================================="
|
| 173 |
+
echo "Local run"
|
| 174 |
+
echo "=========================================="
|
| 175 |
+
echo "workflow.task: ${WORKFLOW}"
|
| 176 |
+
echo "config: ${CONFIG_NAME}"
|
| 177 |
+
echo "GPUS: ${GPUS} (CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-<unset>})"
|
| 178 |
+
echo "output: ${OUTPUT_DIR}"
|
| 179 |
+
echo "logs: ${LOG_DIR}"
|
| 180 |
+
echo "overrides: ${OVERRIDES:-<none>}"
|
| 181 |
+
echo "=========================================="
|
| 182 |
+
|
| 183 |
+
export PYTHONPATH="${CODE_ROOT}:${PYTHONPATH:-}"
|
| 184 |
+
export HDF5_USE_FILE_LOCKING=FALSE
|
| 185 |
+
export CUDNN_V8_API_ENABLED=1
|
| 186 |
+
export OMP_NUM_THREADS="$(nproc)"
|
| 187 |
+
export JOB_START_TIMESTAMP="$(date +%s)"
|
| 188 |
+
export JOB_START_DATETIME="$(date)"
|
| 189 |
+
if [ -n "${TORCH_COMPILE}" ]; then
|
| 190 |
+
export PREDECODER_TORCH_COMPILE="${TORCH_COMPILE}"
|
| 191 |
+
fi
|
| 192 |
+
if [ -n "${TORCH_COMPILE_MODE}" ]; then
|
| 193 |
+
export PREDECODER_TORCH_COMPILE_MODE="${TORCH_COMPILE_MODE}"
|
| 194 |
+
fi
|
| 195 |
+
|
| 196 |
+
# Prefer PREDECODER_PYTHON (cluster/container venv) when set
|
| 197 |
+
PYTHON_BIN="${PYTHON_BIN:-${PREDECODER_PYTHON:-python}}"
|
| 198 |
+
if ! command -v "${PYTHON_BIN}" >/dev/null 2>&1; then
|
| 199 |
+
if command -v python3 >/dev/null 2>&1; then
|
| 200 |
+
PYTHON_BIN="python3"
|
| 201 |
+
else
|
| 202 |
+
echo "[local_run.sh] Error: no python interpreter found on PATH." >&2
|
| 203 |
+
exit 1
|
| 204 |
+
fi
|
| 205 |
+
fi
|
| 206 |
+
|
| 207 |
+
# Ensure CUDA is usable before launching the workflow.
|
| 208 |
+
if ! "${PYTHON_BIN}" - <<'PY'
|
| 209 |
+
import sys
|
| 210 |
+
try:
|
| 211 |
+
import torch
|
| 212 |
+
except Exception as exc:
|
| 213 |
+
print(f"[local_run.sh] Error: PyTorch is required for GPU-only runs ({exc}).", file=sys.stderr)
|
| 214 |
+
sys.exit(1)
|
| 215 |
+
if not torch.cuda.is_available():
|
| 216 |
+
print("[local_run.sh] Error: torch.cuda.is_available() is false. GPU-only mode requires CUDA.", file=sys.stderr)
|
| 217 |
+
sys.exit(1)
|
| 218 |
+
PY
|
| 219 |
+
then
|
| 220 |
+
exit 1
|
| 221 |
+
fi
|
| 222 |
+
|
| 223 |
+
# Run from repo root so config defaults like `output: outputs/${exp_tag}` land in <repo_root>/outputs.
|
| 224 |
+
cd "${REPO_ROOT}"
|
| 225 |
+
|
| 226 |
+
LOG_FILE="${LOG_DIR}/${WORKFLOW}.log"
|
| 227 |
+
|
| 228 |
+
if [ "${GPUS}" -gt 1 ]; then
|
| 229 |
+
"${PYTHON_BIN}" -m torch.distributed.run \
|
| 230 |
+
--nproc_per_node="${GPUS}" \
|
| 231 |
+
--nnodes=1 \
|
| 232 |
+
--node_rank=0 \
|
| 233 |
+
--master_port="${MASTER_PORT}" \
|
| 234 |
+
code/workflows/run.py \
|
| 235 |
+
--config-name="${CONFIG_NAME}" \
|
| 236 |
+
workflow.task="${WORKFLOW}" \
|
| 237 |
+
+exp_tag="${EXPERIMENT_NAME}" \
|
| 238 |
+
${RESUME_FLAG} \
|
| 239 |
+
${OVERRIDES} \
|
| 240 |
+
2>&1 | tee -a "${LOG_FILE}"
|
| 241 |
+
else
|
| 242 |
+
"${PYTHON_BIN}" -u code/workflows/run.py \
|
| 243 |
+
--config-name="${CONFIG_NAME}" \
|
| 244 |
+
workflow.task="${WORKFLOW}" \
|
| 245 |
+
+exp_tag="${EXPERIMENT_NAME}" \
|
| 246 |
+
${RESUME_FLAG} \
|
| 247 |
+
${OVERRIDES} \
|
| 248 |
+
2>&1 | tee -a "${LOG_FILE}"
|
| 249 |
+
fi
|
| 250 |
+
|
| 251 |
+
cp -f "${LOG_FILE}" "${OUTPUT_DIR}/run.log"
|
| 252 |
+
echo "Done. Log: ${LOG_FILE}"
|
code/workflows/config_validator.py
ADDED
|
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Public config normalization / validation for the early-access public release.
|
| 17 |
+
|
| 18 |
+
Responsibilities:
|
| 19 |
+
- Fail-fast if the user tries to set hidden/experimental fields (via Hydra CLI `+foo=...`)
|
| 20 |
+
- Merge in hidden defaults (sourced from model_1_d9 config) so training runs with a minimal public config
|
| 21 |
+
- Apply the selected public model architecture (model_id -> model.*)
|
| 22 |
+
- Clamp distance/n_rounds to the model receptive field:
|
| 23 |
+
D = min(distance, R)
|
| 24 |
+
N_R = min(n_rounds, R)
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
import os
|
| 31 |
+
from typing import Any, Dict, Iterable, Tuple
|
| 32 |
+
|
| 33 |
+
from omegaconf import DictConfig, OmegaConf
|
| 34 |
+
|
| 35 |
+
from model.registry import PublicModelSpec, get_model_spec
|
| 36 |
+
|
| 37 |
+
_PUBLIC_ROTATION_TO_INTERNAL = {
|
| 38 |
+
# Public user-facing aliases
|
| 39 |
+
"O1": "XV",
|
| 40 |
+
"O2": "XH",
|
| 41 |
+
"O3": "ZV",
|
| 42 |
+
"O4": "ZH",
|
| 43 |
+
}
|
| 44 |
+
_INTERNAL_ROTATION_TO_PUBLIC = {v: k for k, v in _PUBLIC_ROTATION_TO_INTERNAL.items()}
|
| 45 |
+
|
| 46 |
+
_PUBLIC_MODEL_ID_TO_LR = {
|
| 47 |
+
1: 3e-4,
|
| 48 |
+
2: 2e-4,
|
| 49 |
+
3: 1e-4,
|
| 50 |
+
4: 2e-4,
|
| 51 |
+
5: 1e-4,
|
| 52 |
+
6: 2e-4,
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _default_precomputed_frames_dir() -> str:
|
| 57 |
+
"""
|
| 58 |
+
Default location for precomputed frames shipped with (or generated inside) this repo.
|
| 59 |
+
|
| 60 |
+
We compute this path relative to the codebase so it is stable regardless of the user's
|
| 61 |
+
current working directory.
|
| 62 |
+
"""
|
| 63 |
+
# .../<repo>/code/workflows/config_validator.py -> repo root is parents[2]
|
| 64 |
+
repo_root = Path(__file__).resolve().parents[2]
|
| 65 |
+
return str((repo_root / "frames_data").resolve())
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _get_env_bool(name: str, default: bool) -> bool:
|
| 69 |
+
raw = os.environ.get(name)
|
| 70 |
+
if raw is None:
|
| 71 |
+
return default
|
| 72 |
+
val = str(raw).strip().lower()
|
| 73 |
+
if val in ("0", "false", "no", "off", ""):
|
| 74 |
+
return False
|
| 75 |
+
return True
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _normalize_code_rotation(value: Any) -> str:
|
| 79 |
+
"""
|
| 80 |
+
Normalize code rotation values.
|
| 81 |
+
|
| 82 |
+
Public config accepts O1..O4 for user convenience. Internally we keep using:
|
| 83 |
+
XV, XH, ZV, ZH (as expected by SurfaceCode / MemoryCircuit).
|
| 84 |
+
"""
|
| 85 |
+
if value is None:
|
| 86 |
+
return value
|
| 87 |
+
s = str(value).strip().upper()
|
| 88 |
+
if s in _PUBLIC_ROTATION_TO_INTERNAL:
|
| 89 |
+
return _PUBLIC_ROTATION_TO_INTERNAL[s]
|
| 90 |
+
if s in _INTERNAL_ROTATION_TO_PUBLIC:
|
| 91 |
+
return s
|
| 92 |
+
raise ValueError(
|
| 93 |
+
f"Invalid data.code_rotation={value!r}. "
|
| 94 |
+
f"Use one of {sorted(_PUBLIC_ROTATION_TO_INTERNAL.keys())} (public) "
|
| 95 |
+
f"or {sorted(_INTERNAL_ROTATION_TO_PUBLIC.keys())} (internal)."
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _base_hidden_defaults_dict() -> Dict[str, Any]:
|
| 100 |
+
"""
|
| 101 |
+
Baseline config used as the source-of-truth for hidden defaults.
|
| 102 |
+
|
| 103 |
+
IMPORTANT: We intentionally embed these defaults directly in code so the public
|
| 104 |
+
release does not ship internal/legacy config files. These values were copied
|
| 105 |
+
from the historical `config_pre_decoder_memory_surface_model_1_d9.yaml`.
|
| 106 |
+
"""
|
| 107 |
+
base_output_dir = os.environ.get("PREDECODER_BASE_OUTPUT_DIR", "outputs")
|
| 108 |
+
output_root = f"{base_output_dir}/${{exp_tag}}"
|
| 109 |
+
return {
|
| 110 |
+
"exp_tag": "pre-decoder",
|
| 111 |
+
"output": output_root,
|
| 112 |
+
"hydra": {
|
| 113 |
+
"run": {
|
| 114 |
+
"dir": "${output}"
|
| 115 |
+
},
|
| 116 |
+
"output_subdir": "hydra"
|
| 117 |
+
},
|
| 118 |
+
"resume_dir": f"{output_root}/models",
|
| 119 |
+
"enable_fp16": False,
|
| 120 |
+
"enable_bf16": False,
|
| 121 |
+
"enable_matmul_tf32": True,
|
| 122 |
+
"enable_cudnn_tf32": True,
|
| 123 |
+
"enable_cudnn_benchmark": True,
|
| 124 |
+
"torch_compile": _get_env_bool("PREDECODER_TORCH_COMPILE", True),
|
| 125 |
+
"torch_compile_mode": os.environ.get("PREDECODER_TORCH_COMPILE_MODE", "default"),
|
| 126 |
+
"load_checkpoint": False,
|
| 127 |
+
"code": "surface",
|
| 128 |
+
"distance": 9,
|
| 129 |
+
"n_rounds": 9,
|
| 130 |
+
"multiple_distances": [13, 13],
|
| 131 |
+
"multiple_rounds": [13, 13],
|
| 132 |
+
"use_multiple_patches": False,
|
| 133 |
+
"meas_basis": "both",
|
| 134 |
+
"workflow": {
|
| 135 |
+
"task": "train"
|
| 136 |
+
},
|
| 137 |
+
"data":
|
| 138 |
+
{
|
| 139 |
+
"timelike_he": True,
|
| 140 |
+
"num_he_cycles": 1,
|
| 141 |
+
"use_weight2_timelike": False,
|
| 142 |
+
"max_passes_w1": 8,
|
| 143 |
+
"max_passes_w2": 4,
|
| 144 |
+
"decompose_y": True,
|
| 145 |
+
"p_error": None,
|
| 146 |
+
"p_min": 0.001,
|
| 147 |
+
"p_max": 0.006,
|
| 148 |
+
"error_mode": "circuit_level_surface_custom",
|
| 149 |
+
# Public config overrides this; keep the historical default for completeness.
|
| 150 |
+
"precomputed_frames_dir": _default_precomputed_frames_dir(),
|
| 151 |
+
"enable_correlated_pymatching": False,
|
| 152 |
+
"code_rotation": "XV",
|
| 153 |
+
"noise_model": None,
|
| 154 |
+
},
|
| 155 |
+
"model":
|
| 156 |
+
{
|
| 157 |
+
"version": "predecoder_memory_v1",
|
| 158 |
+
"dropout_p": 0.05,
|
| 159 |
+
"activation": "gelu",
|
| 160 |
+
"num_filters": [128, 128, 128, 4],
|
| 161 |
+
"kernel_size": [3, 3, 3, 3],
|
| 162 |
+
"input_channels": 4,
|
| 163 |
+
"out_channels": 4,
|
| 164 |
+
},
|
| 165 |
+
"datapipe": "memory",
|
| 166 |
+
"data_method": "train",
|
| 167 |
+
"train":
|
| 168 |
+
{
|
| 169 |
+
# Production baseline: 2^26 shots / epoch when training with 8 GPUs.
|
| 170 |
+
# The training script will auto-scale this based on detected world size / GPU count.
|
| 171 |
+
"num_samples": 67108864,
|
| 172 |
+
"accumulate_steps": 2,
|
| 173 |
+
"checkpoint_interval": 1,
|
| 174 |
+
"save_every_datasets": 5,
|
| 175 |
+
"epochs": 100,
|
| 176 |
+
},
|
| 177 |
+
# NOTE: temporarily reduced for faster iteration during refactor/testing.
|
| 178 |
+
"val": {
|
| 179 |
+
"num_samples": 65536,
|
| 180 |
+
"threshold": 0.5,
|
| 181 |
+
"trials": 1
|
| 182 |
+
},
|
| 183 |
+
"optimizer_type": "Lion",
|
| 184 |
+
"optimizer": {
|
| 185 |
+
"lr": 1e-4,
|
| 186 |
+
"weight_decay": 1e-7,
|
| 187 |
+
"beta2": 0.95
|
| 188 |
+
},
|
| 189 |
+
"lr_scheduler":
|
| 190 |
+
{
|
| 191 |
+
"type": "warmup_then_decay",
|
| 192 |
+
"warmup_steps": 100,
|
| 193 |
+
"milestones": [0.25, 0.5, 1.0],
|
| 194 |
+
"gamma": 0.7,
|
| 195 |
+
"min_lr": 1e-6,
|
| 196 |
+
},
|
| 197 |
+
"batch_schedule":
|
| 198 |
+
{
|
| 199 |
+
"enabled": True,
|
| 200 |
+
"initial": 256,
|
| 201 |
+
"final": 1024,
|
| 202 |
+
"start_epoch": 1,
|
| 203 |
+
"end_epoch": 3,
|
| 204 |
+
},
|
| 205 |
+
"validation_ler": True,
|
| 206 |
+
"early_stopping": {
|
| 207 |
+
"enabled": True,
|
| 208 |
+
"patience": 100
|
| 209 |
+
},
|
| 210 |
+
"time_based_early_stopping": {
|
| 211 |
+
"enabled": False,
|
| 212 |
+
"safety_margin_minutes": 5
|
| 213 |
+
},
|
| 214 |
+
"ema": {
|
| 215 |
+
"use_ema": True,
|
| 216 |
+
"decay": 0.0001
|
| 217 |
+
},
|
| 218 |
+
"test":
|
| 219 |
+
{
|
| 220 |
+
"num_samples": 262144,
|
| 221 |
+
"trials": 1,
|
| 222 |
+
"distance": 9,
|
| 223 |
+
"n_rounds": 9,
|
| 224 |
+
"noise_model": "train",
|
| 225 |
+
"p_error": 0.006,
|
| 226 |
+
"dataloader":
|
| 227 |
+
{
|
| 228 |
+
"batch_size": 64,
|
| 229 |
+
"num_workers": 0,
|
| 230 |
+
"persistent_workers": False,
|
| 231 |
+
},
|
| 232 |
+
"latency_num_samples": 1000,
|
| 233 |
+
"sampler": {
|
| 234 |
+
"shuffle": False,
|
| 235 |
+
"drop_last": False
|
| 236 |
+
},
|
| 237 |
+
"syn_red": "full",
|
| 238 |
+
"th_data": 0.0,
|
| 239 |
+
"th_syn": 0.0,
|
| 240 |
+
"sampling_mode": "threshold",
|
| 241 |
+
"temperature": 0.0,
|
| 242 |
+
"temperature_data": None,
|
| 243 |
+
"temperature_syn": None,
|
| 244 |
+
"per_round": False,
|
| 245 |
+
"meas_basis_test": "both",
|
| 246 |
+
"use_model_checkpoint": -1,
|
| 247 |
+
},
|
| 248 |
+
"threshold":
|
| 249 |
+
{
|
| 250 |
+
"p_values": [0.002, 0.003, 0.004, 0.005, 0.006, 0.007, 0.008],
|
| 251 |
+
"distances": [5, 7, 9, 11, 13],
|
| 252 |
+
"n_rounds": None,
|
| 253 |
+
},
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def _select(cfg: DictConfig, key: str) -> Tuple[bool, Any]:
|
| 258 |
+
"""
|
| 259 |
+
Return (exists, value) for a dot-path in cfg.
|
| 260 |
+
Note: OmegaConf.select returns None both for missing keys and explicit nulls,
|
| 261 |
+
so we treat a key as existing iff it is present in the underlying container.
|
| 262 |
+
"""
|
| 263 |
+
# OmegaConf doesn't provide a direct 'has_key' for dotted paths; implement via container walk.
|
| 264 |
+
cur: Any = cfg
|
| 265 |
+
parts = key.split(".")
|
| 266 |
+
for p in parts:
|
| 267 |
+
if not isinstance(cur, DictConfig) or p not in cur:
|
| 268 |
+
return False, None
|
| 269 |
+
cur = cur[p]
|
| 270 |
+
return True, cur
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def _assert_not_present(cfg: DictConfig, keys: Iterable[str], *, context: str) -> None:
|
| 274 |
+
for k in keys:
|
| 275 |
+
exists, _ = _select(cfg, k)
|
| 276 |
+
if exists:
|
| 277 |
+
raise ValueError(
|
| 278 |
+
f"Config field '{k}' is not supported in the public release ({context}). "
|
| 279 |
+
f"Remove it from the config/CLI overrides."
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def validate_public_config(cfg: DictConfig) -> PublicModelSpec:
|
| 284 |
+
"""
|
| 285 |
+
Validate the user-facing config BEFORE we merge in hidden defaults.
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
PublicModelSpec for cfg.model_id (validated).
|
| 289 |
+
"""
|
| 290 |
+
# model_id must exist in public config
|
| 291 |
+
if "model_id" not in cfg:
|
| 292 |
+
raise ValueError("Missing required field: 'model_id' (choose 1..5).")
|
| 293 |
+
|
| 294 |
+
model_spec = get_model_spec(cfg.model_id)
|
| 295 |
+
|
| 296 |
+
# Public config requires distance/n_rounds (evaluation targets)
|
| 297 |
+
if "distance" not in cfg or "n_rounds" not in cfg:
|
| 298 |
+
raise ValueError("Missing required fields: 'distance' and 'n_rounds'.")
|
| 299 |
+
try:
|
| 300 |
+
d = int(cfg.distance)
|
| 301 |
+
r = int(cfg.n_rounds)
|
| 302 |
+
except Exception as e:
|
| 303 |
+
raise ValueError(
|
| 304 |
+
f"Invalid distance/n_rounds: distance={cfg.distance!r}, n_rounds={cfg.n_rounds!r}"
|
| 305 |
+
) from e
|
| 306 |
+
if d <= 0 or r <= 0:
|
| 307 |
+
raise ValueError(
|
| 308 |
+
f"Invalid distance/n_rounds: distance={d}, n_rounds={r} (must be positive integers)"
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
if "train" in cfg:
|
| 312 |
+
raise ValueError("Config field 'train' is not supported in the public release.")
|
| 313 |
+
if "val" in cfg:
|
| 314 |
+
raise ValueError("Config field 'val' is not supported in the public release.")
|
| 315 |
+
if "test" in cfg:
|
| 316 |
+
raise ValueError("Config field 'test' is not supported in the public release.")
|
| 317 |
+
|
| 318 |
+
# Fail-fast on known hidden fields if the user tries to inject them.
|
| 319 |
+
_assert_not_present(
|
| 320 |
+
cfg,
|
| 321 |
+
keys=(
|
| 322 |
+
# output paths are managed by the runner scripts; not user-configurable in public release
|
| 323 |
+
"output",
|
| 324 |
+
"resume_dir",
|
| 325 |
+
# precision / tf32 knobs (always fp32 + tf32 enabled)
|
| 326 |
+
"enable_fp16",
|
| 327 |
+
"enable_bf16",
|
| 328 |
+
"enable_matmul_tf32",
|
| 329 |
+
"enable_cudnn_tf32",
|
| 330 |
+
# always both bases
|
| 331 |
+
"meas_basis",
|
| 332 |
+
# multi-patch curriculum mode (hidden)
|
| 333 |
+
"use_multiple_patches",
|
| 334 |
+
"multiple_distances",
|
| 335 |
+
"multiple_rounds",
|
| 336 |
+
# optimizer knobs (only optimizer.lr exposed)
|
| 337 |
+
"optimizer",
|
| 338 |
+
"optimizer_type",
|
| 339 |
+
"lr_scheduler",
|
| 340 |
+
"batch_schedule",
|
| 341 |
+
# obsolete/confusing
|
| 342 |
+
"train.save_every_datasets",
|
| 343 |
+
# validation hidden knobs
|
| 344 |
+
"val.threshold",
|
| 345 |
+
"val.trials",
|
| 346 |
+
# early stopping extras hidden
|
| 347 |
+
"time_based_early_stopping",
|
| 348 |
+
"ema",
|
| 349 |
+
),
|
| 350 |
+
context="hidden field override",
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
# Restrict cfg.data to a small public surface (others can be too experimental).
|
| 354 |
+
if "data" in cfg and isinstance(cfg.data, DictConfig):
|
| 355 |
+
# NOTE: precomputed frames path is intentionally hidden from the public config.
|
| 356 |
+
# We default it internally to <repo>/frames_data (see _default_precomputed_frames_dir).
|
| 357 |
+
if "precomputed_frames_dir" in cfg.data:
|
| 358 |
+
raise ValueError(
|
| 359 |
+
"Config field 'data.precomputed_frames_dir' is not supported in the public release. "
|
| 360 |
+
"Remove it from the config/CLI overrides."
|
| 361 |
+
)
|
| 362 |
+
allowed_data_keys = {"code_rotation", "noise_model"}
|
| 363 |
+
for k in cfg.data.keys():
|
| 364 |
+
if k not in allowed_data_keys:
|
| 365 |
+
raise ValueError(
|
| 366 |
+
f"Config field 'data.{k}' is not supported in the public release. "
|
| 367 |
+
f"Allowed data fields are: {sorted(allowed_data_keys)}"
|
| 368 |
+
)
|
| 369 |
+
# Validate rotation value (accept O1..O4; also allow internal XV/XH/ZV/ZH for compatibility).
|
| 370 |
+
if "code_rotation" in cfg.data:
|
| 371 |
+
_normalize_code_rotation(cfg.data.code_rotation)
|
| 372 |
+
|
| 373 |
+
# Restrict optimizer sub-keys: only lr is public.
|
| 374 |
+
if "optimizer" in cfg and isinstance(cfg.optimizer, DictConfig):
|
| 375 |
+
for k in cfg.optimizer.keys():
|
| 376 |
+
if k != "lr":
|
| 377 |
+
raise ValueError(
|
| 378 |
+
f"Config field 'optimizer.{k}' is not supported in the public release. "
|
| 379 |
+
f"Only 'optimizer.lr' is user-configurable."
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
return model_spec
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def clamp_to_receptive_field(cfg: DictConfig, R: int) -> None:
|
| 386 |
+
"""In-place clamp of cfg.distance and cfg.n_rounds to receptive field R."""
|
| 387 |
+
if not isinstance(R, int) or R <= 0:
|
| 388 |
+
raise ValueError(f"Invalid receptive field R={R!r}")
|
| 389 |
+
if "distance" not in cfg or "n_rounds" not in cfg:
|
| 390 |
+
raise ValueError("Both 'distance' and 'n_rounds' must be present in config.")
|
| 391 |
+
cfg.distance = int(min(int(cfg.distance), R))
|
| 392 |
+
cfg.n_rounds = int(min(int(cfg.n_rounds), R))
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def apply_public_defaults_and_model(cfg: DictConfig, model_spec: PublicModelSpec) -> DictConfig:
|
| 396 |
+
"""
|
| 397 |
+
Merge hidden defaults and apply public model settings.
|
| 398 |
+
|
| 399 |
+
Returns a new DictConfig (does not mutate input).
|
| 400 |
+
"""
|
| 401 |
+
base_cfg = OmegaConf.create(_base_hidden_defaults_dict())
|
| 402 |
+
|
| 403 |
+
# Merge: base provides full training-ready config; public cfg overrides user-visible fields.
|
| 404 |
+
merged = OmegaConf.merge(base_cfg, cfg)
|
| 405 |
+
OmegaConf.set_struct(merged, False)
|
| 406 |
+
|
| 407 |
+
# In the public release:
|
| 408 |
+
# - cfg.distance / cfg.n_rounds are the *evaluation targets* the user cares about
|
| 409 |
+
# - training always uses distance=n_rounds=R (the model receptive field)
|
| 410 |
+
requested_distance = int(merged.distance)
|
| 411 |
+
requested_n_rounds = int(merged.n_rounds)
|
| 412 |
+
|
| 413 |
+
# Enforce public invariants (hidden from user)
|
| 414 |
+
merged.enable_fp16 = False
|
| 415 |
+
merged.enable_bf16 = False
|
| 416 |
+
merged.enable_matmul_tf32 = True
|
| 417 |
+
merged.enable_cudnn_tf32 = True
|
| 418 |
+
|
| 419 |
+
merged.meas_basis = "both"
|
| 420 |
+
|
| 421 |
+
# Disable multi-patch mode explicitly
|
| 422 |
+
if "data" not in merged:
|
| 423 |
+
merged.data = {}
|
| 424 |
+
merged.data.use_multiple_patches = False
|
| 425 |
+
merged.multiple_distances = None
|
| 426 |
+
merged.multiple_rounds = None
|
| 427 |
+
|
| 428 |
+
# Always use repo-relative frames_data by default (hidden from public config).
|
| 429 |
+
merged.data.precomputed_frames_dir = _default_precomputed_frames_dir()
|
| 430 |
+
|
| 431 |
+
# Apply model architecture from registry
|
| 432 |
+
if "model" not in merged:
|
| 433 |
+
merged.model = {}
|
| 434 |
+
merged.model.version = model_spec.model_version
|
| 435 |
+
merged.model.num_filters = list(model_spec.num_filters)
|
| 436 |
+
merged.model.kernel_size = list(model_spec.kernel_size)
|
| 437 |
+
|
| 438 |
+
# Public release: hard-code optimizer.lr based on model choice.
|
| 439 |
+
# (User is not allowed to override optimizer settings.)
|
| 440 |
+
if "optimizer" not in merged:
|
| 441 |
+
merged.optimizer = {}
|
| 442 |
+
lr = _PUBLIC_MODEL_ID_TO_LR.get(int(model_spec.model_id))
|
| 443 |
+
if lr is None:
|
| 444 |
+
raise ValueError(f"No public LR mapping for model_id={model_spec.model_id!r}")
|
| 445 |
+
merged.optimizer.lr = float(lr)
|
| 446 |
+
|
| 447 |
+
# Public release: production-like batch schedule defaults.
|
| 448 |
+
# Target behavior: per-GPU batch size is 512 in the first epoch, 2048 thereafter.
|
| 449 |
+
# Model 3 is heavier; use a smaller schedule there.
|
| 450 |
+
if "batch_schedule" not in merged:
|
| 451 |
+
merged.batch_schedule = {}
|
| 452 |
+
merged.batch_schedule.enabled = True
|
| 453 |
+
if int(model_spec.model_id) == 3:
|
| 454 |
+
merged.batch_schedule.initial = 256
|
| 455 |
+
merged.batch_schedule.final = 1024
|
| 456 |
+
elif int(model_spec.model_id) == 6:
|
| 457 |
+
merged.batch_schedule.initial = 256
|
| 458 |
+
merged.batch_schedule.final = 512
|
| 459 |
+
else:
|
| 460 |
+
merged.batch_schedule.initial = 512
|
| 461 |
+
merged.batch_schedule.final = 2048
|
| 462 |
+
# "First epoch only" initial, then final for all later epochs.
|
| 463 |
+
merged.batch_schedule.start_epoch = 0
|
| 464 |
+
merged.batch_schedule.end_epoch = 0
|
| 465 |
+
|
| 466 |
+
# Public release: training epochs default to production values,
|
| 467 |
+
# but honor explicit user overrides for quick validation runs.
|
| 468 |
+
if "train" not in merged:
|
| 469 |
+
merged.train = {}
|
| 470 |
+
if not ("train" in cfg and isinstance(cfg.train, DictConfig) and "epochs" in cfg.train):
|
| 471 |
+
merged.train.epochs = 100
|
| 472 |
+
|
| 473 |
+
# Public release: validation sample count defaults to production values,
|
| 474 |
+
# but honor explicit user overrides for quick validation runs.
|
| 475 |
+
if "val" not in merged:
|
| 476 |
+
merged.val = {}
|
| 477 |
+
# NOTE: temporarily reduced for faster iteration during refactor/testing.
|
| 478 |
+
if not ("val" in cfg and isinstance(cfg.val, DictConfig) and "num_samples" in cfg.val):
|
| 479 |
+
merged.val.num_samples = 65536
|
| 480 |
+
|
| 481 |
+
# Train vs inference window semantics (public release):
|
| 482 |
+
# - Top-level cfg.distance / cfg.n_rounds are the user-specified *evaluation* targets.
|
| 483 |
+
# - Training always runs on the model receptive field R (distance=n_rounds=R).
|
| 484 |
+
task = str(getattr(getattr(merged, "workflow", None), "task", "train")).strip().lower()
|
| 485 |
+
R = int(model_spec.receptive_field)
|
| 486 |
+
if R <= 0:
|
| 487 |
+
raise ValueError(f"Invalid receptive field R={R!r}")
|
| 488 |
+
if task == "train":
|
| 489 |
+
merged.distance = R
|
| 490 |
+
merged.n_rounds = R
|
| 491 |
+
else:
|
| 492 |
+
merged.distance = int(requested_distance)
|
| 493 |
+
merged.n_rounds = int(requested_n_rounds)
|
| 494 |
+
|
| 495 |
+
# Public code_rotation aliases: normalize O1..O4 -> internal XV/XH/ZV/ZH.
|
| 496 |
+
if "data" in merged and "code_rotation" in merged.data:
|
| 497 |
+
merged.data.code_rotation = _normalize_code_rotation(merged.data.code_rotation)
|
| 498 |
+
|
| 499 |
+
# Test/evaluation config is hidden and always uses the user-requested window.
|
| 500 |
+
if "test" not in merged:
|
| 501 |
+
merged.test = {}
|
| 502 |
+
if not ("test" in cfg and isinstance(cfg.test, DictConfig) and "num_samples" in cfg.test):
|
| 503 |
+
merged.test.num_samples = 262144
|
| 504 |
+
merged.test.distance = int(requested_distance)
|
| 505 |
+
merged.test.n_rounds = int(requested_n_rounds)
|
| 506 |
+
merged.test.noise_model = "train"
|
| 507 |
+
return merged
|
code/workflows/run.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import hydra, sys, torch, os, json, numpy as np
|
| 17 |
+
from omegaconf import DictConfig, OmegaConf
|
| 18 |
+
from training.train import main as train_main
|
| 19 |
+
from model.factory import ModelFactory
|
| 20 |
+
from data.factory import DatapipeFactory
|
| 21 |
+
from hydra.utils import to_absolute_path
|
| 22 |
+
from workflows.config_validator import (
|
| 23 |
+
apply_public_defaults_and_model,
|
| 24 |
+
validate_public_config,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
from training.distributed import DistributedManager
|
| 28 |
+
|
| 29 |
+
from torch.utils.data import DataLoader
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _ensure_inference_io_channels(cfg):
|
| 33 |
+
# 1) Ensure out_channels matches the model’s heads (4: z_data, x_data, syn_x, syn_z)
|
| 34 |
+
if not getattr(cfg.model, "out_channels", None) or cfg.model.out_channels == 0:
|
| 35 |
+
cfg.model.out_channels = 4
|
| 36 |
+
|
| 37 |
+
# 2) Infer input_channels from a single inference sample if not set
|
| 38 |
+
if not getattr(cfg.model, "input_channels", None) or cfg.model.input_channels == 0:
|
| 39 |
+
ds = DatapipeFactory.create_datapipe_inference(cfg)
|
| 40 |
+
tmp = DataLoader(ds, batch_size=1)
|
| 41 |
+
sample = next(iter(tmp))
|
| 42 |
+
cfg.model.input_channels = int(sample["trainX"].shape[1])
|
| 43 |
+
|
| 44 |
+
# 3) Keep num_filters consistent with out_channels
|
| 45 |
+
if hasattr(cfg.model, "num_filters"):
|
| 46 |
+
filters = list(cfg.model.num_filters)
|
| 47 |
+
if filters and filters[-1] != cfg.model.out_channels:
|
| 48 |
+
print(
|
| 49 |
+
f"[run] Adjusting model.num_filters[-1] {filters[-1]} -> {cfg.model.out_channels}"
|
| 50 |
+
)
|
| 51 |
+
filters[-1] = cfg.model.out_channels
|
| 52 |
+
cfg.model.num_filters = filters
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@hydra.main(version_base="1.3", config_path="../../conf", config_name="config")
|
| 56 |
+
def run(cfg: DictConfig) -> None:
|
| 57 |
+
# Early-access public release: validate public surface, then merge in hidden defaults.
|
| 58 |
+
# NOTE: Validation is done BEFORE merging defaults so we can fail fast on injected fields.
|
| 59 |
+
model_spec = validate_public_config(cfg)
|
| 60 |
+
cfg = apply_public_defaults_and_model(cfg, model_spec)
|
| 61 |
+
|
| 62 |
+
torch.backends.cuda.matmul.allow_tf32 = cfg.enable_matmul_tf32
|
| 63 |
+
torch.backends.cudnn.allow_tf32 = cfg.enable_cudnn_tf32
|
| 64 |
+
|
| 65 |
+
if cfg.code == "surface" or cfg.code == "surface_partition":
|
| 66 |
+
run_surface(cfg)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def run_surface(cfg: DictConfig):
|
| 70 |
+
if cfg.workflow.task == "train":
|
| 71 |
+
train_main(cfg)
|
| 72 |
+
elif cfg.workflow.task == "threshold":
|
| 73 |
+
raise ValueError(
|
| 74 |
+
"workflow.task='threshold' has been renamed to workflow.task='inference'. "
|
| 75 |
+
"Please update your config/env var to WORKFLOW=inference."
|
| 76 |
+
)
|
| 77 |
+
elif cfg.workflow.task == "inference":
|
| 78 |
+
from evaluation.inference import run_inference
|
| 79 |
+
DistributedManager.initialize()
|
| 80 |
+
dist = DistributedManager()
|
| 81 |
+
model = _load_model(cfg, dist)
|
| 82 |
+
run_inference(model, dist.device, dist, cfg)
|
| 83 |
+
elif cfg.workflow.task == "data":
|
| 84 |
+
DistributedManager.initialize()
|
| 85 |
+
dist = DistributedManager()
|
| 86 |
+
train_loader, _ = DatapipeFactory.create_dataloader(cfg, dist.world_size, dist.rank)
|
| 87 |
+
for j, dl in enumerate(train_loader):
|
| 88 |
+
print(f"Batch {j}: syndrome_shape: {dl['syndrome'].shape}")
|
| 89 |
+
elif cfg.workflow.task == "decoder_ablation":
|
| 90 |
+
from evaluation.failure_analysis import decoder_ablation_study
|
| 91 |
+
DistributedManager.initialize()
|
| 92 |
+
dist = DistributedManager()
|
| 93 |
+
model = _load_model(cfg, dist)
|
| 94 |
+
decoder_ablation_study(model, dist.device, dist, cfg)
|
| 95 |
+
elif cfg.workflow.task in ("sampling", "visualize"):
|
| 96 |
+
raise ValueError(
|
| 97 |
+
f"workflow.task={cfg.workflow.task!r} is not supported in the early-access public release. "
|
| 98 |
+
"Supported workflows: train, inference, decoder_ablation."
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def find_best_model(path, *, rank: int = 0):
|
| 103 |
+
if rank == 0:
|
| 104 |
+
print(f"Searching for best model in: {path}")
|
| 105 |
+
if not os.path.isdir(path):
|
| 106 |
+
raise FileNotFoundError(f"Model directory does not exist: {path}")
|
| 107 |
+
|
| 108 |
+
max_value = -1 # Start with -1 to include epoch 0
|
| 109 |
+
best_file = None
|
| 110 |
+
model_files = []
|
| 111 |
+
# Named .pt files without epoch numbers (e.g. Ising-Decoder-SurfaceCode-1-Fast.pt)
|
| 112 |
+
named_pt_files = []
|
| 113 |
+
|
| 114 |
+
for filename in os.listdir(path):
|
| 115 |
+
if not filename.endswith(".pt"):
|
| 116 |
+
continue
|
| 117 |
+
if filename.startswith("PreDecoderModelMemory_"):
|
| 118 |
+
try:
|
| 119 |
+
value = float(filename.split(".")[2]) # Gets epoch number
|
| 120 |
+
model_files.append((filename, value))
|
| 121 |
+
if value > max_value:
|
| 122 |
+
max_value = value
|
| 123 |
+
best_file = filename
|
| 124 |
+
except (IndexError, ValueError) as e:
|
| 125 |
+
print(f"Warning: could not parse epoch from filename {filename}: {e}")
|
| 126 |
+
else:
|
| 127 |
+
named_pt_files.append(filename)
|
| 128 |
+
|
| 129 |
+
# Fall back to named .pt files when no epoch-numbered checkpoints are present
|
| 130 |
+
if best_file is None and named_pt_files:
|
| 131 |
+
named_pt_files.sort()
|
| 132 |
+
best_file = named_pt_files[-1]
|
| 133 |
+
model_files = [(f, None) for f in named_pt_files]
|
| 134 |
+
|
| 135 |
+
if rank == 0:
|
| 136 |
+
print(f"Found {len(model_files)} model file(s):")
|
| 137 |
+
for filename, epoch in sorted(model_files, key=lambda x: (x[1] is None, x[1] or 0)):
|
| 138 |
+
marker = "*" if filename == best_file else " "
|
| 139 |
+
epoch_str = str(epoch) if epoch is not None else "n/a"
|
| 140 |
+
print(f" [{marker}] {filename} (epoch {epoch_str})")
|
| 141 |
+
|
| 142 |
+
if best_file is None:
|
| 143 |
+
raise FileNotFoundError(
|
| 144 |
+
f"No valid model checkpoint files found in {path}\n"
|
| 145 |
+
f"Expected .pt files (e.g. Ising-Decoder-SurfaceCode-1-Fast.pt or "
|
| 146 |
+
f"PreDecoderModelMemory_*.pt).\n"
|
| 147 |
+
f"Hint: download the pretrained weights and place them in this directory, "
|
| 148 |
+
f"or set model_checkpoint_file in your config to an explicit path."
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
best_model_path = os.path.join(path, best_file)
|
| 152 |
+
if rank == 0:
|
| 153 |
+
epoch_str = str(max_value) if max_value >= 0 else "n/a"
|
| 154 |
+
print(f"Selected best model: {best_file} (epoch {epoch_str})")
|
| 155 |
+
|
| 156 |
+
return best_model_path
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _resolve_dir(path: str) -> str:
|
| 160 |
+
"""Return an absolute version of path, resolving relative paths from the repo root."""
|
| 161 |
+
if os.path.isabs(path):
|
| 162 |
+
return path
|
| 163 |
+
repo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 164 |
+
return os.path.join(repo_root, path)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _load_state_dict_from_pt(model_path: str, device) -> dict:
|
| 168 |
+
"""Load a state dict from a .pt checkpoint, handling multiple saved formats.
|
| 169 |
+
|
| 170 |
+
Supports:
|
| 171 |
+
- bare state dict (keys are layer names)
|
| 172 |
+
- {"model_state_dict": ...}
|
| 173 |
+
- {"state_dict": ...}
|
| 174 |
+
Also strips the DDP "module." prefix if present.
|
| 175 |
+
"""
|
| 176 |
+
raw = torch.load(model_path, map_location=device, weights_only=False)
|
| 177 |
+
if isinstance(raw, dict):
|
| 178 |
+
if "model_state_dict" in raw:
|
| 179 |
+
state_dict = raw["model_state_dict"]
|
| 180 |
+
elif "state_dict" in raw:
|
| 181 |
+
state_dict = raw["state_dict"]
|
| 182 |
+
else:
|
| 183 |
+
state_dict = raw
|
| 184 |
+
else:
|
| 185 |
+
raise ValueError(f"Unexpected checkpoint format: expected a dict, got {type(raw).__name__}")
|
| 186 |
+
return {
|
| 187 |
+
(k[len("module."):] if k.startswith("module.") else k): v for k, v in state_dict.items()
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _load_model(cfg, dist):
|
| 192 |
+
if dist.rank == 0:
|
| 193 |
+
print(f"Loading model for task: {cfg.workflow.task}")
|
| 194 |
+
|
| 195 |
+
_ensure_inference_io_channels(cfg)
|
| 196 |
+
|
| 197 |
+
# SafeTensors path: load fp16/fp32 model from SafeTensors file
|
| 198 |
+
safetensors_path = os.environ.get("PREDECODER_SAFETENSORS_CHECKPOINT", "").strip()
|
| 199 |
+
if safetensors_path:
|
| 200 |
+
from export.safetensors_utils import load_safetensors
|
| 201 |
+
if dist.rank == 0:
|
| 202 |
+
print(f"Loading model from SafeTensors: {safetensors_path}")
|
| 203 |
+
|
| 204 |
+
# Auto-detect model_id from SafeTensors metadata (don't override with config)
|
| 205 |
+
model, metadata = load_safetensors(
|
| 206 |
+
safetensors_path,
|
| 207 |
+
model_id=None,
|
| 208 |
+
device=str(dist.device),
|
| 209 |
+
)
|
| 210 |
+
if dist.rank == 0:
|
| 211 |
+
loaded_model_id = metadata.get("model_id", "unknown")
|
| 212 |
+
dtype = metadata.get("quant_format", "fp32")
|
| 213 |
+
receptive_field = metadata.get("receptive_field", "unknown")
|
| 214 |
+
param_count = sum(p.numel() for p in model.parameters())
|
| 215 |
+
print(f" model_id: {loaded_model_id} (from SafeTensors metadata)")
|
| 216 |
+
print(f" receptive_field: {receptive_field}")
|
| 217 |
+
print(f" dtype: {dtype}")
|
| 218 |
+
print(f" parameters: {param_count:,}")
|
| 219 |
+
|
| 220 |
+
# Warn if config model_id doesn't match file metadata
|
| 221 |
+
config_model_id = getattr(cfg, "model_id", None)
|
| 222 |
+
if config_model_id is not None and str(config_model_id) != str(loaded_model_id):
|
| 223 |
+
print(
|
| 224 |
+
f" Warning: config model_id={config_model_id} differs from "
|
| 225 |
+
f"file model_id={loaded_model_id}; using {loaded_model_id}"
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
if metadata.get("quant_format") == "fp16":
|
| 229 |
+
cfg.enable_fp16 = True
|
| 230 |
+
return model
|
| 231 |
+
|
| 232 |
+
# Direct file path override (for named pretrained models without epoch numbers)
|
| 233 |
+
model_checkpoint_file = getattr(cfg, 'model_checkpoint_file', None)
|
| 234 |
+
if model_checkpoint_file:
|
| 235 |
+
model_checkpoint_file = _resolve_dir(str(model_checkpoint_file))
|
| 236 |
+
if not os.path.exists(model_checkpoint_file):
|
| 237 |
+
raise FileNotFoundError(f"Checkpoint not found: {model_checkpoint_file}")
|
| 238 |
+
if dist.rank == 0:
|
| 239 |
+
print(f"Loading model from: {model_checkpoint_file}")
|
| 240 |
+
model = ModelFactory.create_model(cfg).to(dist.device)
|
| 241 |
+
if cfg.enable_fp16:
|
| 242 |
+
model = model.half()
|
| 243 |
+
state_dict = _load_state_dict_from_pt(model_checkpoint_file, dist.device)
|
| 244 |
+
model.load_state_dict(state_dict)
|
| 245 |
+
if dist.rank == 0:
|
| 246 |
+
param_count = sum(p.numel() for p in model.parameters())
|
| 247 |
+
print(f"Model loaded ({param_count:,} parameters)")
|
| 248 |
+
return model
|
| 249 |
+
|
| 250 |
+
model = ModelFactory.create_model(cfg).to(dist.device)
|
| 251 |
+
|
| 252 |
+
if cfg.enable_fp16:
|
| 253 |
+
model = model.half()
|
| 254 |
+
if dist.rank == 0:
|
| 255 |
+
print("Model converted to float16 for fp16 inference")
|
| 256 |
+
|
| 257 |
+
# Determine model directory
|
| 258 |
+
# Priority: 1) model_checkpoint_dir (for inference configs)
|
| 259 |
+
# 2) cfg.output/models (for training configs)
|
| 260 |
+
model_checkpoint_dir = getattr(cfg, 'model_checkpoint_dir', None)
|
| 261 |
+
use_checkpoint = getattr(cfg.test, 'use_model_checkpoint', -1)
|
| 262 |
+
|
| 263 |
+
if use_checkpoint == -1:
|
| 264 |
+
model_dir = _resolve_dir(
|
| 265 |
+
os.path.join(model_checkpoint_dir, "best_model")
|
| 266 |
+
if model_checkpoint_dir else f"{cfg.output}/models/best_model"
|
| 267 |
+
)
|
| 268 |
+
if dist.rank == 0:
|
| 269 |
+
print(f"Loading best model from: {model_dir}")
|
| 270 |
+
|
| 271 |
+
# Fallback: older runs may not have a best_model/ folder
|
| 272 |
+
if not os.path.isdir(model_dir):
|
| 273 |
+
fallback_dir = _resolve_dir(
|
| 274 |
+
model_checkpoint_dir if model_checkpoint_dir else f"{cfg.output}/models"
|
| 275 |
+
)
|
| 276 |
+
if dist.rank == 0:
|
| 277 |
+
print(f"best_model/ not found; falling back to: {fallback_dir}")
|
| 278 |
+
model_dir = fallback_dir
|
| 279 |
+
|
| 280 |
+
model_path = find_best_model(model_dir, rank=dist.rank)
|
| 281 |
+
else:
|
| 282 |
+
checkpoint_dir = _resolve_dir(
|
| 283 |
+
model_checkpoint_dir if model_checkpoint_dir else f"{cfg.output}/models"
|
| 284 |
+
)
|
| 285 |
+
if dist.rank == 0:
|
| 286 |
+
print(f"Loading checkpoint {use_checkpoint} from: {checkpoint_dir}")
|
| 287 |
+
|
| 288 |
+
# Prefer any PreDecoderModelMemory_* file ending with .0.{use_checkpoint}.pt
|
| 289 |
+
target_suffix = f".0.{use_checkpoint}.pt"
|
| 290 |
+
checkpoint_filename = None
|
| 291 |
+
try:
|
| 292 |
+
for f in os.listdir(checkpoint_dir):
|
| 293 |
+
if f.startswith("PreDecoderModelMemory_") and f.endswith(target_suffix):
|
| 294 |
+
checkpoint_filename = f
|
| 295 |
+
break
|
| 296 |
+
except OSError:
|
| 297 |
+
pass
|
| 298 |
+
if checkpoint_filename is None:
|
| 299 |
+
checkpoint_filename = f"PreDecoderModelMemory_v1.0.{use_checkpoint}.pt"
|
| 300 |
+
model_path = os.path.join(checkpoint_dir, checkpoint_filename)
|
| 301 |
+
|
| 302 |
+
if not os.path.exists(model_path):
|
| 303 |
+
raise FileNotFoundError(f"Checkpoint not found: {model_path}")
|
| 304 |
+
|
| 305 |
+
if dist.rank == 0:
|
| 306 |
+
print(f"Loading model parameters from: {model_path}")
|
| 307 |
+
|
| 308 |
+
state_dict = _load_state_dict_from_pt(model_path, dist.device)
|
| 309 |
+
model.load_state_dict(state_dict)
|
| 310 |
+
|
| 311 |
+
if dist.rank == 0:
|
| 312 |
+
param_count = sum(p.numel() for p in model.parameters())
|
| 313 |
+
print(f"Model loaded ({param_count:,} parameters)")
|
| 314 |
+
|
| 315 |
+
return model
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
if __name__ == "__main__":
|
| 319 |
+
run()
|
conf/config_public.yaml
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
# Public, single-file config for external users.
|
| 17 |
+
#
|
| 18 |
+
# Users should only edit the fields in this file.
|
| 19 |
+
# Advanced/experimental fields are intentionally omitted and will be populated
|
| 20 |
+
# from internal defaults (and validated to prevent unsupported overrides).
|
| 21 |
+
|
| 22 |
+
# === Model selection (required) ===
|
| 23 |
+
model_id: 6 # Choose 1, 2, 3, 4, or 5
|
| 24 |
+
|
| 25 |
+
model:
|
| 26 |
+
version: predecoder_fasthyper_rf13_v1
|
| 27 |
+
input_channels: 4
|
| 28 |
+
out_channels: 4
|
| 29 |
+
hidden_dim: 96
|
| 30 |
+
mid_dim: 144
|
| 31 |
+
mix_groups: 6
|
| 32 |
+
num_blocks: 5
|
| 33 |
+
stem_kernel_size: 3
|
| 34 |
+
gate_reduction: 4
|
| 35 |
+
dropout_p: 0.02
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# === Values for evaluation. Training window is hardcoded to model receptive field. ===
|
| 39 |
+
distance: 13
|
| 40 |
+
n_rounds: 104
|
| 41 |
+
|
| 42 |
+
# === Workflow ===
|
| 43 |
+
workflow:
|
| 44 |
+
task: train # train, inference
|
| 45 |
+
# simplify logs of inference to have only pymatching b4 and after predecoding. TODO: batch_size=1
|
| 46 |
+
|
| 47 |
+
# === Data (public surface only) ===
|
| 48 |
+
data:
|
| 49 |
+
# Surface code orientation (public naming): O1, O2, O3, O4
|
| 50 |
+
code_rotation: O1
|
| 51 |
+
# Circuit-level noise model (25-parameter). This is the default public noise specification.
|
| 52 |
+
# The defaults are chosen for p=0.003.
|
| 53 |
+
noise_model:
|
| 54 |
+
# State preparation errors (2)
|
| 55 |
+
p_prep_X: 0.002 # |+> state-prep fails with this probability (apply Z), 2*p/3
|
| 56 |
+
p_prep_Z: 0.002 # |0> state-prep fails with this probability (apply X), 2*p/3
|
| 57 |
+
# Measurement errors (2)
|
| 58 |
+
p_meas_X: 0.002 # Measurement in X-basis fails with this probability (apply Z before measurement), 2*p/3
|
| 59 |
+
p_meas_Z: 0.002 # Measurement in Z-basis fails with this probability (apply X before measurement), 2*p/3
|
| 60 |
+
# Idle during CNOT layers / bulk (3)
|
| 61 |
+
p_idle_cnot_X: 0.001 # p/3
|
| 62 |
+
p_idle_cnot_Y: 0.001 # p/3
|
| 63 |
+
p_idle_cnot_Z: 0.001 # p/3
|
| 64 |
+
# Idle during SPAM window (ancilla prep+reset) on data qubits only (3)
|
| 65 |
+
p_idle_spam_X: 0.001998 # 2*p/3 - 2*p^2/9
|
| 66 |
+
p_idle_spam_Y: 0.001998 # 2*p/3 - 2*p^2/9
|
| 67 |
+
p_idle_spam_Z: 0.001998 # 2*p/3 - 2*p^2/9
|
| 68 |
+
# CNOT two-qubit errors (15) - keys are p_cnot_{Pauli}{Pauli} excluding II, p/15
|
| 69 |
+
p_cnot_IX: 0.0002
|
| 70 |
+
p_cnot_IY: 0.0002
|
| 71 |
+
p_cnot_IZ: 0.0002
|
| 72 |
+
p_cnot_XI: 0.0002
|
| 73 |
+
p_cnot_XX: 0.0002
|
| 74 |
+
p_cnot_XY: 0.0002
|
| 75 |
+
p_cnot_XZ: 0.0002
|
| 76 |
+
p_cnot_YI: 0.0002
|
| 77 |
+
p_cnot_YX: 0.0002
|
| 78 |
+
p_cnot_YY: 0.0002
|
| 79 |
+
p_cnot_YZ: 0.0002
|
| 80 |
+
p_cnot_ZI: 0.0002
|
| 81 |
+
p_cnot_ZX: 0.0002
|
| 82 |
+
p_cnot_ZY: 0.0002
|
| 83 |
+
p_cnot_ZZ: 0.0002
|
| 84 |
+
|
framework.png
ADDED
|