import argparse import json from typing import Any import torch from .config import ModelConfig from .model import SupernovaModel from .tokenizer import load_gpt2_tokenizer def main(config_path: str): cfg = ModelConfig.from_json_file(config_path) tok = load_gpt2_tokenizer() assert tok.vocab_size == cfg.vocab_size model = SupernovaModel(cfg) total_params = sum(p.numel() for p in model.parameters()) print(json.dumps({ "vocab_size": tok.vocab_size, "n_positions": cfg.n_positions, "d_model": cfg.d_model, "n_layers": cfg.n_layers, "n_heads": cfg.n_heads, "total_params": total_params, "exact": total_params == 25_000_000 }, indent=2)) if __name__ == "__main__": ap = argparse.ArgumentParser() ap.add_argument("--config", required=True) args = ap.parse_args() main(args.config)