richiejp commited on
Commit
f5e08b6
·
verified ·
1 Parent(s): ace7c6c

point at localvqe-v1-1.3M.pt (renamed on model repo)

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -1,28 +1,35 @@
1
  """Gradio demo for LocalVQE — real-time AEC + NS + dereverb.
2
 
3
- Downloads the shipped `localvqe-v1.pt` checkpoint from
4
- huggingface.co/LocalAI-io/LocalVQE on first run and caches it. Runs
5
- inference on CPU in the Space's default hardware tier.
 
6
  """
 
7
  from pathlib import Path
8
 
9
  import gradio as gr
10
  import numpy as np
11
  import soundfile as sf
12
  import torch
13
- from huggingface_hub import hf_hub_download
14
  from scipy.signal import resample_poly
15
 
16
  from localvqe_model import Config, LocalVQE, apply_ckpt_model_config, load_checkpoint
17
 
18
  SR = 16000
19
  REPO_ID = "LocalAI-io/LocalVQE"
20
- CKPT_FILE = "localvqe-v1.pt"
21
  EXAMPLES_DIR = Path(__file__).resolve().parent / "examples"
22
 
23
 
24
  def _build_model() -> LocalVQE:
25
- ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_FILE)
 
 
 
 
 
 
26
  cfg = Config()
27
  peek = torch.load(ckpt_path, map_location="cpu", weights_only=False)
28
  apply_ckpt_model_config(peek, cfg)
 
1
  """Gradio demo for LocalVQE — real-time AEC + NS + dereverb.
2
 
3
+ By default downloads the shipped `localvqe-v1-1.3M.pt` checkpoint from
4
+ huggingface.co/LocalAI-io/LocalVQE. Set the env var
5
+ `LOCALVQE_LOCAL_CKPT=/path/to/checkpoint.pt` to load a local file
6
+ instead — useful for auditioning new training runs.
7
  """
8
+ import os
9
  from pathlib import Path
10
 
11
  import gradio as gr
12
  import numpy as np
13
  import soundfile as sf
14
  import torch
 
15
  from scipy.signal import resample_poly
16
 
17
  from localvqe_model import Config, LocalVQE, apply_ckpt_model_config, load_checkpoint
18
 
19
  SR = 16000
20
  REPO_ID = "LocalAI-io/LocalVQE"
21
+ CKPT_FILE = "localvqe-v1-1.3M.pt"
22
  EXAMPLES_DIR = Path(__file__).resolve().parent / "examples"
23
 
24
 
25
  def _build_model() -> LocalVQE:
26
+ local_override = os.environ.get("LOCALVQE_LOCAL_CKPT")
27
+ if local_override:
28
+ ckpt_path = local_override
29
+ print(f"Loading local checkpoint: {ckpt_path}")
30
+ else:
31
+ from huggingface_hub import hf_hub_download
32
+ ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_FILE)
33
  cfg = Config()
34
  peek = torch.load(ckpt_path, map_location="cpu", weights_only=False)
35
  apply_ckpt_model_config(peek, cfg)