| # load_model.py | |
| import torch | |
| # ... (import your model definition and ModelArgs from model.py and model_args.py) | |
| def main(): | |
| args = ModelArgs() | |
| model = CombinedMultiModalTransformer(args) | |
| # Load the entire model | |
| model = torch.load('my_model.pth') | |
| # Alternatively, load the model's state dictionary | |
| # checkpoint = torch.load('my_model_state_dict.pth') | |
| # model.load_state_dict(checkpoint['model_state_dict']) # Assuming you saved it as 'model_state_dict' | |
| # Move the model to the appropriate device (optional) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| # Set the model to evaluation mode (if you're not training) | |
| model.eval() | |
| # ... (use the loaded model for inference) | |
| if __name__ == "__main__": | |
| main() |