donghufeng
init
d57fabf
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hydra, sys, torch, os, json, numpy as np
from omegaconf import DictConfig, OmegaConf
from training.train import main as train_main
from model.factory import ModelFactory
from data.factory import DatapipeFactory
from hydra.utils import to_absolute_path
from workflows.config_validator import (
apply_public_defaults_and_model,
validate_public_config,
)
from training.distributed import DistributedManager
from torch.utils.data import DataLoader
def _ensure_inference_io_channels(cfg):
# 1) Ensure out_channels matches the model’s heads (4: z_data, x_data, syn_x, syn_z)
if not getattr(cfg.model, "out_channels", None) or cfg.model.out_channels == 0:
cfg.model.out_channels = 4
# 2) Infer input_channels from a single inference sample if not set
if not getattr(cfg.model, "input_channels", None) or cfg.model.input_channels == 0:
ds = DatapipeFactory.create_datapipe_inference(cfg)
tmp = DataLoader(ds, batch_size=1)
sample = next(iter(tmp))
cfg.model.input_channels = int(sample["trainX"].shape[1])
# 3) Keep num_filters consistent with out_channels
if hasattr(cfg.model, "num_filters"):
filters = list(cfg.model.num_filters)
if filters and filters[-1] != cfg.model.out_channels:
print(
f"[run] Adjusting model.num_filters[-1] {filters[-1]} -> {cfg.model.out_channels}"
)
filters[-1] = cfg.model.out_channels
cfg.model.num_filters = filters
@hydra.main(version_base="1.3", config_path="../../conf", config_name="config")
def run(cfg: DictConfig) -> None:
# Early-access public release: validate public surface, then merge in hidden defaults.
# NOTE: Validation is done BEFORE merging defaults so we can fail fast on injected fields.
model_spec = validate_public_config(cfg)
cfg = apply_public_defaults_and_model(cfg, model_spec)
torch.backends.cuda.matmul.allow_tf32 = cfg.enable_matmul_tf32
torch.backends.cudnn.allow_tf32 = cfg.enable_cudnn_tf32
if cfg.code == "surface" or cfg.code == "surface_partition":
run_surface(cfg)
def run_surface(cfg: DictConfig):
if cfg.workflow.task == "train":
train_main(cfg)
elif cfg.workflow.task == "threshold":
raise ValueError(
"workflow.task='threshold' has been renamed to workflow.task='inference'. "
"Please update your config/env var to WORKFLOW=inference."
)
elif cfg.workflow.task == "inference":
from evaluation.inference import run_inference
DistributedManager.initialize()
dist = DistributedManager()
model = _load_model(cfg, dist)
run_inference(model, dist.device, dist, cfg)
elif cfg.workflow.task == "data":
DistributedManager.initialize()
dist = DistributedManager()
train_loader, _ = DatapipeFactory.create_dataloader(cfg, dist.world_size, dist.rank)
for j, dl in enumerate(train_loader):
print(f"Batch {j}: syndrome_shape: {dl['syndrome'].shape}")
elif cfg.workflow.task == "decoder_ablation":
from evaluation.failure_analysis import decoder_ablation_study
DistributedManager.initialize()
dist = DistributedManager()
model = _load_model(cfg, dist)
decoder_ablation_study(model, dist.device, dist, cfg)
elif cfg.workflow.task in ("sampling", "visualize"):
raise ValueError(
f"workflow.task={cfg.workflow.task!r} is not supported in the early-access public release. "
"Supported workflows: train, inference, decoder_ablation."
)
def find_best_model(path, *, rank: int = 0):
if rank == 0:
print(f"Searching for best model in: {path}")
if not os.path.isdir(path):
raise FileNotFoundError(f"Model directory does not exist: {path}")
max_value = -1 # Start with -1 to include epoch 0
best_file = None
model_files = []
# Named .pt files without epoch numbers (e.g. Ising-Decoder-SurfaceCode-1-Fast.pt)
named_pt_files = []
for filename in os.listdir(path):
if not filename.endswith(".pt"):
continue
if filename.startswith("PreDecoderModelMemory_"):
try:
value = float(filename.split(".")[2]) # Gets epoch number
model_files.append((filename, value))
if value > max_value:
max_value = value
best_file = filename
except (IndexError, ValueError) as e:
print(f"Warning: could not parse epoch from filename {filename}: {e}")
else:
named_pt_files.append(filename)
# Fall back to named .pt files when no epoch-numbered checkpoints are present
if best_file is None and named_pt_files:
named_pt_files.sort()
best_file = named_pt_files[-1]
model_files = [(f, None) for f in named_pt_files]
if rank == 0:
print(f"Found {len(model_files)} model file(s):")
for filename, epoch in sorted(model_files, key=lambda x: (x[1] is None, x[1] or 0)):
marker = "*" if filename == best_file else " "
epoch_str = str(epoch) if epoch is not None else "n/a"
print(f" [{marker}] {filename} (epoch {epoch_str})")
if best_file is None:
raise FileNotFoundError(
f"No valid model checkpoint files found in {path}\n"
f"Expected .pt files (e.g. Ising-Decoder-SurfaceCode-1-Fast.pt or "
f"PreDecoderModelMemory_*.pt).\n"
f"Hint: download the pretrained weights and place them in this directory, "
f"or set model_checkpoint_file in your config to an explicit path."
)
best_model_path = os.path.join(path, best_file)
if rank == 0:
epoch_str = str(max_value) if max_value >= 0 else "n/a"
print(f"Selected best model: {best_file} (epoch {epoch_str})")
return best_model_path
def _resolve_dir(path: str) -> str:
"""Return an absolute version of path, resolving relative paths from the repo root."""
if os.path.isabs(path):
return path
repo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
return os.path.join(repo_root, path)
def _load_state_dict_from_pt(model_path: str, device) -> dict:
"""Load a state dict from a .pt checkpoint, handling multiple saved formats.
Supports:
- bare state dict (keys are layer names)
- {"model_state_dict": ...}
- {"state_dict": ...}
Also strips the DDP "module." prefix if present.
"""
raw = torch.load(model_path, map_location=device, weights_only=False)
if isinstance(raw, dict):
if "model_state_dict" in raw:
state_dict = raw["model_state_dict"]
elif "state_dict" in raw:
state_dict = raw["state_dict"]
else:
state_dict = raw
else:
raise ValueError(f"Unexpected checkpoint format: expected a dict, got {type(raw).__name__}")
return {
(k[len("module."):] if k.startswith("module.") else k): v for k, v in state_dict.items()
}
def _load_model(cfg, dist):
if dist.rank == 0:
print(f"Loading model for task: {cfg.workflow.task}")
_ensure_inference_io_channels(cfg)
# SafeTensors path: load fp16/fp32 model from SafeTensors file
safetensors_path = os.environ.get("PREDECODER_SAFETENSORS_CHECKPOINT", "").strip()
if safetensors_path:
from export.safetensors_utils import load_safetensors
if dist.rank == 0:
print(f"Loading model from SafeTensors: {safetensors_path}")
# Auto-detect model_id from SafeTensors metadata (don't override with config)
model, metadata = load_safetensors(
safetensors_path,
model_id=None,
device=str(dist.device),
)
if dist.rank == 0:
loaded_model_id = metadata.get("model_id", "unknown")
dtype = metadata.get("quant_format", "fp32")
receptive_field = metadata.get("receptive_field", "unknown")
param_count = sum(p.numel() for p in model.parameters())
print(f" model_id: {loaded_model_id} (from SafeTensors metadata)")
print(f" receptive_field: {receptive_field}")
print(f" dtype: {dtype}")
print(f" parameters: {param_count:,}")
# Warn if config model_id doesn't match file metadata
config_model_id = getattr(cfg, "model_id", None)
if config_model_id is not None and str(config_model_id) != str(loaded_model_id):
print(
f" Warning: config model_id={config_model_id} differs from "
f"file model_id={loaded_model_id}; using {loaded_model_id}"
)
if metadata.get("quant_format") == "fp16":
cfg.enable_fp16 = True
return model
# Direct file path override (for named pretrained models without epoch numbers)
model_checkpoint_file = getattr(cfg, 'model_checkpoint_file', None)
if model_checkpoint_file:
model_checkpoint_file = _resolve_dir(str(model_checkpoint_file))
if not os.path.exists(model_checkpoint_file):
raise FileNotFoundError(f"Checkpoint not found: {model_checkpoint_file}")
if dist.rank == 0:
print(f"Loading model from: {model_checkpoint_file}")
model = ModelFactory.create_model(cfg).to(dist.device)
if cfg.enable_fp16:
model = model.half()
state_dict = _load_state_dict_from_pt(model_checkpoint_file, dist.device)
model.load_state_dict(state_dict)
if dist.rank == 0:
param_count = sum(p.numel() for p in model.parameters())
print(f"Model loaded ({param_count:,} parameters)")
return model
model = ModelFactory.create_model(cfg).to(dist.device)
if cfg.enable_fp16:
model = model.half()
if dist.rank == 0:
print("Model converted to float16 for fp16 inference")
# Determine model directory
# Priority: 1) model_checkpoint_dir (for inference configs)
# 2) cfg.output/models (for training configs)
model_checkpoint_dir = getattr(cfg, 'model_checkpoint_dir', None)
use_checkpoint = getattr(cfg.test, 'use_model_checkpoint', -1)
if use_checkpoint == -1:
model_dir = _resolve_dir(
os.path.join(model_checkpoint_dir, "best_model")
if model_checkpoint_dir else f"{cfg.output}/models/best_model"
)
if dist.rank == 0:
print(f"Loading best model from: {model_dir}")
# Fallback: older runs may not have a best_model/ folder
if not os.path.isdir(model_dir):
fallback_dir = _resolve_dir(
model_checkpoint_dir if model_checkpoint_dir else f"{cfg.output}/models"
)
if dist.rank == 0:
print(f"best_model/ not found; falling back to: {fallback_dir}")
model_dir = fallback_dir
model_path = find_best_model(model_dir, rank=dist.rank)
else:
checkpoint_dir = _resolve_dir(
model_checkpoint_dir if model_checkpoint_dir else f"{cfg.output}/models"
)
if dist.rank == 0:
print(f"Loading checkpoint {use_checkpoint} from: {checkpoint_dir}")
# Prefer any PreDecoderModelMemory_* file ending with .0.{use_checkpoint}.pt
target_suffix = f".0.{use_checkpoint}.pt"
checkpoint_filename = None
try:
for f in os.listdir(checkpoint_dir):
if f.startswith("PreDecoderModelMemory_") and f.endswith(target_suffix):
checkpoint_filename = f
break
except OSError:
pass
if checkpoint_filename is None:
checkpoint_filename = f"PreDecoderModelMemory_v1.0.{use_checkpoint}.pt"
model_path = os.path.join(checkpoint_dir, checkpoint_filename)
if not os.path.exists(model_path):
raise FileNotFoundError(f"Checkpoint not found: {model_path}")
if dist.rank == 0:
print(f"Loading model parameters from: {model_path}")
state_dict = _load_state_dict_from_pt(model_path, dist.device)
model.load_state_dict(state_dict)
if dist.rank == 0:
param_count = sum(p.numel() for p in model.parameters())
print(f"Model loaded ({param_count:,} parameters)")
return model
if __name__ == "__main__":
run()