Spaces:
Running
Running
| import torch | |
| def _unwrap(model): | |
| return getattr(model, "_orig_mod", model) | |
| def apply_ckpt_model_config(ckpt, cfg): | |
| """If ckpt has a `model_config` key (width metadata), apply to cfg.model. | |
| Must be called before LocalVQE.from_config(cfg) so the rebuilt model | |
| matches the pruned-checkpoint shapes. | |
| """ | |
| mc = ckpt.get("model_config") | |
| if not mc: | |
| return | |
| if "mic_channels" in mc: | |
| cfg.model.mic_channels = list(mc["mic_channels"]) | |
| if "far_channels" in mc: | |
| cfg.model.far_channels = list(mc["far_channels"]) | |
| if "bottleneck_hidden" in mc: | |
| cfg.model.bottleneck_hidden = int(mc["bottleneck_hidden"]) | |
| def load_checkpoint(path, model): | |
| target_device = next(_unwrap(model).parameters()).device | |
| ckpt = torch.load(path, weights_only=False, map_location=target_device) | |
| state = ckpt["model_state_dict"] | |
| state = {k.removeprefix("_orig_mod."): v for k, v in state.items()} | |
| state.pop("decoder._overlap_count", None) | |
| # Pre-buffer checkpoints lack align.temperature; default to 1.0. | |
| if "align.temperature" not in state: | |
| state["align.temperature"] = torch.tensor(1.0) | |
| _unwrap(model).load_state_dict(state) | |
| return ckpt["epoch"], ckpt.get("loss") | |