richiejp commited on
Commit
6e0a6e4
·
verified ·
1 Parent(s): 3a7d44e

Initial upload: LocalVQE demo Space

Browse files
app.py CHANGED
@@ -36,11 +36,11 @@ def _build_model() -> LocalVQE:
36
  del peek
37
  model = LocalVQE.from_config(cfg).to("cpu")
38
  load_checkpoint(ckpt_path, model)
39
- # Bake the trained AlignBlock softmax temperature into the smoothing
40
- # conv weights checkpoints don't persist the temperature scalar, so
41
- # without this the model runs at the default 1.0 instead of the
42
- # trained 0.1, costing several dB of FE-ST ERLE on real recordings.
43
- model.align.fold_temperature(cfg.model.align_temp_end)
44
  model.eval()
45
  n_params = sum(p.numel() for p in model.parameters())
46
  print(f"LocalVQE loaded: {n_params:,} params from {ckpt_path}")
 
36
  del peek
37
  model = LocalVQE.from_config(cfg).to("cpu")
38
  load_checkpoint(ckpt_path, model)
39
+ # Fold the trained AlignBlock softmax temperature (carried in the
40
+ # checkpoint as a buffer) into the smoothing conv weights — without
41
+ # this the model runs at the default 1.0 instead of the trained value
42
+ # and loses several dB of FE-ST ERLE on real recordings.
43
+ model.align.fold_temperature()
44
  model.eval()
45
  n_params = sum(p.numel() for p in model.parameters())
46
  print(f"LocalVQE loaded: {n_params:,} params from {ckpt_path}")
localvqe_model/align.py CHANGED
@@ -20,7 +20,9 @@ class AlignBlock(nn.Module):
20
  self.in_channels = in_channels
21
  self.hidden_channels = hidden_channels
22
  self.dmax = dmax
23
- self.temperature = temperature
 
 
24
 
25
  # Pointwise projections for Q and K
26
  self.pconv_mic = nn.Conv2d(in_channels, hidden_channels, 1)
@@ -39,21 +41,18 @@ class AlignBlock(nn.Module):
39
  )
40
 
41
  def fold_temperature(self, temperature=None):
42
- """Bake a softmax temperature into the smoothing-conv weights.
43
-
44
- After this, forward() at self.temperature=1.0 is mathematically
45
- equivalent to running with the given temperature on the original
46
- weights. Used at inference-load time so PyTorch eval and the GGML
47
- graph (whose softmax has no temperature parameter) both produce
48
- the trained-temperature distribution.
49
  """
50
- t = temperature if temperature is not None else self.temperature
51
  if t == 1.0:
52
  return
53
  with torch.no_grad():
54
  self.conv[1].weight.div_(t)
55
  self.conv[1].bias.div_(t)
56
- self.temperature = 1.0
57
 
58
  def forward(self, x_mic, x_ref, return_delay=False):
59
  """
 
20
  self.in_channels = in_channels
21
  self.hidden_channels = hidden_channels
22
  self.dmax = dmax
23
+ # Registered as a buffer so the trained value persists in state_dict.
24
+ # Mutate via .fill_(), never re-assign.
25
+ self.register_buffer("temperature", torch.tensor(float(temperature)))
26
 
27
  # Pointwise projections for Q and K
28
  self.pconv_mic = nn.Conv2d(in_channels, hidden_channels, 1)
 
41
  )
42
 
43
  def fold_temperature(self, temperature=None):
44
+ """Bake the AlignBlock softmax temperature into the smoothing-conv
45
+ weights. Reads `self.temperature` (a buffer carried by the
46
+ checkpoint) when called with no argument; after folding the buffer
47
+ is reset to 1.0 so subsequent calls are no-ops.
 
 
 
48
  """
49
+ t = float(temperature if temperature is not None else self.temperature)
50
  if t == 1.0:
51
  return
52
  with torch.no_grad():
53
  self.conv[1].weight.div_(t)
54
  self.conv[1].bias.div_(t)
55
+ self.temperature.fill_(1.0)
56
 
57
  def forward(self, x_mic, x_ref, return_delay=False):
58
  """
localvqe_model/config.py CHANGED
@@ -11,7 +11,6 @@ class ModelConfig:
11
  power_law_c: float = 0.3
12
  kernel_size: Tuple[int, int] = (4, 4)
13
  bottleneck_hidden: int = 0
14
- align_temp_end: float = 0.1
15
 
16
 
17
  @dataclass
 
11
  power_law_c: float = 0.3
12
  kernel_size: Tuple[int, int] = (4, 4)
13
  bottleneck_hidden: int = 0
 
14
 
15
 
16
  @dataclass
localvqe_model/utils.py CHANGED
@@ -28,5 +28,8 @@ def load_checkpoint(path, model):
28
  state = ckpt["model_state_dict"]
29
  state = {k.removeprefix("_orig_mod."): v for k, v in state.items()}
30
  state.pop("decoder._overlap_count", None)
 
 
 
31
  _unwrap(model).load_state_dict(state)
32
  return ckpt["epoch"], ckpt.get("loss")
 
28
  state = ckpt["model_state_dict"]
29
  state = {k.removeprefix("_orig_mod."): v for k, v in state.items()}
30
  state.pop("decoder._overlap_count", None)
31
+ # Pre-buffer checkpoints lack align.temperature; default to 1.0.
32
+ if "align.temperature" not in state:
33
+ state["align.temperature"] = torch.tensor(1.0)
34
  _unwrap(model).load_state_dict(state)
35
  return ckpt["epoch"], ckpt.get("loss")