orrp commited on
Commit
4c5d66a
·
1 Parent(s): 9583919

Fixed missing model

Browse files
Files changed (1) hide show
  1. vampnet/app.py +27 -0
vampnet/app.py CHANGED
@@ -16,6 +16,33 @@ from vampnet.interface import Interface
16
  SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
17
  os.chdir(SCRIPT_DIR)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
  sys.argv = ["app.py", "--args.load", "conf/interface.yml", "--Interface.device", device]
21
 
 
16
  SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
17
  os.chdir(SCRIPT_DIR)
18
 
19
+ # 2. Define the models directory
20
+ MODEL_DIR = SCRIPT_DIR / "models"
21
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
22
+
23
+
24
+ def ensure_models_exist():
25
+ """Downloads weights from HF Hub if they aren't in the models/ folder."""
26
+ repo_id = "ProjectCETI/wham"
27
+ # List all the .pth files your app needs
28
+ files_to_download = ["codec.pth", "coarse.pth", "c2f.pth", "wavebeat.pth"]
29
+ print(f"Checking for model weights in {MODEL_DIR}...")
30
+ for filename in files_to_download:
31
+ target_file = MODEL_DIR / filename
32
+ if not target_file.exists():
33
+ print(f"Downloading {filename} from {repo_id}...")
34
+ hf_hub_download(
35
+ repo_id=repo_id,
36
+ filename=filename,
37
+ local_dir=str(MODEL_DIR),
38
+ local_dir_use_symlinks=False,
39
+ )
40
+ else:
41
+ print(f"✓ {filename} found.")
42
+
43
+
44
+ ensure_models_exist()
45
+
46
  device = "cuda" if torch.cuda.is_available() else "cpu"
47
  sys.argv = ["app.py", "--args.load", "conf/interface.yml", "--Interface.device", device]
48