File size: 467 Bytes
1f5470c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch
# ... (import your model definition and ModelArgs from model.py and model_args.py)
def main():
args = ModelArgs()
model = CombinedMultiModalTransformer(args)
# ... (optional: load pre-trained weights if available)
# Save the entire model
torch.save(model, 'my_model.pth')
# Alternatively, save the model's state dictionary
# torch.save(model.state_dict(), 'my_model_state_dict.pth')
if __name__ == "__main__":
main() |