nukopy commited on
Commit
e72f744
·
1 Parent(s): ddbbe8e

fix: base dir

Browse files
apps/audio_cloning/vallex/g2p/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  """from https://github.com/keithito/tacotron"""
2
 
 
3
  import os
4
 
5
  # import utils.g2p.cleaners
@@ -9,19 +10,29 @@ import apps.audio_cloning.vallex.g2p.cleaners as cleaners
9
 
10
  from .symbols import symbols
11
 
 
 
12
  # Mappings from symbol to numeric ID and vice versa:
13
  _symbol_to_id = {s: i for i, s in enumerate(symbols)}
14
  _id_to_symbol = {i: s for i, s in enumerate(symbols)}
15
 
16
 
17
- BASE_DIR = os.getenv("HF_HOME", ".")
18
- TOKENIZER_PATH = os.path.join(BASE_DIR, "apps/audio_cloning/vallex/g2p/bpe_1024.json")
 
 
19
 
20
 
21
  class PhonemeBpeTokenizer:
22
  def __init__(self, tokenizer_path=TOKENIZER_PATH):
23
  print(f"Initializing PhonemeBpeTokenizer with tokenizer path: {tokenizer_path}")
24
- self.tokenizer = Tokenizer.from_file(tokenizer_path)
 
 
 
 
 
 
25
 
26
  def tokenize(self, text):
27
  # 1. convert text to phoneme
 
1
  """from https://github.com/keithito/tacotron"""
2
 
3
+ import logging
4
  import os
5
 
6
  # import utils.g2p.cleaners
 
10
 
11
  from .symbols import symbols
12
 
13
+ logger = logging.getLogger(__name__)
14
+
15
  # Mappings from symbol to numeric ID and vice versa:
16
  _symbol_to_id = {s: i for i, s in enumerate(symbols)}
17
  _id_to_symbol = {i: s for i, s in enumerate(symbols)}
18
 
19
 
20
+ PREPARED_BASE_DIR = "."
21
+ TOKENIZER_PATH = os.path.join(
22
+ PREPARED_BASE_DIR, "apps/audio_cloning/vallex/g2p/bpe_1024.json"
23
+ )
24
 
25
 
26
  class PhonemeBpeTokenizer:
27
  def __init__(self, tokenizer_path=TOKENIZER_PATH):
28
  print(f"Initializing PhonemeBpeTokenizer with tokenizer path: {tokenizer_path}")
29
+ try:
30
+ self.tokenizer = Tokenizer.from_file(tokenizer_path)
31
+ except Exception as e:
32
+ logger.error(
33
+ f"Error initializing PhonemeBpeTokenizer when reading file: {tokenizer_path}: {e}"
34
+ )
35
+ raise e
36
 
37
  def tokenize(self, text):
38
  # 1. convert text to phoneme
apps/audio_cloning/vallex/main.py CHANGED
@@ -41,8 +41,10 @@ from .models.vallex import VALLE
41
  logger = logging.getLogger(__name__)
42
 
43
  # set base directory
44
- BASE_DIR = os.getenv("HF_HOME", ".")
45
- logger.info("Base directory: %s", BASE_DIR)
 
 
46
 
47
  # set languages
48
  langid.set_languages(["en", "zh", "ja"])
@@ -90,7 +92,9 @@ else:
90
 
91
  # set text tokenizer and collater
92
  logger.info("Setting text tokenizer and collater...")
93
- tokenizer_path = os.path.join(BASE_DIR, "apps/audio_cloning/vallex/g2p/bpe_69.json")
 
 
94
  text_tokenizer = PhonemeBpeTokenizer(tokenizer_path=tokenizer_path)
95
  text_collater = get_text_token_collater()
96
 
@@ -104,7 +108,7 @@ if torch.cuda.is_available():
104
  logger.info("Device set to %s", device)
105
 
106
  # Download VALL-E-X model weights if not exists
107
- OUTPUT_DIR_CHECKPOINTS = os.path.join(BASE_DIR, "models/checkpoints")
108
  OUTPUT_FILENAME_CHECKPOINTS = "vallex-checkpoint.pt"
