File size: 1,242 Bytes
4313d1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e0a6e4
 
 
4313d1d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch


def _unwrap(model):
    return getattr(model, "_orig_mod", model)


def apply_ckpt_model_config(ckpt, cfg):
    """If ckpt has a `model_config` key (width metadata), apply to cfg.model.

    Must be called before LocalVQE.from_config(cfg) so the rebuilt model
    matches the pruned-checkpoint shapes.
    """
    mc = ckpt.get("model_config")
    if not mc:
        return
    if "mic_channels" in mc:
        cfg.model.mic_channels = list(mc["mic_channels"])
    if "far_channels" in mc:
        cfg.model.far_channels = list(mc["far_channels"])
    if "bottleneck_hidden" in mc:
        cfg.model.bottleneck_hidden = int(mc["bottleneck_hidden"])


def load_checkpoint(path, model):
    target_device = next(_unwrap(model).parameters()).device
    ckpt = torch.load(path, weights_only=False, map_location=target_device)
    state = ckpt["model_state_dict"]
    state = {k.removeprefix("_orig_mod."): v for k, v in state.items()}
    state.pop("decoder._overlap_count", None)
    # Pre-buffer checkpoints lack align.temperature; default to 1.0.
    if "align.temperature" not in state:
        state["align.temperature"] = torch.tensor(1.0)
    _unwrap(model).load_state_dict(state)
    return ckpt["epoch"], ckpt.get("loss")