jhansss commited on
Commit
9f5df1a
·
1 Parent(s): fa3bbac

Make eval_results_csv optional in CLI

Browse files
Files changed (1) hide show
  1. cli.py +20 -15
cli.py CHANGED
@@ -17,7 +17,7 @@ def get_parser():
17
  "--config_path", type=Path, default="config/cli/yaoyin_default.yaml"
18
  )
19
  parser.add_argument("--output_audio_folder", type=Path, required=True)
20
- parser.add_argument("--eval_results_csv", type=Path, required=True)
21
  return parser
22
 
23
 
@@ -38,11 +38,15 @@ def main():
38
  character = get_character(character_name)
39
  prompt_template = character.prompt
40
  args.output_audio_folder.mkdir(parents=True, exist_ok=True)
41
- args.eval_results_csv.parent.mkdir(parents=True, exist_ok=True)
42
- with open(args.eval_results_csv, "a") as f:
43
- f.write(
44
- f"query_audio,asr_model,llm_model,svs_model,melody_source,language,speaker,output_audio,asr_text,llm_text,metrics\n"
45
- )
 
 
 
 
46
  try:
47
  for query_audio in args.query_audios:
48
  output_audio = args.output_audio_folder / f"{query_audio.stem}_response.wav"
@@ -53,16 +57,17 @@ def main():
53
  speaker,
54
  output_audio_path=output_audio,
55
  )
56
- metrics = pipeline.evaluate(output_audio, **results)
57
- metrics.update(results.get("metrics", {}))
58
- metrics_str = ",".join([f"{metrics[k]}" for k in sorted(metrics.keys())])
59
- logger.info(
60
- f"Input: {query_audio}, Output: {output_audio}, ASR results: {results['asr_text']}, LLM results: {results['llm_text']}"
61
- )
62
- with open(args.eval_results_csv, "a") as f:
63
- f.write(
64
- f"{query_audio},{config['asr_model']},{config['llm_model']},{config['svs_model']},{config['melody_source']},{config['language']},{config['speaker']},{output_audio},{results['asr_text']},{results['llm_text']},{metrics_str}\n"
65
  )
 
 
 
 
66
  except Exception as e:
67
  logger.error(f"Error in main: {e}")
68
  breakpoint()
 
17
  "--config_path", type=Path, default="config/cli/yaoyin_default.yaml"
18
  )
19
  parser.add_argument("--output_audio_folder", type=Path, required=True)
20
+ parser.add_argument("--eval_results_csv", type=Path, default=None)
21
  return parser
22
 
23
 
 
38
  character = get_character(character_name)
39
  prompt_template = character.prompt
40
  args.output_audio_folder.mkdir(parents=True, exist_ok=True)
41
+ if config.get("evaluators", {}):
42
+ if args.eval_results_csv:
43
+ args.eval_results_csv.parent.mkdir(parents=True, exist_ok=True)
44
+ with open(args.eval_results_csv, "a") as f:
45
+ f.write(
46
+ f"query_audio,asr_model,llm_model,svs_model,melody_source,language,speaker,output_audio,asr_text,llm_text,metrics\n"
47
+ )
48
+ else:
49
+ logger.warning("No eval_results_csv provided, skipping evaluation")
50
  try:
51
  for query_audio in args.query_audios:
52
  output_audio = args.output_audio_folder / f"{query_audio.stem}_response.wav"
 
57
  speaker,
58
  output_audio_path=output_audio,
59
  )
60
+ if args.eval_results_csv and config.get("evaluators", {}):
61
+ metrics = pipeline.evaluate(output_audio, **results)
62
+ metrics.update(results.get("metrics", {}))
63
+ metrics_str = ",".join([f"{metrics[k]}" for k in sorted(metrics.keys())])
64
+ logger.info(
65
+ f"Input: {query_audio}, Output: {output_audio}, ASR results: {results['asr_text']}, LLM results: {results['llm_text']}"
 
 
 
66
  )
67
+ with open(args.eval_results_csv, "a") as f:
68
+ f.write(
69
+ f"{query_audio},{config['asr_model']},{config['llm_model']},{config['svs_model']},{config['melody_source']},{config['language']},{config['speaker']},{output_audio},{results['asr_text']},{results['llm_text']},{metrics_str}\n"
70
+ )
71
  except Exception as e:
72
  logger.error(f"Error in main: {e}")
73
  breakpoint()