Spaces:
Running
Running
Initial upload: LocalVQE demo Space
Browse files- app.py +5 -5
- localvqe_model/align.py +9 -10
- localvqe_model/config.py +0 -1
- localvqe_model/utils.py +3 -0
app.py
CHANGED
|
@@ -36,11 +36,11 @@ def _build_model() -> LocalVQE:
|
|
| 36 |
del peek
|
| 37 |
model = LocalVQE.from_config(cfg).to("cpu")
|
| 38 |
load_checkpoint(ckpt_path, model)
|
| 39 |
-
#
|
| 40 |
-
#
|
| 41 |
-
#
|
| 42 |
-
#
|
| 43 |
-
model.align.fold_temperature(
|
| 44 |
model.eval()
|
| 45 |
n_params = sum(p.numel() for p in model.parameters())
|
| 46 |
print(f"LocalVQE loaded: {n_params:,} params from {ckpt_path}")
|
|
|
|
| 36 |
del peek
|
| 37 |
model = LocalVQE.from_config(cfg).to("cpu")
|
| 38 |
load_checkpoint(ckpt_path, model)
|
| 39 |
+
# Fold the trained AlignBlock softmax temperature (carried in the
|
| 40 |
+
# checkpoint as a buffer) into the smoothing conv weights — without
|
| 41 |
+
# this the model runs at the default 1.0 instead of the trained value
|
| 42 |
+
# and loses several dB of FE-ST ERLE on real recordings.
|
| 43 |
+
model.align.fold_temperature()
|
| 44 |
model.eval()
|
| 45 |
n_params = sum(p.numel() for p in model.parameters())
|
| 46 |
print(f"LocalVQE loaded: {n_params:,} params from {ckpt_path}")
|
localvqe_model/align.py
CHANGED
|
@@ -20,7 +20,9 @@ class AlignBlock(nn.Module):
|
|
| 20 |
self.in_channels = in_channels
|
| 21 |
self.hidden_channels = hidden_channels
|
| 22 |
self.dmax = dmax
|
| 23 |
-
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# Pointwise projections for Q and K
|
| 26 |
self.pconv_mic = nn.Conv2d(in_channels, hidden_channels, 1)
|
|
@@ -39,21 +41,18 @@ class AlignBlock(nn.Module):
|
|
| 39 |
)
|
| 40 |
|
| 41 |
def fold_temperature(self, temperature=None):
|
| 42 |
-
"""Bake
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
weights. Used at inference-load time so PyTorch eval and the GGML
|
| 47 |
-
graph (whose softmax has no temperature parameter) both produce
|
| 48 |
-
the trained-temperature distribution.
|
| 49 |
"""
|
| 50 |
-
t = temperature if temperature is not None else self.temperature
|
| 51 |
if t == 1.0:
|
| 52 |
return
|
| 53 |
with torch.no_grad():
|
| 54 |
self.conv[1].weight.div_(t)
|
| 55 |
self.conv[1].bias.div_(t)
|
| 56 |
-
|
| 57 |
|
| 58 |
def forward(self, x_mic, x_ref, return_delay=False):
|
| 59 |
"""
|
|
|
|
| 20 |
self.in_channels = in_channels
|
| 21 |
self.hidden_channels = hidden_channels
|
| 22 |
self.dmax = dmax
|
| 23 |
+
# Registered as a buffer so the trained value persists in state_dict.
|
| 24 |
+
# Mutate via .fill_(), never re-assign.
|
| 25 |
+
self.register_buffer("temperature", torch.tensor(float(temperature)))
|
| 26 |
|
| 27 |
# Pointwise projections for Q and K
|
| 28 |
self.pconv_mic = nn.Conv2d(in_channels, hidden_channels, 1)
|
|
|
|
| 41 |
)
|
| 42 |
|
| 43 |
def fold_temperature(self, temperature=None):
|
| 44 |
+
"""Bake the AlignBlock softmax temperature into the smoothing-conv
|
| 45 |
+
weights. Reads `self.temperature` (a buffer carried by the
|
| 46 |
+
checkpoint) when called with no argument; after folding the buffer
|
| 47 |
+
is reset to 1.0 so subsequent calls are no-ops.
|
|
|
|
|
|
|
|
|
|
| 48 |
"""
|
| 49 |
+
t = float(temperature if temperature is not None else self.temperature)
|
| 50 |
if t == 1.0:
|
| 51 |
return
|
| 52 |
with torch.no_grad():
|
| 53 |
self.conv[1].weight.div_(t)
|
| 54 |
self.conv[1].bias.div_(t)
|
| 55 |
+
self.temperature.fill_(1.0)
|
| 56 |
|
| 57 |
def forward(self, x_mic, x_ref, return_delay=False):
|
| 58 |
"""
|
localvqe_model/config.py
CHANGED
|
@@ -11,7 +11,6 @@ class ModelConfig:
|
|
| 11 |
power_law_c: float = 0.3
|
| 12 |
kernel_size: Tuple[int, int] = (4, 4)
|
| 13 |
bottleneck_hidden: int = 0
|
| 14 |
-
align_temp_end: float = 0.1
|
| 15 |
|
| 16 |
|
| 17 |
@dataclass
|
|
|
|
| 11 |
power_law_c: float = 0.3
|
| 12 |
kernel_size: Tuple[int, int] = (4, 4)
|
| 13 |
bottleneck_hidden: int = 0
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
@dataclass
|
localvqe_model/utils.py
CHANGED
|
@@ -28,5 +28,8 @@ def load_checkpoint(path, model):
|
|
| 28 |
state = ckpt["model_state_dict"]
|
| 29 |
state = {k.removeprefix("_orig_mod."): v for k, v in state.items()}
|
| 30 |
state.pop("decoder._overlap_count", None)
|
|
|
|
|
|
|
|
|
|
| 31 |
_unwrap(model).load_state_dict(state)
|
| 32 |
return ckpt["epoch"], ckpt.get("loss")
|
|
|
|
| 28 |
state = ckpt["model_state_dict"]
|
| 29 |
state = {k.removeprefix("_orig_mod."): v for k, v in state.items()}
|
| 30 |
state.pop("decoder._overlap_count", None)
|
| 31 |
+
# Pre-buffer checkpoints lack align.temperature; default to 1.0.
|
| 32 |
+
if "align.temperature" not in state:
|
| 33 |
+
state["align.temperature"] = torch.tensor(1.0)
|
| 34 |
_unwrap(model).load_state_dict(state)
|
| 35 |
return ckpt["epoch"], ckpt.get("loss")
|