asdrty123 commited on
Commit
3d56e24
·
verified ·
1 Parent(s): bf0f511

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -33
app.py CHANGED
@@ -10,6 +10,7 @@ import gradio as gr
10
  import librosa
11
  import torch
12
  from fairseq import checkpoint_utils
 
13
 
14
  from config import Config
15
  from lib.infer_pack.models import (
@@ -20,6 +21,8 @@ from lib.infer_pack.models import (
20
  )
21
  from rmvpe import RMVPE
22
  from vc_infer_pipeline import VC
 
 
23
 
24
  logging.getLogger("fairseq").setLevel(logging.WARNING)
25
  logging.getLogger("numba").setLevel(logging.WARNING)
@@ -36,30 +39,110 @@ tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voice
36
  tts_voices = [f"{v['ShortName']}-{v['Gender']}" for v in tts_voice_list]
37
 
38
  model_root = "weights"
39
- models = [
 
 
40
  d for d in os.listdir(model_root) if os.path.isdir(os.path.join(model_root, d))
41
  ]
42
- if len(models) == 0:
43
- raise ValueError("No model found in `weights` folder")
44
- models.sort()
45
-
46
-
47
- def model_data(model_name):
48
- # global n_spk, tgt_sr, net_g, vc, cpt, version, index_file
49
- pth_files = [
50
- os.path.join(model_root, model_name, f)
51
- for f in os.listdir(os.path.join(model_root, model_name))
52
- if f.endswith(".pth")
53
- ]
54
- if len(pth_files) == 0:
55
- raise ValueError(f"No pth file found in {model_root}/{model_name}")
56
- pth_path = pth_files[0]
57
- print(f"Loading {pth_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  cpt = torch.load(pth_path, map_location="cpu")
59
  tgt_sr = cpt["config"][-1]
60
  cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
61
  if_f0 = cpt.get("f0", 1)
62
  version = cpt.get("version", "v1")
 
 
 
 
63
  if version == "v1":
64
  if if_f0 == 1:
65
  net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half)
@@ -72,22 +155,21 @@ def model_data(model_name):
72
  net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
73
  else:
74
  raise ValueError("Unknown version")
 
75
  del net_g.enc_q
76
  net_g.load_state_dict(cpt["weight"], strict=False)
77
- print("Model loaded")
78
  net_g.eval().to(config.device)
 
79
  if config.is_half:
80
  net_g = net_g.half()
81
  else:
82
  net_g = net_g.float()
 
83
  vc = VC(tgt_sr, config)
84
- # n_spk = cpt["config"][-3]
85
 
86
- index_files = [
87
- os.path.join(model_root, model_name, f)
88
- for f in os.listdir(os.path.join(model_root, model_name))
89
- if f.endswith(".index")
90
- ]
91
  if len(index_files) == 0:
92
  print("No index file found")
93
  index_file = ""
@@ -100,19 +182,17 @@ def model_data(model_name):
100
 
101
  def load_hubert():
102
  global hubert_model
103
- models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
104
- ["hubert_base.pt"],
105
- suffix="",
106
- )
 
 
107
  hubert_model = models[0]
108
  hubert_model = hubert_model.to(config.device)
109
- if config.is_half:
110
- hubert_model = hubert_model.half()
111
- else:
112
- hubert_model = hubert_model.float()
113
  return hubert_model.eval()
114
 
115
-
116
  print("Loading hubert model...")
117
  hubert_model = load_hubert()
118
  print("Hubert model loaded.")
 
10
  import librosa
11
  import torch
12
  from fairseq import checkpoint_utils
13
+ from fairseq.data import dictionary
14
 
15
  from config import Config
16
  from lib.infer_pack.models import (
 
21
  )
22
  from rmvpe import RMVPE
23
  from vc_infer_pipeline import VC
24
+ from huggingface_hub import HfApi, hf_hub_download
25
+ from collections import defaultdict
26
 
27
  logging.getLogger("fairseq").setLevel(logging.WARNING)
28
  logging.getLogger("numba").setLevel(logging.WARNING)
 
39
  tts_voices = [f"{v['ShortName']}-{v['Gender']}" for v in tts_voice_list]
40
 
41
  model_root = "weights"
42
+
43
+ # ---- Local models ----
44
+ local_models = [
45
  d for d in os.listdir(model_root) if os.path.isdir(os.path.join(model_root, d))
46
  ]
