jhansss commited on
Commit
64e2d77
·
2 Parent(s): 651aefd 92276c4

Merge branch 'refactor' into hf

Browse files
Files changed (6) hide show
  1. README.md +16 -1
  2. cli.py +22 -17
  3. interface.py +2 -2
  4. modules/llm/minimax.py +2 -1
  5. modules/melody.py +3 -2
  6. modules/utils/g2p.py +6 -4
README.md CHANGED
@@ -58,7 +58,22 @@ pip install -r requirements.txt
58
  #### Example Usage
59
 
60
  ```bash
61
- python cli.py --query_audio tests/audio/hello.wav --config_path config/cli/yaoyin_default.yaml --output_audio outputs/yaoyin_hello.wav
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  ```
63
 
64
  #### Parameter Description
 
58
  #### Example Usage
59
 
60
  ```bash
61
+ python cli.py \
62
+ --query_audio tests/audio/hello.wav \
63
+ --config_path config/cli/yaoyin_default.yaml \
64
+ --output_audio outputs/yaoyin_hello.wav \
65
+ --eval_results_csv outputs/yaoyin_test.csv
66
+ ```
67
+
68
+ #### Inference-Only Mode
69
+
70
+ Run minimal inference without evaluation.
71
+
72
+ ```bash
73
+ python cli.py \
74
+ --query_audio tests/audio/hello.wav \
75
+ --config_path config/cli/yaoyin_default_infer_only.yaml \
76
+ --output_audio outputs/yaoyin_hello.wav
77
  ```
78
 
79
  #### Parameter Description
cli.py CHANGED
@@ -17,7 +17,7 @@ def get_parser():
17
  "--config_path", type=Path, default="config/cli/yaoyin_default.yaml"
18
  )
19
  parser.add_argument("--output_audio_folder", type=Path, required=True)
20
- parser.add_argument("--eval_results_csv", type=Path, required=True)
21
  return parser
22
 
23
 
@@ -38,11 +38,15 @@ def main():
38
  character = get_character(character_name)
39
  prompt_template = character.prompt
40
  args.output_audio_folder.mkdir(parents=True, exist_ok=True)
41
- args.eval_results_csv.parent.mkdir(parents=True, exist_ok=True)
42
- with open(args.eval_results_csv, "a") as f:
43
- f.write(
44
- f"query_audio,asr_model,llm_model,svs_model,melody_source,language,speaker,output_audio,asr_text,llm_text,metrics\n"
45
- )
 
 
 
 
46
  try:
47
  for query_audio in args.query_audios:
48
  output_audio = args.output_audio_folder / f"{query_audio.stem}_response.wav"
@@ -53,19 +57,20 @@ def main():
53
  speaker,
54
  output_audio_path=output_audio,
55
  )
56
- metrics = pipeline.evaluate(output_audio, **results)
57
- metrics.update(results.get("metrics", {}))
58
- metrics_str = ",".join([f"{metrics[k]}" for k in sorted(metrics.keys())])
59
- logger.info(
60
- f"Input: {query_audio}, Output: {output_audio}, ASR results: {results['asr_text']}, LLM results: {results['llm_text']}"
61
- )
62
- with open(args.eval_results_csv, "a") as f:
63
- f.write(
64
- f"{query_audio},{config['asr_model']},{config['llm_model']},{config['svs_model']},{config['melody_source']},{config['language']},{config['speaker']},{output_audio},{results['asr_text']},{results['llm_text']},{metrics_str}\n"
65
  )
 
 
 
 
66
  except Exception as e:
67
- logger.error(f"Error in main: {e}")
68
- breakpoint()
69
  raise e
70
 
71
 
 
17
  "--config_path", type=Path, default="config/cli/yaoyin_default.yaml"
18
  )
19
  parser.add_argument("--output_audio_folder", type=Path, required=True)
20
+ parser.add_argument("--eval_results_csv", type=Path, default=None)
21
  return parser
22
 
23
 
 
38
  character = get_character(character_name)
39
  prompt_template = character.prompt
40
  args.output_audio_folder.mkdir(parents=True, exist_ok=True)
41
+ if config.get("evaluators", {}):
42
+ if args.eval_results_csv:
43
+ args.eval_results_csv.parent.mkdir(parents=True, exist_ok=True)
44
+ with open(args.eval_results_csv, "a") as f:
45
+ f.write(
46
+ f"query_audio,asr_model,llm_model,svs_model,melody_source,language,speaker,output_audio,asr_text,llm_text,metrics\n"
47
+ )
48
+ else:
49
+ logger.warning("No eval_results_csv provided, skipping evaluation")
50
  try:
51
  for query_audio in args.query_audios:
52
  output_audio = args.output_audio_folder / f"{query_audio.stem}_response.wav"
 
57
  speaker,
58
  output_audio_path=output_audio,
59
  )
60
+ if args.eval_results_csv and config.get("evaluators", {}):
61
+ metrics = pipeline.evaluate(output_audio, **results)
62
+ metrics.update(results.get("metrics", {}))
63
+ metrics_str = ",".join([f"{metrics[k]}" for k in sorted(metrics.keys())])
64
+ logger.info(
65
+ f"Input: {query_audio}, Output: {output_audio}, ASR results: {results['asr_text']}, LLM results: {results['llm_text']}"
 
 
 
66
  )
