Spaces:
Running
Running
File size: 1,242 Bytes
4313d1d 6e0a6e4 4313d1d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | import torch
def _unwrap(model):
return getattr(model, "_orig_mod", model)
def apply_ckpt_model_config(ckpt, cfg):
"""If ckpt has a `model_config` key (width metadata), apply to cfg.model.
Must be called before LocalVQE.from_config(cfg) so the rebuilt model
matches the pruned-checkpoint shapes.
"""
mc = ckpt.get("model_config")
if not mc:
return
if "mic_channels" in mc:
cfg.model.mic_channels = list(mc["mic_channels"])
if "far_channels" in mc:
cfg.model.far_channels = list(mc["far_channels"])
if "bottleneck_hidden" in mc:
cfg.model.bottleneck_hidden = int(mc["bottleneck_hidden"])
def load_checkpoint(path, model):
target_device = next(_unwrap(model).parameters()).device
ckpt = torch.load(path, weights_only=False, map_location=target_device)
state = ckpt["model_state_dict"]
state = {k.removeprefix("_orig_mod."): v for k, v in state.items()}
state.pop("decoder._overlap_count", None)
# Pre-buffer checkpoints lack align.temperature; default to 1.0.
if "align.temperature" not in state:
state["align.temperature"] = torch.tensor(1.0)
_unwrap(model).load_state_dict(state)
return ckpt["epoch"], ckpt.get("loss")
|