File size: 3,001 Bytes
50dd0bc
 
20bd7b4
50dd0bc
 
 
120744c
50dd0bc
 
 
 
 
 
 
2ce9d86
20bd7b4
 
 
2ce9d86
9f5df1a
50dd0bc
 
 
20bd7b4
50dd0bc
 
 
 
 
 
 
 
 
 
 
 
 
120744c
50dd0bc
2ce9d86
9f5df1a
 
 
 
 
 
 
 
 
2ce9d86
 
 
 
 
 
 
 
 
 
9f5df1a
 
 
 
 
 
2ce9d86
9f5df1a
 
 
 
2ce9d86
1a42edd
 
2ce9d86
50dd0bc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from argparse import ArgumentParser
from logging import getLogger
from pathlib import Path

import yaml

from characters import get_character
from pipeline import SingingDialoguePipeline

logger = getLogger(__name__)


def get_parser():
    parser = ArgumentParser()
    parser.add_argument("--query_audios", nargs="+", type=Path, required=True)
    parser.add_argument(
        "--config_path", type=Path, default="config/cli/yaoyin_default.yaml"
    )
    parser.add_argument("--output_audio_folder", type=Path, required=True)
    parser.add_argument("--eval_results_csv", type=Path, default=None)
    return parser


def load_config(config_path: Path):
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)
    return config


def main():
    parser = get_parser()
    args = parser.parse_args()
    config = load_config(args.config_path)
    pipeline = SingingDialoguePipeline(config)
    speaker = config["speaker"]
    language = config["language"]
    character_name = config["prompt_template_character"]
    character = get_character(character_name)
    prompt_template = character.prompt
    args.output_audio_folder.mkdir(parents=True, exist_ok=True)
    if config.get("evaluators", {}):
        if args.eval_results_csv:
            args.eval_results_csv.parent.mkdir(parents=True, exist_ok=True)
            with open(args.eval_results_csv, "a") as f:
                f.write(
                    f"query_audio,asr_model,llm_model,svs_model,melody_source,language,speaker,output_audio,asr_text,llm_text,metrics\n"
                )
        else:
            logger.warning("No eval_results_csv provided, skipping evaluation")
    try:
        for query_audio in args.query_audios:
            output_audio = args.output_audio_folder / f"{query_audio.stem}_response.wav"
            results = pipeline.run(
                query_audio,
                language,
                prompt_template,
                speaker,
                output_audio_path=output_audio,
            )
            if args.eval_results_csv and config.get("evaluators", {}):
                metrics = pipeline.evaluate(output_audio, **results)
                metrics.update(results.get("metrics", {}))
                metrics_str = ",".join([f"{metrics[k]}" for k in sorted(metrics.keys())])
                logger.info(
                    f"Input: {query_audio}, Output: {output_audio}, ASR results: {results['asr_text']}, LLM results: {results['llm_text']}"
                )
                with open(args.eval_results_csv, "a") as f:
                    f.write(
                        f"{query_audio},{config['asr_model']},{config['llm_model']},{config['svs_model']},{config['melody_source']},{config['language']},{config['speaker']},{output_audio},{results['asr_text']},{results['llm_text']},{metrics_str}\n"
                    )
    except Exception as e:
        import traceback
        logger.error(traceback.format_exc())
        raise e


if __name__ == "__main__":
    main()