# load_model.py import torch # ... (import your model definition and ModelArgs from model.py and model_args.py) def main(): args = ModelArgs() model = CombinedMultiModalTransformer(args) # Load the entire model model = torch.load('my_model.pth') # Alternatively, load the model's state dictionary # checkpoint = torch.load('my_model_state_dict.pth') # model.load_state_dict(checkpoint['model_state_dict']) # Assuming you saved it as 'model_state_dict' # Move the model to the appropriate device (optional) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Set the model to evaluation mode (if you're not training) model.eval() # ... (use the loaded model for inference) if __name__ == "__main__": main()