File size: 808 Bytes
1f5470c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
# load_model.py
import torch
# ... (import your model definition and ModelArgs from model.py and model_args.py)
def main():
args = ModelArgs()
model = CombinedMultiModalTransformer(args)
# Load the entire model
model = torch.load('my_model.pth')
# Alternatively, load the model's state dictionary
# checkpoint = torch.load('my_model_state_dict.pth')
# model.load_state_dict(checkpoint['model_state_dict']) # Assuming you saved it as 'model_state_dict'
# Move the model to the appropriate device (optional)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Set the model to evaluation mode (if you're not training)
model.eval()
# ... (use the loaded model for inference)
if __name__ == "__main__":
main() |