| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import hydra, sys, torch, os, json, numpy as np |
| from omegaconf import DictConfig, OmegaConf |
| from training.train import main as train_main |
| from model.factory import ModelFactory |
| from data.factory import DatapipeFactory |
| from hydra.utils import to_absolute_path |
| from workflows.config_validator import ( |
| apply_public_defaults_and_model, |
| validate_public_config, |
| ) |
|
|
| from training.distributed import DistributedManager |
|
|
| from torch.utils.data import DataLoader |
|
|
|
|
| def _ensure_inference_io_channels(cfg): |
| |
| if not getattr(cfg.model, "out_channels", None) or cfg.model.out_channels == 0: |
| cfg.model.out_channels = 4 |
|
|
| |
| if not getattr(cfg.model, "input_channels", None) or cfg.model.input_channels == 0: |
| ds = DatapipeFactory.create_datapipe_inference(cfg) |
| tmp = DataLoader(ds, batch_size=1) |
| sample = next(iter(tmp)) |
| cfg.model.input_channels = int(sample["trainX"].shape[1]) |
|
|
| |
| if hasattr(cfg.model, "num_filters"): |
| filters = list(cfg.model.num_filters) |
| if filters and filters[-1] != cfg.model.out_channels: |
| print( |
| f"[run] Adjusting model.num_filters[-1] {filters[-1]} -> {cfg.model.out_channels}" |
| ) |
| filters[-1] = cfg.model.out_channels |
| cfg.model.num_filters = filters |
|
|
|
|
| @hydra.main(version_base="1.3", config_path="../../conf", config_name="config") |
| def run(cfg: DictConfig) -> None: |
| |
| |
| model_spec = validate_public_config(cfg) |
| cfg = apply_public_defaults_and_model(cfg, model_spec) |
|
|
| torch.backends.cuda.matmul.allow_tf32 = cfg.enable_matmul_tf32 |
| torch.backends.cudnn.allow_tf32 = cfg.enable_cudnn_tf32 |
|
|
| if cfg.code == "surface" or cfg.code == "surface_partition": |
| run_surface(cfg) |
|
|
|
|
| def run_surface(cfg: DictConfig): |
| if cfg.workflow.task == "train": |
| train_main(cfg) |
| elif cfg.workflow.task == "threshold": |
| raise ValueError( |
| "workflow.task='threshold' has been renamed to workflow.task='inference'. " |
| "Please update your config/env var to WORKFLOW=inference." |
| ) |
| elif cfg.workflow.task == "inference": |
| from evaluation.inference import run_inference |
| DistributedManager.initialize() |
| dist = DistributedManager() |
| model = _load_model(cfg, dist) |
| run_inference(model, dist.device, dist, cfg) |
| elif cfg.workflow.task == "data": |
| DistributedManager.initialize() |
| dist = DistributedManager() |
| train_loader, _ = DatapipeFactory.create_dataloader(cfg, dist.world_size, dist.rank) |
| for j, dl in enumerate(train_loader): |
| print(f"Batch {j}: syndrome_shape: {dl['syndrome'].shape}") |
| elif cfg.workflow.task == "decoder_ablation": |
| from evaluation.failure_analysis import decoder_ablation_study |
| DistributedManager.initialize() |
| dist = DistributedManager() |
| model = _load_model(cfg, dist) |
| decoder_ablation_study(model, dist.device, dist, cfg) |
| elif cfg.workflow.task in ("sampling", "visualize"): |
| raise ValueError( |
| f"workflow.task={cfg.workflow.task!r} is not supported in the early-access public release. " |
| "Supported workflows: train, inference, decoder_ablation." |
| ) |
|
|
|
|
| def find_best_model(path, *, rank: int = 0): |
| if rank == 0: |
| print(f"Searching for best model in: {path}") |
| if not os.path.isdir(path): |
| raise FileNotFoundError(f"Model directory does not exist: {path}") |
|
|
| max_value = -1 |
| best_file = None |
| model_files = [] |
| |
| named_pt_files = [] |
|
|
| for filename in os.listdir(path): |
| if not filename.endswith(".pt"): |
| continue |
| if filename.startswith("PreDecoderModelMemory_"): |
| try: |
| value = float(filename.split(".")[2]) |
| model_files.append((filename, value)) |
| if value > max_value: |
| max_value = value |
| best_file = filename |
| except (IndexError, ValueError) as e: |
| print(f"Warning: could not parse epoch from filename {filename}: {e}") |
| else: |
| named_pt_files.append(filename) |
|
|
| |
| if best_file is None and named_pt_files: |
| named_pt_files.sort() |
| best_file = named_pt_files[-1] |
| model_files = [(f, None) for f in named_pt_files] |
|
|
| if rank == 0: |
| print(f"Found {len(model_files)} model file(s):") |
| for filename, epoch in sorted(model_files, key=lambda x: (x[1] is None, x[1] or 0)): |
| marker = "*" if filename == best_file else " " |
| epoch_str = str(epoch) if epoch is not None else "n/a" |
| print(f" [{marker}] {filename} (epoch {epoch_str})") |
|
|
| if best_file is None: |
| raise FileNotFoundError( |
| f"No valid model checkpoint files found in {path}\n" |
| f"Expected .pt files (e.g. Ising-Decoder-SurfaceCode-1-Fast.pt or " |
| f"PreDecoderModelMemory_*.pt).\n" |
| f"Hint: download the pretrained weights and place them in this directory, " |
| f"or set model_checkpoint_file in your config to an explicit path." |
| ) |
|
|
| best_model_path = os.path.join(path, best_file) |
| if rank == 0: |
| epoch_str = str(max_value) if max_value >= 0 else "n/a" |
| print(f"Selected best model: {best_file} (epoch {epoch_str})") |
|
|
| return best_model_path |
|
|
|
|
| def _resolve_dir(path: str) -> str: |
| """Return an absolute version of path, resolving relative paths from the repo root.""" |
| if os.path.isabs(path): |
| return path |
| repo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| return os.path.join(repo_root, path) |
|
|
|
|
| def _load_state_dict_from_pt(model_path: str, device) -> dict: |
| """Load a state dict from a .pt checkpoint, handling multiple saved formats. |
| |
| Supports: |
| - bare state dict (keys are layer names) |
| - {"model_state_dict": ...} |
| - {"state_dict": ...} |
| Also strips the DDP "module." prefix if present. |
| """ |
| raw = torch.load(model_path, map_location=device, weights_only=False) |
| if isinstance(raw, dict): |
| if "model_state_dict" in raw: |
| state_dict = raw["model_state_dict"] |
| elif "state_dict" in raw: |
| state_dict = raw["state_dict"] |
| else: |
| state_dict = raw |
| else: |
| raise ValueError(f"Unexpected checkpoint format: expected a dict, got {type(raw).__name__}") |
| return { |
| (k[len("module."):] if k.startswith("module.") else k): v for k, v in state_dict.items() |
| } |
|
|
|
|
| def _load_model(cfg, dist): |
| if dist.rank == 0: |
| print(f"Loading model for task: {cfg.workflow.task}") |
|
|
| _ensure_inference_io_channels(cfg) |
|
|
| |
| safetensors_path = os.environ.get("PREDECODER_SAFETENSORS_CHECKPOINT", "").strip() |
| if safetensors_path: |
| from export.safetensors_utils import load_safetensors |
| if dist.rank == 0: |
| print(f"Loading model from SafeTensors: {safetensors_path}") |
|
|
| |
| model, metadata = load_safetensors( |
| safetensors_path, |
| model_id=None, |
| device=str(dist.device), |
| ) |
| if dist.rank == 0: |
| loaded_model_id = metadata.get("model_id", "unknown") |
| dtype = metadata.get("quant_format", "fp32") |
| receptive_field = metadata.get("receptive_field", "unknown") |
| param_count = sum(p.numel() for p in model.parameters()) |
| print(f" model_id: {loaded_model_id} (from SafeTensors metadata)") |
| print(f" receptive_field: {receptive_field}") |
| print(f" dtype: {dtype}") |
| print(f" parameters: {param_count:,}") |
|
|
| |
| config_model_id = getattr(cfg, "model_id", None) |
| if config_model_id is not None and str(config_model_id) != str(loaded_model_id): |
| print( |
| f" Warning: config model_id={config_model_id} differs from " |
| f"file model_id={loaded_model_id}; using {loaded_model_id}" |
| ) |
|
|
| if metadata.get("quant_format") == "fp16": |
| cfg.enable_fp16 = True |
| return model |
|
|
| |
| model_checkpoint_file = getattr(cfg, 'model_checkpoint_file', None) |
| if model_checkpoint_file: |
| model_checkpoint_file = _resolve_dir(str(model_checkpoint_file)) |
| if not os.path.exists(model_checkpoint_file): |
| raise FileNotFoundError(f"Checkpoint not found: {model_checkpoint_file}") |
| if dist.rank == 0: |
| print(f"Loading model from: {model_checkpoint_file}") |
| model = ModelFactory.create_model(cfg).to(dist.device) |
| if cfg.enable_fp16: |
| model = model.half() |
| state_dict = _load_state_dict_from_pt(model_checkpoint_file, dist.device) |
| model.load_state_dict(state_dict) |
| if dist.rank == 0: |
| param_count = sum(p.numel() for p in model.parameters()) |
| print(f"Model loaded ({param_count:,} parameters)") |
| return model |
|
|
| model = ModelFactory.create_model(cfg).to(dist.device) |
|
|
| if cfg.enable_fp16: |
| model = model.half() |
| if dist.rank == 0: |
| print("Model converted to float16 for fp16 inference") |
|
|
| |
| |
| |
| model_checkpoint_dir = getattr(cfg, 'model_checkpoint_dir', None) |
| use_checkpoint = getattr(cfg.test, 'use_model_checkpoint', -1) |
|
|
| if use_checkpoint == -1: |
| model_dir = _resolve_dir( |
| os.path.join(model_checkpoint_dir, "best_model") |
| if model_checkpoint_dir else f"{cfg.output}/models/best_model" |
| ) |
| if dist.rank == 0: |
| print(f"Loading best model from: {model_dir}") |
|
|
| |
| if not os.path.isdir(model_dir): |
| fallback_dir = _resolve_dir( |
| model_checkpoint_dir if model_checkpoint_dir else f"{cfg.output}/models" |
| ) |
| if dist.rank == 0: |
| print(f"best_model/ not found; falling back to: {fallback_dir}") |
| model_dir = fallback_dir |
|
|
| model_path = find_best_model(model_dir, rank=dist.rank) |
| else: |
| checkpoint_dir = _resolve_dir( |
| model_checkpoint_dir if model_checkpoint_dir else f"{cfg.output}/models" |
| ) |
| if dist.rank == 0: |
| print(f"Loading checkpoint {use_checkpoint} from: {checkpoint_dir}") |
|
|
| |
| target_suffix = f".0.{use_checkpoint}.pt" |
| checkpoint_filename = None |
| try: |
| for f in os.listdir(checkpoint_dir): |
| if f.startswith("PreDecoderModelMemory_") and f.endswith(target_suffix): |
| checkpoint_filename = f |
| break |
| except OSError: |
| pass |
| if checkpoint_filename is None: |
| checkpoint_filename = f"PreDecoderModelMemory_v1.0.{use_checkpoint}.pt" |
| model_path = os.path.join(checkpoint_dir, checkpoint_filename) |
|
|
| if not os.path.exists(model_path): |
| raise FileNotFoundError(f"Checkpoint not found: {model_path}") |
|
|
| if dist.rank == 0: |
| print(f"Loading model parameters from: {model_path}") |
|
|
| state_dict = _load_state_dict_from_pt(model_path, dist.device) |
| model.load_state_dict(state_dict) |
|
|
| if dist.rank == 0: |
| param_count = sum(p.numel() for p in model.parameters()) |
| print(f"Model loaded ({param_count:,} parameters)") |
|
|
| return model |
|
|
|
|
| if __name__ == "__main__": |
| run() |
|
|