import argparse import json import sys from pathlib import Path import torch def get_args(): parser = argparse.ArgumentParser( description="Evaluate the bundled Streaming GOPT checkpoint on a prepared val/test split." ) parser.add_argument("--data-dir", type=Path, required=True, help="Directory containing *_chunks.npz and metadata.json.") parser.add_argument("--model-dir", type=Path, required=True, help="Directory containing best_audio_model.pth and config.json.") parser.add_argument("--repo-src", type=Path, required=True, help="Path to custom-gopt/src.") parser.add_argument("--split", type=str, default="test", choices=["val", "test", "train"]) parser.add_argument("--device", type=str, default=None, help="cuda / cuda:0 / cpu. Defaults to cuda if available.") parser.add_argument("--output-json", type=Path, default=None) parser.add_argument("--main-context-tokens", type=int, default=None, help="Override evaluation main context tokens.") parser.add_argument("--right-context-tokens", type=int, default=None, help="Override evaluation right context tokens.") return parser.parse_args() def load_summary(data_dir, model_dir, repo_src, split, device_name, main_context_tokens=None, right_context_tokens=None): sys.path.insert(0, str(repo_src)) from models import StreamingGOPT, StreamingGOPTNoPhn from train_streaming_charsiu import StreamingChunkDataset, load_model_state, make_loader, validate metadata = json.loads((data_dir / "metadata.json").read_text(encoding="utf-8")) cfg = json.loads((model_dir / "config.json").read_text(encoding="utf-8")) args_dict = cfg["args"] class Args: pass args = Args() for key, value in args_dict.items(): setattr(args, key, value) args.main_context_token_choices = args_dict.get("main_context_token_choices") or [ int(item.strip()) for item in str(args_dict["main_context_tokens"]).split(",") if item.strip() ] args.right_context_token_choices = args_dict.get("right_context_token_choices") or [ int(item.strip()) for item in str(args_dict["right_context_tokens"]).split(",") if item.strip() ] device = torch.device(device_name or ("cuda" if torch.cuda.is_available() else "cpu")) if getattr(args, "tf32", False) and device.type == "cuda": if hasattr(torch, "set_float32_matmul_precision"): torch.set_float32_matmul_precision("high") torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True dataset = StreamingChunkDataset(split, data_dir, metadata, final_only=(split != "train")) loader = make_loader(dataset, len(dataset), False, int(getattr(args, "num_workers", 0))) model_cls = StreamingGOPT if args_dict["model"] == "streaming_gopt" else StreamingGOPTNoPhn model = model_cls( embed_dim=int(args_dict["embed_dim"]), num_heads=int(args_dict["heads"]), depth=int(args_dict["depth"]), input_dim=int(cfg["input_dim"]), seq_len=int(cfg["seq_len"]), phn_num=int(cfg["phn_num"]), ) state = torch.load(model_dir / "best_audio_model.pth", map_location=device) load_model_state(model, state) model = model.to(device) model.eval() eval_main_context = int(main_context_tokens if main_context_tokens is not None else max(args.main_context_token_choices)) eval_right_context = int(right_context_tokens if right_context_tokens is not None else max(args.right_context_token_choices)) mse, corr, utt_mse, utt_corr, word_mse, word_corr = validate( model, loader, args, -1, device, eval_main_context, eval_right_context, ) return { "split": split, "device": str(device), "main_context_tokens": eval_main_context, "right_context_tokens": eval_right_context, "phone_mse": float(mse), "phone_pcc": float(corr), "utt_mse": [float(x) for x in utt_mse], "utt_pcc": [float(x) for x in utt_corr], "word_mse": [float(x) for x in word_mse], "word_pcc": [float(x) for x in word_corr], } def main(): args = get_args() summary = load_summary( data_dir=args.data_dir, model_dir=args.model_dir, repo_src=args.repo_src, split=args.split, device_name=args.device, main_context_tokens=args.main_context_tokens, right_context_tokens=args.right_context_tokens, ) payload = json.dumps(summary, ensure_ascii=False, indent=2) print(payload) if args.output_json is not None: args.output_json.parent.mkdir(parents=True, exist_ok=True) args.output_json.write_text(payload, encoding="utf-8") if __name__ == "__main__": main()