109
  OUTPUT_PATH_CHECKPOINTS = os.path.join(
110
  OUTPUT_DIR_CHECKPOINTS, OUTPUT_FILENAME_CHECKPOINTS
@@ -142,7 +146,9 @@ model = VALLE(
142
  prepend_bos=True,
143
  num_quantizers=NUM_QUANTIZERS,
144
  )
145
- checkpoint = torch.load(OUTPUT_PATH_CHECKPOINTS, map_location="cpu", weights_only=False)
 
 
146
  missing_keys, unexpected_keys = model.load_state_dict(checkpoint["model"], strict=True)
147
  assert not missing_keys
148
  model.eval()
@@ -155,7 +161,7 @@ audio_tokenizer = AudioTokenizer(device)
155
  vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(device)
156
 
157
  # initialize ASR model
158
- OUTPUT_DIR_WHISPER = os.path.join(BASE_DIR, "models/whisper")
159
  if not os.path.exists(OUTPUT_DIR_WHISPER):
160
  os.makedirs(OUTPUT_DIR_WHISPER, exist_ok=True)
161
 
@@ -176,7 +182,7 @@ except Exception as e:
176
 
177
  # Initialize Voice Presets
178
  logger.info("Initializing Voice Presets...")
179
- PRESETS_DIR = os.path.join(BASE_DIR, "apps/audio_cloning/vallex/presets")
180
  preset_list = os.walk(PRESETS_DIR).__next__()[2]
181
  preset_list = [preset[:-4] for preset in preset_list if preset.endswith(".npz")]
182
 
 
41
  logger = logging.getLogger(__name__)
42
 
43
  # set base directory
44
+ OUTPUT_BASE_DIR = os.getenv("HF_HOME", ".")
45
+ PREPARED_BASE_DIR = "."
46
+ logger.info("Base directory: %s", OUTPUT_BASE_DIR)
47
+ logger.info("Prepared base directory: %s", PREPARED_BASE_DIR)
48
 
49
  # set languages
50
  langid.set_languages(["en", "zh", "ja"])
 
92
 
93
  # set text tokenizer and collater
94
  logger.info("Setting text tokenizer and collater...")
95
+ tokenizer_path = os.path.join(
96
+ PREPARED_BASE_DIR, "apps/audio_cloning/vallex/g2p/bpe_69.json"
97
+ )
98
  text_tokenizer = PhonemeBpeTokenizer(tokenizer_path=tokenizer_path)
99
  text_collater = get_text_token_collater()
100
 
 
108
  logger.info("Device set to %s", device)
109
 
110
  # Download VALL-E-X model weights if not exists
111
+ OUTPUT_DIR_CHECKPOINTS = os.path.join(OUTPUT_BASE_DIR, "models/checkpoints")
112
  OUTPUT_FILENAME_CHECKPOINTS = "vallex-checkpoint.pt"
113
  OUTPUT_PATH_CHECKPOINTS = os.path.join(
114
  OUTPUT_DIR_CHECKPOINTS, OUTPUT_FILENAME_CHECKPOINTS
 
146
  prepend_bos=True,
147
  num_quantizers=NUM_QUANTIZERS,
148
  )
149
+ checkpoint = torch.load(
150
+ OUTPUT_PATH_CHECKPOINTS, map_location=device, weights_only=False
151
+ )
152
  missing_keys, unexpected_keys = model.load_state_dict(checkpoint["model"], strict=True)
153
  assert not missing_keys
154
  model.eval()
 
161
  vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(device)
162
 
163
  # initialize ASR model
164
+ OUTPUT_DIR_WHISPER = os.path.join(OUTPUT_BASE_DIR, "models/whisper")
165
  if not os.path.exists(OUTPUT_DIR_WHISPER):
166
  os.makedirs(OUTPUT_DIR_WHISPER, exist_ok=True)
167
 
 
182
 
183
  # Initialize Voice Presets
184
  logger.info("Initializing Voice Presets...")
185
+ PRESETS_DIR = os.path.join(PREPARED_BASE_DIR, "apps/audio_cloning/vallex/presets")
186
  preset_list = os.walk(PRESETS_DIR).__next__()[2]
187
  preset_list = [preset[:-4] for preset in preset_list if preset.endswith(".npz")]
188