import torch def _unwrap(model): return getattr(model, "_orig_mod", model) def apply_ckpt_model_config(ckpt, cfg): """If ckpt has a `model_config` key (width metadata), apply to cfg.model. Must be called before LocalVQE.from_config(cfg) so the rebuilt model matches the pruned-checkpoint shapes. """ mc = ckpt.get("model_config") if not mc: return if "mic_channels" in mc: cfg.model.mic_channels = list(mc["mic_channels"]) if "far_channels" in mc: cfg.model.far_channels = list(mc["far_channels"]) if "bottleneck_hidden" in mc: cfg.model.bottleneck_hidden = int(mc["bottleneck_hidden"]) def load_checkpoint(path, model): target_device = next(_unwrap(model).parameters()).device ckpt = torch.load(path, weights_only=False, map_location=target_device) state = ckpt["model_state_dict"] state = {k.removeprefix("_orig_mod."): v for k, v in state.items()} state.pop("decoder._overlap_count", None) # Pre-buffer checkpoints lack align.temperature; default to 1.0. if "align.temperature" not in state: state["align.temperature"] = torch.tensor(1.0) _unwrap(model).load_state_dict(state) return ckpt["epoch"], ckpt.get("loss")