orrp commited on
Commit
a0bbc38
·
1 Parent(s): e60681a

Model into Docker

Browse files
Files changed (2) hide show
  1. Dockerfile +16 -3
  2. vampnet/app.py +17 -21
Dockerfile CHANGED
@@ -1,17 +1,30 @@
1
  FROM python:3.10-slim
2
  COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv
3
 
4
- # Install ffmpeg for audio processing
5
  RUN apt-get update && apt-get install -y ffmpeg git build-essential && rm -rf /var/lib/apt/lists/*
6
 
7
  WORKDIR /app
8
 
9
- # Install dependencies using uv
 
 
 
 
 
 
 
 
 
 
 
10
  COPY pyproject.toml .
11
  RUN uv pip install --system .
12
 
13
- # Copy your code and run the app
14
  COPY . .
 
15
  EXPOSE 7860
16
  ENV GRADIO_SERVER_NAME="0.0.0.0"
 
17
  CMD ["python", "vampnet/app.py"]
 
1
  FROM python:3.10-slim
2
  COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv
3
 
4
+ # 1. system dependencies
5
  RUN apt-get update && apt-get install -y ffmpeg git build-essential && rm -rf /var/lib/apt/lists/*
6
 
7
  WORKDIR /app
8
 
9
+ # 2. hf_hub
10
+ # This ensures changing your code doesn't trigger a re-download
11
+ RUN uv pip install --system huggingface_hub
12
+
13
+ # 3. Download weights
14
+ RUN mkdir -p /app/vampnet/models && \
15
+ python3 -c "from huggingface_hub import hf_hub_download; \
16
+ repo = 'ProjectCETI/wham'; \
17
+ [hf_hub_download(repo_id=repo, filename=f, local_dir='/app/vampnet/models') \
18
+ for f in ['codec.pth', 'coarse.pth', 'c2f.pth', 'wavebeat.pth']]"
19
+
20
+ # 4. Install project dependencies
21
  COPY pyproject.toml .
22
  RUN uv pip install --system .
23
 
24
+ # 5. copy code
25
  COPY . .
26
+
27
  EXPOSE 7860
28
  ENV GRADIO_SERVER_NAME="0.0.0.0"
29
+ # Ensure we run from the root so imports like 'from vampnet' work
30
  CMD ["python", "vampnet/app.py"]
vampnet/app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import sys
2
  import uuid
3
  from pathlib import Path
@@ -8,40 +9,39 @@ import gradio as gr
8
  import numpy as np
9
  import torch
10
  import yaml
 
 
11
 
12
  from vampnet import mask as pmask
13
  from vampnet.interface import Interface
14
 
15
- SCRIPT_DIR = Path(__file__).parent
 
 
 
16
 
17
- # 2. Define the models directory
18
  MODEL_DIR = SCRIPT_DIR / "models"
19
- MODEL_DIR.mkdir(parents=True, exist_ok=True)
20
 
21
 
22
  def ensure_models_exist():
23
- """Downloads weights from HF Hub if they aren't in the models/ folder."""
24
  repo_id = "ProjectCETI/wham"
25
- # List all the .pth files your app needs
26
- files_to_download = ["codec.pth", "coarse.pth", "c2f.pth", "wavebeat.pth"]
27
- print(f"Checking for model weights in {MODEL_DIR}...")
28
- for filename in files_to_download:
29
- target_file = MODEL_DIR / filename
30
- if not target_file.exists():
31
- print(f"Downloading {filename} from {repo_id}...")
32
  hf_hub_download(
33
- repo_id=repo_id,
34
- filename=filename,
35
- local_dir=str(MODEL_DIR),
36
- local_dir_use_symlinks=False,
37
  )
38
- else:
39
- print(f"✓ {filename} found.")
40
 
41
 
 
42
  ensure_models_exist()
43
 
 
44
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
45
  sys.argv = ["app.py", "--args.load", "conf/interface.yml", "--Interface.device", device]
46
 
47
 
@@ -50,9 +50,6 @@ Interface = argbind.bind(Interface)
50
  conf = argbind.parse_args()
51
 
52
 
53
- from torch_pitch_shift import pitch_shift
54
-
55
-
56
  def shift_pitch(signal, interval: int):
57
  signal.samples = pitch_shift(
58
  signal.samples, shift=interval, sample_rate=signal.sample_rate
@@ -232,7 +229,6 @@ def _vamp(
232
 
233
 
234
  def _extract_and_call_vamp(data, return_mask):
235
- """Extract plain values from Gradio data dict so only picklable args cross the ZeroGPU boundary."""
236
  return _vamp(
237
  _input_audio=data[input_audio],
238
  _num_steps=data[num_steps],
 
1
+ import os
2
  import sys
3
  import uuid
4
  from pathlib import Path
 
9
  import numpy as np
10
  import torch
11
  import yaml
12
+ from huggingface_hub import hf_hub_download
13
+ from torch_pitch_shift import pitch_shift
14
 
15
  from vampnet import mask as pmask
16
  from vampnet.interface import Interface
17
 
18
+ # 1. Setup paths and WorkDir
19
+ SCRIPT_DIR = Path(__file__).parent.absolute()
20
+ # This ensures relative paths like 'conf/interface.yml' work correctly
21
+ os.chdir(SCRIPT_DIR)
22
 
 
23
  MODEL_DIR = SCRIPT_DIR / "models"
 
24
 
25
 
26
  def ensure_models_exist():
27
+ """Fallback check for weights. In Docker, these are already baked in."""
28
  repo_id = "ProjectCETI/wham"
29
+ files = ["codec.pth", "coarse.pth", "c2f.pth", "wavebeat.pth"]
30
+
31
+ for filename in files:
32
+ if not (MODEL_DIR / filename).exists():
33
+ print(f"Weight {filename} missing, downloading...")
 
 
34
  hf_hub_download(
35
+ repo_id=repo_id, filename=filename, local_dir=str(MODEL_DIR)
 
 
 
36
  )
 
 
37
 
38
 
39
+ # Run the check
40
  ensure_models_exist()
41
 
42
+ # 2. Hardware Setup
43
  device = "cuda" if torch.cuda.is_available() else "cpu"
44
+ # Update sys.argv so argbind finds the config file correctly
45
  sys.argv = ["app.py", "--args.load", "conf/interface.yml", "--Interface.device", device]
46
 
47
 
 
50
  conf = argbind.parse_args()
51
 
52
 
 
 
 
53
  def shift_pitch(signal, interval: int):
54
  signal.samples = pitch_shift(
55
  signal.samples, shift=interval, sample_rate=signal.sample_rate
 
229
 
230
 
231
  def _extract_and_call_vamp(data, return_mask):
 
232
  return _vamp(
233
  _input_audio=data[input_audio],
234
  _num_steps=data[num_steps],