| import torch | |
| from modeling_chatdbs1 import ChatDBS1Model, ChatDBS1Config | |
| # ---------------------------- | |
| # Config | |
| # ---------------------------- | |
| config = ChatDBS1Config( | |
| n_embd=512, | |
| n_layer=12, | |
| n_head=8, | |
| vocab_size=50257, | |
| block_size=1024 | |
| ) | |
| # ---------------------------- | |
| # Initialize model | |
| # ---------------------------- | |
| model = ChatDBS1Model(config) | |
| # ---------------------------- | |
| # Save random-initialized weights | |
| # ---------------------------- | |
| torch.save(model.state_dict(), "pytorch_model.bin") | |
| print("Random-initialized pytorch_model.bin created successfully!") |