67
+ with open(args.eval_results_csv, "a") as f:
68
+ f.write(
69
+ f"{query_audio},{config['asr_model']},{config['llm_model']},{config['svs_model']},{config['melody_source']},{config['language']},{config['speaker']},{output_audio},{results['asr_text']},{results['llm_text']},{metrics_str}\n"
70
+ )
71
  except Exception as e:
72
+ import traceback
73
+ logger.error(traceback.format_exc())
74
  raise e
75
 
76
 
interface.py CHANGED
@@ -159,8 +159,8 @@ class GradioInterface:
159
 
160
  return demo
161
  except Exception as e:
162
- print(f"error: {e}")
163
- breakpoint()
164
  return gr.Blocks()
165
 
166
  def update_character(self, character):
 
159
 
160
  return demo
161
  except Exception as e:
162
+ import traceback
163
+ print(traceback.format_exc())
164
  return gr.Blocks()
165
 
166
  def update_character(self, character):
modules/llm/minimax.py CHANGED
@@ -77,7 +77,8 @@ class MiniMaxLLM(AbstractLLMModel):
77
  )
78
  except Exception as e:
79
  print(f"Failed to load MiniMax model: {e}")
80
- breakpoint()
 
81
  raise e
82
 
83
  def generate(
 
77
  )
78
  except Exception as e:
79
  print(f"Failed to load MiniMax model: {e}")
80
+ import traceback
81
+ print(traceback.format_exc())
82
  raise e
83
 
84
  def generate(
modules/melody.py CHANGED
@@ -25,7 +25,7 @@ class MelodyController:
25
  def get_melody_constraints(self, max_num_phrases: int = 5) -> str:
26
  """Return a lyric-format prompt based on melody structure."""
27
  if self.mode == "gen":
28
- return ""
29
 
30
  elif self.mode == "sample":
31
  assert self.database is not None, "Song database is not loaded."
@@ -46,10 +46,11 @@ class MelodyController:
46
  )
47
  + "\n如果没有足够的信息回答,请使用最少的句子,不要重复、不要扩展、不要加入无关内容。\n"
48
  )
49
- return prompt
50
 
51
  else:
52
  raise ValueError(f"Unsupported melody mode: {self.mode}")
 
 
53
 
54
  def generate_score(
55
  self, lyrics: str, language: str
 
25
  def get_melody_constraints(self, max_num_phrases: int = 5) -> str:
26
  """Return a lyric-format prompt based on melody structure."""
27
  if self.mode == "gen":
28
+ prompt = ""
29
 
30
  elif self.mode == "sample":
31
  assert self.database is not None, "Song database is not loaded."
 
46
  )
47
  + "\n如果没有足够的信息回答,请使用最少的句子,不要重复、不要扩展、不要加入无关内容。\n"
48
  )
 
49
 
50
  else:
51
  raise ValueError(f"Unsupported melody mode: {self.mode}")
52
+ prompt += "请使用用户输入的语言回答"
53
+ return prompt
54
 
55
  def generate_score(
56
  self, lyrics: str, language: str
modules/utils/g2p.py CHANGED
@@ -3,12 +3,12 @@ import re
3
  import warnings
4
  from pathlib import Path
5
 
6
- from kanjiconv import KanjiConv
7
  from pypinyin import lazy_pinyin
8
 
9
  from .resources.pinyin_dict import PINYIN_DICT
10
 
11
- kanji_to_kana = KanjiConv()
12
 
13
  yoon_map = {
14
  "ぁ": "あ",
@@ -32,9 +32,9 @@ for plan in ace_phonemes_all_plans["plans"]:
32
 
33
 
34
  def preprocess_text(text: str, language: str) -> list[str]:
35
- text = text.replace(" ", "")
36
  if language == "mandarin":
37
  text_list = to_pinyin(text)
 
38
  elif language == "japanese":
39
  text_list = to_kana(text)
40
  else:
@@ -117,7 +117,9 @@ def replace_chouonpu(hiragana_text: str) -> str:
117
 
118
 
119
  def to_kana(text: str) -> list[str]:
120
- hiragana_text = kanji_to_kana.to_hiragana(text.replace(" ", ""))
 
 
121
  hiragana_text_wl = replace_chouonpu(hiragana_text).split(" ")
122
  final_ls = []
123
  for subword in hiragana_text_wl:
 
3
  import warnings
4
  from pathlib import Path
5
 
6
+ import pykakasi
7
  from pypinyin import lazy_pinyin
8
 
9
  from .resources.pinyin_dict import PINYIN_DICT
10
 
11
+ kks = pykakasi.kakasi()
12
 
13
  yoon_map = {
14
  "ぁ": "あ",
 
32
 
33
 
34
  def preprocess_text(text: str, language: str) -> list[str]:
 
35
  if language == "mandarin":
36
  text_list = to_pinyin(text)
37
+ text_list = [pinyin for pinyin in text_list if pinyin != " "]
38
  elif language == "japanese":
39
  text_list = to_kana(text)
40
  else:
 
117
 
118
 
119
  def to_kana(text: str) -> list[str]:
120
+ hiragana_text = "".join(
121
+ [item["hira"] for item in kks.convert(text.replace(" ", ""))]
122
+ )
123
  hiragana_text_wl = replace_chouonpu(hiragana_text).split(" ")
124
  final_ls = []
125
  for subword in hiragana_text_wl: