claudeson / claudson /save_model.py
joebruce1313's picture
Upload 38004 files
1f5470c verified
import torch
# ... (import your model definition and ModelArgs from model.py and model_args.py)
def main():
args = ModelArgs()
model = CombinedMultiModalTransformer(args)
# ... (optional: load pre-trained weights if available)
# Save the entire model
torch.save(model, 'my_model.pth')
# Alternatively, save the model's state dictionary
# torch.save(model.state_dict(), 'my_model_state_dict.pth')
if __name__ == "__main__":
main()