47
+ if not local_models:
48
+ print("⚠️ No model found in local `weights` folder")
49
+ local_models.sort()
50
+
51
+ # ---- HF models ----
52
+ REPO_ID = "simpsonsaiorg/stream-models"
53
+ api = HfApi()
54
+
55
+ all_files = api.list_repo_files(REPO_ID)
56
+
57
+ hf_models = defaultdict(list)
58
+ for f in all_files:
59
+ parts = f.split("/")
60
+ if len(parts) == 3 and parts[0] == "weights":
61
+ model_name, filename = parts[1], parts[2]
62
+ hf_models[model_name].append(filename)
63
+
64
+ # ---- Merge / display ----
65
+ av_models = sorted(set(local_models) | set(hf_models.keys()))
66
+
67
+ print("Local models:", local_models)
68
+ print("HF models:", list(hf_models.keys()))
69
+ print("Available models (combined):", av_models)
70
+
71
+ models = hf_models
72
+
73
+
74
+ # Example: load a specific model (like your hubert loader)
75
+ """
76
+ def load_model(model_name):
77
+ if model_name not in models:
78
+ raise ValueError(f"Model '{model_name}' not found in repo")
79
+ files = models[model_name]
80
+ loaded = {}
81
+ for file in files:
82
+ path = hf_hub_download(repo_id=REPO_ID, filename=f"{model_name}/{file}")
83
+ loaded[file] = torch.load(path, map_location="cpu")
84
+ return loaded
85
+
86
+ # Load homer model
87
+ #homer_model = load_model("homer")
88
+ """
89
+
90
+
91
+
92
+ def model_data(model_name, model_root="weights", repo_id="simpsonsaiorg/stream-models", use_hf=True):
93
+ """
94
+ Load a model either from local disk or HuggingFace repo.
95
+ """
96
+
97
+ if not use_hf:
98
+ # --- Local load ---
99
+ pth_files = [
100
+ os.path.join(model_root, model_name, f)
101
+ for f in os.listdir(os.path.join(model_root, model_name))
102
+ if f.endswith(".pth")
103
+ ]
104
+ if len(pth_files) == 0:
105
+ raise ValueError(f"No .pth file found in {model_root}/{model_name}")
106
+ pth_path = pth_files[0]
107
+
108
+ index_files = [
109
+ os.path.join(model_root, model_name, f)
110
+ for f in os.listdir(os.path.join(model_root, model_name))
111
+ if f.endswith(".index")
112
+ ]
113
+ else:
114
+ # --- HuggingFace load ---
115
+ all_files = api.list_repo_files(repo_id)
116
+ model_files = [f for f in all_files if f.startswith(f"weights/{model_name}/")]
117
+
118
+ # Find .pth file
119
+ pth_files = [f for f in model_files if f.endswith(".pth")]
120
+ if not pth_files:
121
+ raise ValueError(f"No .pth file found for model {model_name} in repo")
122
+ pth_path = hf_hub_download(repo_id=repo_id, filename=pth_files[0])
123
+
124
+ # Find index files
125
+ index_files = [
126
+ hf_hub_download(repo_id=repo_id, filename=f)
127
+ for f in model_files
128
+ if f.endswith(".index") and "added_IVF" in f
129
+ ]
130
+
131
+ print(f"Loading {pth_path}") # <-- safe to do for both cases
132
+
133
+
134
+ # -----------------------
135
+ # 2. Load checkpoint
136
+ # -----------------------
137
  cpt = torch.load(pth_path, map_location="cpu")
138
  tgt_sr = cpt["config"][-1]
139
  cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
140
  if_f0 = cpt.get("f0", 1)
141
  version = cpt.get("version", "v1")
142
+
143
+ # -----------------------
144
+ # 3. Init network
145
+ # -----------------------
146
  if version == "v1":
147
  if if_f0 == 1:
148
  net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half)
 
155
  net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
156
  else:
157
  raise ValueError("Unknown version")
158
+
159
  del net_g.enc_q
160
  net_g.load_state_dict(cpt["weight"], strict=False)
 
161
  net_g.eval().to(config.device)
162
+
163
  if config.is_half:
164
  net_g = net_g.half()
165
  else:
166
  net_g = net_g.float()
167
+
168
  vc = VC(tgt_sr, config)
 
169
 
170
+ # -----------------------
171
+ # 4. Index file
172
+ # -----------------------
 
 
173
  if len(index_files) == 0:
174
  print("No index file found")
175
  index_file = ""
 
182
 
183
  def load_hubert():
184
  global hubert_model
185
+ safe_globals = [dictionary.Dictionary] # allow this class
186
+ with torch.serialization.safe_globals(safe_globals):
187
+ models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
188
+ ["hubert_base.pt"],
189
+ suffix="",
190
+ )
191
  hubert_model = models[0]
192
  hubert_model = hubert_model.to(config.device)
193
+ hubert_model = hubert_model.half() if config.is_half else hubert_model.float()
 
 
 
194
  return hubert_model.eval()
195
 
 
196
  print("Loading hubert model...")
197
  hubert_model = load_hubert()
198
  print("Hubert model loaded.")