richiejp's picture
Initial upload: LocalVQE demo Space
6e0a6e4 verified
import torch
def _unwrap(model):
return getattr(model, "_orig_mod", model)
def apply_ckpt_model_config(ckpt, cfg):
"""If ckpt has a `model_config` key (width metadata), apply to cfg.model.
Must be called before LocalVQE.from_config(cfg) so the rebuilt model
matches the pruned-checkpoint shapes.
"""
mc = ckpt.get("model_config")
if not mc:
return
if "mic_channels" in mc:
cfg.model.mic_channels = list(mc["mic_channels"])
if "far_channels" in mc:
cfg.model.far_channels = list(mc["far_channels"])
if "bottleneck_hidden" in mc:
cfg.model.bottleneck_hidden = int(mc["bottleneck_hidden"])
def load_checkpoint(path, model):
target_device = next(_unwrap(model).parameters()).device
ckpt = torch.load(path, weights_only=False, map_location=target_device)
state = ckpt["model_state_dict"]
state = {k.removeprefix("_orig_mod."): v for k, v in state.items()}
state.pop("decoder._overlap_count", None)
# Pre-buffer checkpoints lack align.temperature; default to 1.0.
if "align.temperature" not in state:
state["align.temperature"] = torch.tensor(1.0)
_unwrap(model).load_state_dict(state)
return ckpt["epoch"], ckpt.get("loss")