File size: 4,801 Bytes
00f6d1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
import json
import sys
from pathlib import Path

import torch


def get_args():
    parser = argparse.ArgumentParser(
        description="Evaluate the bundled Streaming GOPT checkpoint on a prepared val/test split."
    )
    parser.add_argument("--data-dir", type=Path, required=True, help="Directory containing *_chunks.npz and metadata.json.")
    parser.add_argument("--model-dir", type=Path, required=True, help="Directory containing best_audio_model.pth and config.json.")
    parser.add_argument("--repo-src", type=Path, required=True, help="Path to custom-gopt/src.")
    parser.add_argument("--split", type=str, default="test", choices=["val", "test", "train"])
    parser.add_argument("--device", type=str, default=None, help="cuda / cuda:0 / cpu. Defaults to cuda if available.")
    parser.add_argument("--output-json", type=Path, default=None)
    parser.add_argument("--main-context-tokens", type=int, default=None, help="Override evaluation main context tokens.")
    parser.add_argument("--right-context-tokens", type=int, default=None, help="Override evaluation right context tokens.")
    return parser.parse_args()


def load_summary(data_dir, model_dir, repo_src, split, device_name, main_context_tokens=None, right_context_tokens=None):
    sys.path.insert(0, str(repo_src))

    from models import StreamingGOPT, StreamingGOPTNoPhn
    from train_streaming_charsiu import StreamingChunkDataset, load_model_state, make_loader, validate

    metadata = json.loads((data_dir / "metadata.json").read_text(encoding="utf-8"))
    cfg = json.loads((model_dir / "config.json").read_text(encoding="utf-8"))
    args_dict = cfg["args"]

    class Args:
        pass

    args = Args()
    for key, value in args_dict.items():
        setattr(args, key, value)

    args.main_context_token_choices = args_dict.get("main_context_token_choices") or [
        int(item.strip()) for item in str(args_dict["main_context_tokens"]).split(",") if item.strip()
    ]
    args.right_context_token_choices = args_dict.get("right_context_token_choices") or [
        int(item.strip()) for item in str(args_dict["right_context_tokens"]).split(",") if item.strip()
    ]

    device = torch.device(device_name or ("cuda" if torch.cuda.is_available() else "cpu"))
    if getattr(args, "tf32", False) and device.type == "cuda":
        if hasattr(torch, "set_float32_matmul_precision"):
            torch.set_float32_matmul_precision("high")
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    dataset = StreamingChunkDataset(split, data_dir, metadata, final_only=(split != "train"))
    loader = make_loader(dataset, len(dataset), False, int(getattr(args, "num_workers", 0)))

    model_cls = StreamingGOPT if args_dict["model"] == "streaming_gopt" else StreamingGOPTNoPhn
    model = model_cls(
        embed_dim=int(args_dict["embed_dim"]),
        num_heads=int(args_dict["heads"]),
        depth=int(args_dict["depth"]),
        input_dim=int(cfg["input_dim"]),
        seq_len=int(cfg["seq_len"]),
        phn_num=int(cfg["phn_num"]),
    )

    state = torch.load(model_dir / "best_audio_model.pth", map_location=device)
    load_model_state(model, state)
    model = model.to(device)
    model.eval()

    eval_main_context = int(main_context_tokens if main_context_tokens is not None else max(args.main_context_token_choices))
    eval_right_context = int(right_context_tokens if right_context_tokens is not None else max(args.right_context_token_choices))

    mse, corr, utt_mse, utt_corr, word_mse, word_corr = validate(
        model,
        loader,
        args,
        -1,
        device,
        eval_main_context,
        eval_right_context,
    )

    return {
        "split": split,
        "device": str(device),
        "main_context_tokens": eval_main_context,
        "right_context_tokens": eval_right_context,
        "phone_mse": float(mse),
        "phone_pcc": float(corr),
        "utt_mse": [float(x) for x in utt_mse],
        "utt_pcc": [float(x) for x in utt_corr],
        "word_mse": [float(x) for x in word_mse],
        "word_pcc": [float(x) for x in word_corr],
    }


def main():
    args = get_args()
    summary = load_summary(
        data_dir=args.data_dir,
        model_dir=args.model_dir,
        repo_src=args.repo_src,
        split=args.split,
        device_name=args.device,
        main_context_tokens=args.main_context_tokens,
        right_context_tokens=args.right_context_tokens,
    )
    payload = json.dumps(summary, ensure_ascii=False, indent=2)
    print(payload)
    if args.output_json is not None:
        args.output_json.parent.mkdir(parents=True, exist_ok=True)
        args.output_json.write_text(payload, encoding="utf-8")


if __name__ == "__main__":
    main()