claudeson / claudson /load.py
joebruce1313's picture
Upload 38004 files
1f5470c verified
# load_model.py
import torch
# ... (import your model definition and ModelArgs from model.py and model_args.py)
def main():
args = ModelArgs()
model = CombinedMultiModalTransformer(args)
# Load the entire model
model = torch.load('my_model.pth')
# Alternatively, load the model's state dictionary
# checkpoint = torch.load('my_model_state_dict.pth')
# model.load_state_dict(checkpoint['model_state_dict']) # Assuming you saved it as 'model_state_dict'
# Move the model to the appropriate device (optional)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Set the model to evaluation mode (if you're not training)
model.eval()
# ... (use the loaded model for inference)
if __name__ == "__main__":
main()